-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched autodiff #2181
Batched autodiff #2181
Conversation
I haven't yet fully made the changes in enzyme-tblgen.cpp, and either way this just works for the simple test case. mlir::Value itmp = ({
// Computing MulFOp
auto fwdarg_0 = dif;
auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
if (gutils->width != 1)
{
fwdarg_1 = builder.create<tensor::SplatOp>(
op.getLoc(),
mlir::RankedTensorType::get({gutils->width},
fwdarg_1.getType()),
fwdarg_1);
}
builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
}); But this is the MLIR code that is generated for this simple test: func.func private @fwddiffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
%splat = tensor.splat %arg0 : tensor<2xf64>
%0 = arith.mulf %arg1, %splat : tensor<2xf64>
%splat_0 = tensor.splat %arg0 : tensor<2xf64>
%1 = arith.mulf %arg1, %splat_0 : tensor<2xf64>
%2 = arith.addf %0, %1 : tensor<2xf64>
%3 = arith.mulf %arg0, %arg0 : f64
return %2 : tensor<2xf64>
} |
This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast<arith::MulFOp>(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ```
This reverts commit c06ed01.
fix the format/etc then I think this is good to go! |
Co-authored-by: Billy Moses <wmoses@google.com>
@wsmoses I started looking into control flow ( |
That’s totally fine/go for it, but also feel free to merge whenever it’s in a good state |
Added some type conversions to tensor types if
width != 1
. The simple test case seems correct now.Corresponding Enzyme-JAX PR: EnzymeAD/Enzyme-JAX#197