From a62ac638c33c83e78fb84e67e63c6ffc64cbe70e Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:33:46 +0100 Subject: [PATCH] fix for reversemode --- .../test/MLIR/ReverseMode/batched_square.mlir | 27 +++++++++++++++++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 20 +++++++------- 2 files changed, 37 insertions(+), 10 deletions(-) create mode 100644 enzyme/test/MLIR/ReverseMode/batched_square.mlir diff --git a/enzyme/test/MLIR/ReverseMode/batched_square.mlir b/enzyme/test/MLIR/ReverseMode/batched_square.mlir new file mode 100644 index 00000000000..a69ff6fb96b --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/batched_square.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math %s | FileCheck %s + +module { + func.func @square(%x: f64) -> f64 { + %next = arith.mulf %x, %x : f64 + return %next : f64 + } + + func.func @dsquare(%x: f64, %dr: tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> tensor<2xf64> + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %0 = call @diffesquare(%arg0, %arg1) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %0 : tensor<2xf64> +// CHECK-NEXT: } + +// CHECK: func.func private @diffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %0 = "enzyme.broadcast"(%arg0) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %1 = arith.mulf %arg1, %0 : tensor<2xf64> +// CHECK-NEXT: %2 = "enzyme.broadcast"(%arg0) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %3 = arith.mulf %arg1, %2 : tensor<2xf64> +// CHECK-NEXT: %4 = arith.addf %1, %3 : tensor<2xf64> +// CHECK-NEXT: return %4 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index dccbc7b7923..7f0d0ecd1e4 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -277,16 +277,16 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, if (!vecValue && !startsWith(ord, "local")) { if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; - if (intrinsic == MLIRDerivatives) { - os << ";\n"; - os << "if (gutils->width != 1) {\n" - << " " << argName << "_" << (idx - 1) - << " = builder.create(\n" - << " op.getLoc(),\n" - << " " << argName << "_" << (idx - 1) << ",\n" - << " llvm::SmallVector({gutils->width}));\n" - << "}"; - } + } + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << curIndent << "if (gutils->width != 1) {\n" + << curIndent << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << curIndent << " op.getLoc(),\n" + << curIndent << " " << argName << "_" << (idx - 1) << ",\n" + << curIndent << " llvm::SmallVector({gutils->width}));\n" + << curIndent << "}"; } if (lookup && intrinsic != MLIRDerivatives)