-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add type conversions for width != 1. This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast<arith::MulFOp>(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ``` * add code to tblgen generator, this eventually needs to be a single function call. * a test and formatting * use tensor splatop * remove stale enzyme-tblgen changes * do the simple batching in enzyme-tblgen * include tensor in all AutoDiffOpInterfaceImpls * add enzyme broadcastop * getShadowType for TensorTypeInterface * create broadcastop in enzyme-tblgen * Revert "include tensor in all AutoDiffOpInterfaceImpls" This reverts commit c06ed01. * test * DenseI64ArrayAttr for shape instead of scalar width * `llvm::SmallVector` --> `ArrayRef` * formatting * use getShadowType in BroadcastOp builder Co-authored-by: Billy Moses <wmoses@google.com> * unstructured control flow test * scf.for * formatting * support `scf.if` test * formatting * forgotten includes --------- Co-authored-by: Jules Merckx <jumerckx@mac.local> Co-authored-by: Billy Moses <wmoses@google.com>
- Loading branch information
1 parent
8e79483
commit eeb6200
Showing
18 changed files
with
239 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// RUN: %eopt --enzyme %s | FileCheck %s | ||
|
||
module { | ||
func.func @square(%x : f64, %y : f64) -> f64 { | ||
%c = arith.cmpf ult, %x, %y : f64 | ||
cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64) | ||
|
||
^blk2(%r : f64): | ||
return %r : f64 | ||
} | ||
func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> { | ||
%r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>) | ||
return %r : tensor<2xf64> | ||
} | ||
} | ||
|
||
// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { | ||
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> | ||
// CHECK-NEXT: return %[[i0]] : tensor<2xf64> | ||
// CHECK-NEXT: } | ||
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { | ||
// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 | ||
// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) | ||
// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 | ||
// CHECK-NEXT: return %[[i2]] : tensor<2xf64> | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// RUN: %eopt --enzyme %s | FileCheck %s | ||
|
||
module { | ||
func.func @square(%x : f64) -> f64 { | ||
%cst = arith.constant 10.000000e+00 : f64 | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c10 = arith.constant 10 : index | ||
%r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) { | ||
%n = arith.addf %arg2, %x : f64 | ||
scf.yield %n : f64 | ||
} | ||
return %r : f64 | ||
} | ||
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { | ||
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) | ||
return %r : tensor<2xf64> | ||
} | ||
} | ||
|
||
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { | ||
// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> | ||
// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 | ||
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index | ||
// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) { | ||
// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 | ||
// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64> | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64> | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// RUN: %eopt --enzyme %s | FileCheck %s | ||
|
||
module { | ||
func.func @square(%x : f64, %c : i1) -> f64 { | ||
%c2 = arith.constant 2.000000e+00 : f64 | ||
%c10 = arith.constant 10.000000e+00 : f64 | ||
%r:2 = scf.if %c -> (f64, f64) { | ||
%mul = arith.mulf %x, %x : f64 | ||
scf.yield %mul, %c2 : f64, f64 | ||
} else { | ||
%add = arith.addf %x, %x : f64 | ||
scf.yield %add, %c10 : f64, f64 | ||
} | ||
%res = arith.mulf %r#0, %r#1 : f64 | ||
return %res : f64 | ||
} | ||
func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> { | ||
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>) | ||
return %r : tensor<2xf64> | ||
} | ||
} | ||
|
||
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> { | ||
// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 | ||
// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 | ||
// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) { | ||
// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64> | ||
// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64> | ||
// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 | ||
// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64 | ||
// CHECK-NEXT: } else { | ||
// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 | ||
// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64 | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64> | ||
// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 | ||
// CHECK-NEXT: return %[[r2]] : tensor<2xf64> | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// RUN: %eopt --enzyme %s | FileCheck %s | ||
|
||
module { | ||
func.func @square(%x : f64) -> f64{ | ||
%y = arith.mulf %x, %x : f64 | ||
return %y : f64 | ||
} | ||
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { | ||
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) | ||
return %r : tensor<2xf64> | ||
} | ||
} | ||
|
||
// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { | ||
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> | ||
// CHECK-NEXT: return %[[i0]] : tensor<2xf64> | ||
// CHECK-NEXT: } | ||
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { | ||
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : f64 -> tensor<2xf64> | ||
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : f64 -> tensor<2xf64> | ||
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> | ||
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> | ||
// CHECK-NEXT: return %[[i2]] : tensor<2xf64> | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// RUN: %eopt --enzyme %s | FileCheck %s | ||
|
||
module { | ||
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ | ||
%y = arith.mulf %x, %x : tensor<10xf64> | ||
return %y : tensor<10xf64> | ||
} | ||
func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { | ||
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) | ||
return %r : tensor<2x10xf64> | ||
} | ||
} | ||
|
||
// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { | ||
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> | ||
// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> | ||
// CHECK-NEXT: } | ||
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { | ||
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array<i64: 2>}> : (tensor<10xf64>) -> tensor<2x10xf64> | ||
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> | ||
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array<i64: 2>}> : (tensor<10xf64>) -> tensor<2x10xf64> | ||
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> | ||
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> | ||
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> | ||
// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters