Skip to content

Commit

Permalink
Batched autodiff (#2181)
Browse files Browse the repository at this point in the history
* add type conversions for width != 1.

This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this:
```
  LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const
  {
    auto op = cast<arith::MulFOp>(op0);
    if (gutils->width != 1) {
      auto newop = gutils->getNewFromOriginal(op0);
      for (auto res : newop->getResults()) {
        res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType()));
      }
    }
    gutils->eraseIfUnused(op);
    if (gutils->isConstantInstruction(op))
      return success();
    mlir::Value res = nullptr;
    if (!gutils->isConstantValue(op->getOperand(0)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(0), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          // TODO: gutils->makeBatched(...)
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        itmp.dump();
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    if (!gutils->isConstantValue(op->getOperand(1)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(1), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    assert(res);
    gutils->setDiffe(op->getResult(0), res, builder);
    return success();
  }
```

* add code to tblgen generator, this eventually needs to be a single function call.

* a test and formatting

* use tensor splatop

* remove stale enzyme-tblgen changes

* do the simple batching in enzyme-tblgen

* include tensor in all AutoDiffOpInterfaceImpls

* add enzyme broadcastop

* getShadowType for TensorTypeInterface

* create broadcastop in enzyme-tblgen

* Revert "include tensor in all AutoDiffOpInterfaceImpls"

This reverts commit c06ed01.

* test

* DenseI64ArrayAttr for shape instead of scalar width

* `llvm::SmallVector` --> `ArrayRef`

* formatting

* use getShadowType in BroadcastOp builder

Co-authored-by: Billy Moses <wmoses@google.com>

* unstructured control flow test

* scf.for

* formatting

* support `scf.if` test

* formatting

* forgotten includes

---------

Co-authored-by: Jules Merckx <jumerckx@mac.local>
Co-authored-by: Billy Moses <wmoses@google.com>
  • Loading branch information
3 people authored Dec 27, 2024
1 parent 8e79483 commit eeb6200
Show file tree
Hide file tree
Showing 18 changed files with 239 additions and 11 deletions.
16 changes: 16 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> {

}

def BroadcastOp : Enzyme_Op<"broadcast"> {
let description = [{
Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front.
For scalar operands, ranked tensor is created.

NOTE: Only works for scalar and *ranked* tensor operands for now.
}];

let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape);
let results = (outs AnyRankedTensor:$output);

let builders = [
OpBuilder<(ins "Value":$input, "ArrayRef<int64_t>":$shape)>
];
}

#endif // ENZYME_OPS
15 changes: 15 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

return success();
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
ArrayRef<int64_t> shape) {
auto shapeAttr = builder.getDenseI64ArrayAttr(shape);
auto resultTy = input.getType();
for (auto s : llvm::reverse(shape)) {
resultTy = resultTy.cast<AutoDiffTypeInterface>().getShadowType(s);
}
build(builder, result, resultTy, input, shapeAttr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface(
arith::ConstantOp::attachInterface<ArithConstantOpBatchInterface>(*context);
});
}

void mlir::enzyme::registerTensorDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) {
registerInterfaces(context);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class FloatTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
if (width > 1) {
return RankedTensorType::get({width}, self);
} else {
return self;
}
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -106,7 +109,14 @@ class TensorTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
if (width != 1) {
auto tenType = self.cast<TensorType>();
auto shape = tenType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.push_back(width);
newShape.append(shape.begin(), shape.end());
return RankedTensorType::get(newShape, tenType.getElementType());
}
return self;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst,
newVals.push_back(gutils->invertPointerM(op, builder));
} else {
Type retTy =
arg.getType().cast<AutoDiffTypeInterface>().getShadowType();
arg.getType().cast<AutoDiffTypeInterface>().getShadowType(
gutils->width);
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, op.getLoc());
newVals.push_back(toret);
Expand Down Expand Up @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler(
if (auto iface =
dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
if (!iface.isMutable()) {
Type retTy = iface.getShadowType();
Type retTy = iface.getShadowType(gutils->width);
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, operand.get().getLoc());
newOperands.push_back(toret);
Expand Down Expand Up @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
<< result.getType() << "\n";
return failure();
}
newOpResultTypes.push_back(typeIface.getShadowType());
newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
}

