diff --git a/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp index 4e2dd22fe..3b2fb6f78 100644 --- a/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp @@ -9,7 +9,8 @@ #include "lib/Dialect/ModArith/IR/ModArithOps.h" #include "lib/Dialect/ModArith/IR/ModArithTypes.h" #include "lib/Utils/ConversionUtils.h" -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -98,8 +99,8 @@ struct ConvertConstant : public OpConversionPattern { } }; -struct ConvertExt : public OpConversionPattern { - ConvertExt(mlir::MLIRContext *context) +struct ConvertExtSI : public OpConversionPattern { + ConvertExtSI(mlir::MLIRContext *context) : OpConversionPattern(context) {} using OpConversionPattern::OpConversionPattern; @@ -116,20 +117,19 @@ struct ConvertExt : public OpConversionPattern { } }; -template -struct ConvertBinOp : public OpConversionPattern { - ConvertBinOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} +struct ConvertExtUI : public OpConversionPattern { + ConvertExtUI(mlir::MLIRContext *context) + : OpConversionPattern(context) {} - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - SourceArithOp op, typename SourceArithOp::Adaptor adaptor, + ::mlir::arith::ExtUIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto result = - b.create(adaptor.getLhs(), adaptor.getRhs()); + auto result = b.create( + op.getLoc(), convertArithType(op.getType()), adaptor.getIn()); rewriter.replaceOp(op, result); return success(); } @@ -161,22 +161,23 @@ void ArithToModArith::runOnOperation() { target.addDynamicallyLegalOp< memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp, - memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp>( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp, + affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); RewritePatternSet patterns(context); patterns - .add, ConvertBinOp, ConvertBinOp, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, - ConvertAny, ConvertAny >( + ConvertAny, ConvertAny, + ConvertAny, ConvertAny>( typeConverter, context); addStructuralConversionPatterns(typeConverter, patterns, target); diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp index a4f73ee27..2a4698952 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp @@ -118,7 +118,7 @@ struct AddBoolServerKeyArg : public OpConversionPattern { }; template -struct ConvertBinOp : public OpConversionPattern { +struct ConvertCGGIBinOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -136,12 +136,14 @@ struct ConvertBinOp : public OpConversionPattern { } }; -using ConvertBoolAndOp = ConvertBinOp; -using ConvertBoolNandOp = ConvertBinOp; -using ConvertBoolOrOp = ConvertBinOp; -using ConvertBoolNorOp = ConvertBinOp; -using ConvertBoolXorOp = ConvertBinOp; -using ConvertBoolXNorOp = ConvertBinOp; +using ConvertBoolAndOp = ConvertCGGIBinOp; +using ConvertBoolNandOp = + ConvertCGGIBinOp; +using ConvertBoolOrOp = ConvertCGGIBinOp; +using ConvertBoolNorOp = ConvertCGGIBinOp; +using ConvertBoolXorOp = ConvertCGGIBinOp; +using ConvertBoolXNorOp = + ConvertCGGIBinOp; struct ConvertBoolNotOp : public OpConversionPattern { ConvertBoolNotOp(mlir::MLIRContext *context) diff --git a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp index 97c7e86e5..3ee76a090 100644 --- a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp @@ -278,9 +278,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase { ConvertEncodeOp, ConvertEncryptOp, ConvertDecryptOp, // Scheme-agnostic RLWE Arithmetic Ops: - ConvertBinOp, - ConvertBinOp, - ConvertBinOp, + ConvertLWEBinOp, + ConvertLWEBinOp, + ConvertLWEBinOp, ConvertUnaryOp, /////////////////////////////////// diff --git a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h index b975869bd..b0041a283 100644 --- a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h +++ b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h @@ -45,7 +45,7 @@ struct ConvertUnaryOp : public OpConversionPattern { }; template -struct ConvertBinOp : public OpConversionPattern { +struct ConvertLWEBinOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( diff --git a/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp b/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp index 39f7c6ef0..50d490fe0 100644 --- a/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp +++ b/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp @@ -581,8 +581,8 @@ struct ConvertLeadingTerm : public OpConversionPattern { }; template -struct ConvertBinop : public OpConversionPattern { - ConvertBinop(mlir::MLIRContext *context) +struct ConvertPolyBinop : public OpConversionPattern { + ConvertPolyBinop(mlir::MLIRContext *context) : OpConversionPattern(context) {} using OpConversionPattern::OpConversionPattern; @@ -1293,8 +1293,8 @@ void PolynomialToModArith::runOnOperation() { RewritePatternSet patterns(context); patterns.add, - ConvertBinop, + ConvertPolyBinop, + ConvertPolyBinop, ConvertLeadingTerm, ConvertMonomial, ConvertMonicMonomialMul, ConvertConstant, ConvertMulScalar, ConvertNTT, ConvertINTT>( typeConverter, context); diff --git a/lib/Utils/ConversionUtils.h b/lib/Utils/ConversionUtils.h index 74ed7a2ee..3fdf9f7af 100644 --- a/lib/Utils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils.h @@ -80,6 +80,25 @@ struct ConvertAny : public ConversionPattern { } }; +template +struct ConvertBinOp : public OpConversionPattern { + ConvertBinOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SourceArithOp op, typename SourceArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto result = + b.create(adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ContextAwareTypeConverter : public TypeConverter { public: // Convert types of the values in the input range, taking into account the diff --git a/scripts/templates/Conversion/lib/BUILD.jinja b/scripts/templates/Conversion/lib/BUILD.jinja index 99156707b..09c4022e0 100644 --- a/scripts/templates/Conversion/lib/BUILD.jinja +++ b/scripts/templates/Conversion/lib/BUILD.jinja @@ -11,7 +11,7 @@ cc_library( hdrs = ["{{ pass_name }}.h"], deps = [ ":pass_inc_gen", - "@heir//lib/Utils/ConversionUtils", + "@heir//lib/Utils:ConversionUtils", "@heir//lib/Dialect/{{ source_dialect_name }}/IR:Dialect", "@heir//lib/Dialect/{{ target_dialect_name }}/IR:Dialect", "@llvm-project//mlir:IR", diff --git a/scripts/templates/Conversion/lib/ConversionPass.cpp.jinja b/scripts/templates/Conversion/lib/ConversionPass.cpp.jinja index c85474784..6eca8dfde 100644 --- a/scripts/templates/Conversion/lib/ConversionPass.cpp.jinja +++ b/scripts/templates/Conversion/lib/ConversionPass.cpp.jinja @@ -5,7 +5,7 @@ #include "lib/Utils/ConversionUtils.h" #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project -namespace mlir::heir { +namespace mlir::heir::{{ source_dialect_namespace }} { #define GEN_PASS_DEF_{{ pass_name | upper }} #include "lib/Dialect/{{ source_dialect_name }}/Conversions/{{ pass_name }}/{{ pass_name }}.h.inc" @@ -59,4 +59,4 @@ struct {{ pass_name }} : public impl::{{ pass_name }}Base<{{ pass_name }}> { } }; -} // namespace mlir::heir +} // namespace mlir::heir::{{ source_dialect_namespace }} diff --git a/scripts/templates/Conversion/lib/ConversionPass.h.jinja b/scripts/templates/Conversion/lib/ConversionPass.h.jinja index a8cdbdb0f..b349a05a5 100644 --- a/scripts/templates/Conversion/lib/ConversionPass.h.jinja +++ b/scripts/templates/Conversion/lib/ConversionPass.h.jinja @@ -3,7 +3,7 @@ #include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project -namespace mlir::heir { +namespace mlir::heir::{{ source_dialect_namespace }} { #define GEN_PASS_DECL #include "lib/Dialect/{{ source_dialect_name }}/Conversions/{{ pass_name }}/{{ pass_name }}.h.inc" @@ -11,6 +11,6 @@ namespace mlir::heir { #define GEN_PASS_REGISTRATION #include "lib/Dialect/{{ source_dialect_name }}/Conversions/{{ pass_name }}/{{ pass_name }}.h.inc" -} // namespace mlir::heir +} // namespace mlir::heir::{{ source_dialect_namespace }} #endif // LIB_DIALECT_{{ source_dialect_name | upper }}_CONVERSIONS_{{ pass_name | upper }}_{{ pass_name | upper }}_H_ diff --git a/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir b/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir index b81321ac5..0ebd3bc56 100644 --- a/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir +++ b/tests/Dialect/Arith/Conversions/ArithToModArith/arith-to-mod-arith.mlir @@ -61,6 +61,8 @@ module attributes {tf_saved_model.semantics} { memref.global "private" constant @__constant_16x1xi8 : memref<16x1xi8> = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> {alignment = 64 : i64} func.func @test_memref_global(%arg0: memref<1x1xi32>) -> memref<1x1xi32> { %c429_i32 = arith.constant 429 : i32 + %c33_i8 = arith.constant 33 : i8 + %c33 = arith.extui %c33_i8 : i8 to i32 %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_16x1xi8 : memref<16x1xi8> %3 = memref.get_global @__constant_16xi32_0 : memref<16xi32> @@ -72,7 +74,27 @@ module attributes {tf_saved_model.semantics} { %a24 = arith.extsi %22 : i8 to i32 %25 = arith.muli %24, %a24 : i32 %26 = arith.addi %21, %25 : i32 - memref.store %26, %alloc[%c0, %c0] : memref<1x1xi32> + %27 = arith.addi %26, %c33 : i32 + memref.store %27, %alloc[%c0, %c0] : memref<1x1xi32> + return %alloc : memref<1x1xi32> + } +} + +// CHECK-LABEL: @test_affine +// CHECK-SAME: (%[[ARG:.*]]: memref<1x1x!Z128_i9_>) -> memref<1x1x!Z2147483648_i33_> { +module attributes {tf_saved_model.semantics} { + func.func @test_affine(%arg0: memref<1x1xi8>) -> memref<1x1xi32> { + %c429_i32 = arith.constant 429 : i32 + %c33_i8 = arith.constant 33 : i8 + %c33 = arith.extui %c33_i8 : i8 to i32 + %0 = affine.load %arg0[0, 0] : memref<1x1xi8> + %c0 = arith.constant 0 : index + %1 = arith.extsi %0 : i8 to i32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi32> + // CHECK: %[[ENC:.*]] = mod_arith.mod_switch %{{.*}}: !Z128_i9_ to !Z2147483648_i33_ + %25 = arith.muli %1, %c33 : i32 + %26 = arith.addi %c429_i32, %25 : i32 + affine.store %26, %alloc[0, 0] : memref<1x1xi32> return %alloc : memref<1x1xi32> } }