Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched autodiff #2181

Merged
merged 22 commits into from
Dec 27, 2024
Merged

Batched autodiff #2181

merged 22 commits into from
Dec 27, 2024

Conversation

jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Nov 28, 2024

Added some type conversions to tensor types if width != 1. The simple test case seems correct now.
Corresponding Enzyme-JAX PR: EnzymeAD/Enzyme-JAX#197

@jumerckx
Copy link
Collaborator Author

jumerckx commented Dec 2, 2024

I haven't yet fully made the changes in enzyme-tblgen.cpp, and either way this just works for the simple test case.
But I added the following manually in ArithDerivatives.inc.

mlir::Value itmp = ({
  // Computing MulFOp
  auto fwdarg_0 = dif;
  auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
  if (gutils->width != 1)
  {
    fwdarg_1 = builder.create<tensor::SplatOp>(
        op.getLoc(),
        mlir::RankedTensorType::get({gutils->width},
                                    fwdarg_1.getType()),
        fwdarg_1);
  }
  builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
});

But this is the MLIR code that is generated for this simple test:

  func.func private @fwddiffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
    %splat = tensor.splat %arg0 : tensor<2xf64>
    %0 = arith.mulf %arg1, %splat : tensor<2xf64>
    %splat_0 = tensor.splat %arg0 : tensor<2xf64>
    %1 = arith.mulf %arg1, %splat_0 : tensor<2xf64>
    %2 = arith.addf %0, %1 : tensor<2xf64>
    %3 = arith.mulf %arg0, %arg0 : f64
    return %2 : tensor<2xf64>
  }

This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this:
```
  LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const
  {
    auto op = cast<arith::MulFOp>(op0);
    if (gutils->width != 1) {
      auto newop = gutils->getNewFromOriginal(op0);
      for (auto res : newop->getResults()) {
        res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType()));
      }
    }
    gutils->eraseIfUnused(op);
    if (gutils->isConstantInstruction(op))
      return success();
    mlir::Value res = nullptr;
    if (!gutils->isConstantValue(op->getOperand(0)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(0), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          // TODO: gutils->makeBatched(...)
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        itmp.dump();
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    if (!gutils->isConstantValue(op->getOperand(1)))
    {
      auto dif = gutils->invertPointerM(op->getOperand(1), builder);
      {
        mlir::Value itmp = ({
          // Computing MulFOp
          auto fwdarg_0 = dif;
          dif.dump();
          auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0));
          builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
        });
        if (!res)
          res = itmp;
        else
        {
          auto operandType = cast<AutoDiffTypeInterface>(res.getType());
          res = operandType.createAddOp(builder, op.getLoc(), res, itmp);
        }
      }
    }
    assert(res);
    gutils->setDiffe(op->getResult(0), res, builder);
    return success();
  }
```
@wsmoses
Copy link
Member

wsmoses commented Dec 22, 2024

fix the format/etc then I think this is good to go!

@jumerckx jumerckx changed the title (WIP) Batched autodiff Batched autodiff Dec 25, 2024
@jumerckx jumerckx marked this pull request as ready for review December 25, 2024 12:01
@jumerckx
Copy link
Collaborator Author

@wsmoses I started looking into control flow (cf and scf) but can move these commits out if you'd like to merge without those changes first.

@wsmoses
Copy link
Member

wsmoses commented Dec 25, 2024

That’s totally fine/go for it, but also feel free to merge whenever it’s in a good state

@wsmoses wsmoses merged commit eeb6200 into EnzymeAD:main Dec 27, 2024
13 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants