Skip to content

Commit

Permalink
fix for reversemode
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Jan 3, 2025
1 parent 7bc73fa commit a8d9eb5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
27 changes: 27 additions & 0 deletions enzyme/test/MLIR/ReverseMode/batched_square.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math %s | FileCheck %s

module {
func.func @square(%x: f64) -> f64 {
%next = arith.mulf %x, %x : f64
return %next : f64
}

func.func @dsquare(%x: f64, %dr: tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>], width=2 } : (f64, tensor<2xf64>) -> tensor<2xf64>
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsquare(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %0 = call @diffesquare(%arg0, %arg1) : (f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %0 : tensor<2xf64>
// CHECK-NEXT: }

// CHECK: func.func private @diffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %0 = "enzyme.broadcast"(%arg0) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %1 = arith.mulf %arg1, %0 : tensor<2xf64>
// CHECK-NEXT: %2 = "enzyme.broadcast"(%arg0) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %3 = arith.mulf %arg1, %2 : tensor<2xf64>
// CHECK-NEXT: %4 = arith.addf %1, %3 : tensor<2xf64>
// CHECK-NEXT: return %4 : tensor<2xf64>
// CHECK-NEXT: }
21 changes: 11 additions & 10 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,17 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
if (!vecValue && !startsWith(ord, "local")) {
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) {
os << ")";
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << "if (gutils->width != 1) {\n"
<< " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< "}";
}
}
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << curIndent << "if (gutils->width != 1) {\n"
<< curIndent << " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< curIndent << " op.getLoc(),\n"
<< curIndent << " " << argName << "_" << (idx - 1) << ",\n"
<< curIndent
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< curIndent << "}";
}

if (lookup && intrinsic != MLIRDerivatives)
Expand Down

0 comments on commit a8d9eb5

Please sign in to comment.