SmallVector<Value> newOperands;
Expand Down Expand Up @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces(
enzyme::registerCFDialectAutoDiffInterface(registry);
enzyme::registerLinalgDialectAutoDiffInterface(registry);
enzyme::registerFuncDialectAutoDiffInterface(registry);
enzyme::registerTensorDialectAutoDiffInterface(registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
void registerMathDialectAutoDiffInterface(DialectRegistry &registry);
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry);
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry);

void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);

Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ FunctionOpInterface CloneFunctionWithReturns(
mlir::Value val = blk.getArgument(i);
mlir::Value dval;
if (i == ArgActivity.size() - 1)
dval = blk.addArgument(val.getType(), val.getLoc());
dval = blk.addArgument(getShadowType(val.getType(), width),
val.getLoc());
else
dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(),
dval = blk.insertArgument(blk.args_begin() + i + 1,
getShadowType(val.getType(), width),
val.getLoc());
ptrInputs.map(oval, dval);
}
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v,
return invertedPointers.lookupOrNull(v);

if (isConstantValue(v)) {
if (auto iface = v.getType().dyn_cast<AutoDiffTypeInterface>()) {
if (auto iface =
getShadowType(v.getType()).dyn_cast<AutoDiffTypeInterface>()) {
OpBuilder::InsertionGuard guard(Builder2);
if (auto op = v.getDefiningOp())
Builder2.setInsertionPoint(getNewFromOriginal(op));
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
MLIRFuncDialect
MLIRFuncTransforms
MLIRGPUDialect
MLIRTensorDialect
MLIRIR
MLIRLLVMDialect
MLIRMathDialect
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "Dialect/Dialect.h"

Expand Down Expand Up @@ -80,6 +81,10 @@ namespace affine {
class AffineDialect;
} // end namespace affine

namespace tensor {
class TensorDialect;
} // end namespace tensor

namespace LLVM {
class LLVMDialect;
} // end namespace LLVM
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> {
let dependentDialects = [
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect"
"cf::ControlFlowDialect",
"tensor::TensorDialect",
];
let constructor = "mlir::enzyme::createDifferentiatePass()";
}
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::omp::OpenMPDialect>();
registry.insert<mlir::math::MathDialect>();
registry.insert<mlir::linalg::LinalgDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<DLTIDialect>();

registry.insert<mlir::enzyme::EnzymeDialect>();
Expand Down
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_branch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64, %y : f64) -> f64 {
%c = arith.cmpf ult, %x, %y : f64
cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64)

^blk2(%r : f64):
return %r : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64
// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>)
// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }
33 changes: 33 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_for.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64 {
%cst = arith.constant 10.000000e+00 : f64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) {
%n = arith.addf %arg2, %x : f64
scf.yield %n : f64
}
return %r : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64>
// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) {
// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64
// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64>
// CHECK-NEXT: }
43 changes: 43 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64, %c : i1) -> f64 {
%c2 = arith.constant 2.000000e+00 : f64
%c10 = arith.constant 10.000000e+00 : f64
%r:2 = scf.if %c -> (f64, f64) {
%mul = arith.mulf %x, %x : f64
scf.yield %mul, %c2 : f64, f64
} else {
%add = arith.addf %x, %x : f64
scf.yield %add, %c10 : f64, f64
}
%res = arith.mulf %r#0, %r#1 : f64
return %res : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> {
// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64
// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64
// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) {
// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64>
// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64>
// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64>
// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64
// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64>
// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64
// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64
// CHECK-NEXT: }
// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64>
// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64
// CHECK-NEXT: return %[[r2]] : tensor<2xf64>
// CHECK-NEXT: }
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_scalar.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64{
%y = arith.mulf %x, %x : f64
return %y : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : f64 -> tensor<2xf64>
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64>
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : f64 -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
%y = arith.mulf %x, %x : tensor<10xf64>
return %y : tensor<10xf64>
}
func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>)
return %r : tensor<2x10xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array<i64: 2>}> : (tensor<10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64>
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array<i64: 2>}> : (tensor<10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64>
// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64>
// CHECK-NEXT: }
13 changes: 12 additions & 1 deletion enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,19 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
os << ord;
}
if (!vecValue && !startsWith(ord, "local")) {
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives))
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) {
os << ")";
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << "if (gutils->width != 1) {\n"
<< " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< "}";
}
}

if (lookup && intrinsic != MLIRDerivatives)
os << ", " << builder << ")";
Expand Down

0 comments on commit eeb6200

Please sign in to comment.