Skip to content

Commit

Permalink
mlir: Add Enzyme ops removal on structured control flow (#2200)
Browse files Browse the repository at this point in the history
* mlir: Add Enzyme ops removal on structured control flow

* format

* use AutoDiffTypeInterface for batching

* remove

* add test with unknown number of iterations

* don't push same value twice

* tensor extract/insert

* reserve the right size

* better batchType

* better comment
  • Loading branch information
Pangoraw authored Jan 5, 2025
1 parent c759460 commit 5b330a9
Show file tree
Hide file tree
Showing 12 changed files with 701 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ using namespace mlir;
using namespace mlir::enzyme;

namespace {

static mlir::Type batchType(mlir::Type type, int64_t width) {
if (width == 1)
return type;

if (auto TT = dyn_cast<mlir::TensorType>(type)) {
SmallVector<int64_t> shape;
shape.reserve(TT.getShape().size() + 1);
shape.push_back(width);
shape.append(TT.getShape().begin(), TT.getShape().end());
return TT.clone(shape);
}

return RankedTensorType::get({width}, type);
}

class FloatTypeInterface
: public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
FloatType> {
Expand All @@ -44,12 +60,8 @@ class FloatTypeInterface
return a;
}

Type getShadowType(Type self, unsigned width) const {
if (width > 1) {
return RankedTensorType::get({width}, self);
} else {
return self;
}
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -108,16 +120,8 @@ class TensorTypeInterface
return added;
}

Type getShadowType(Type self, unsigned width) const {
if (width != 1) {
auto tenType = self.cast<TensorType>();
auto shape = tenType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.push_back(width);
newShape.append(shape.begin(), shape.end());
return RankedTensorType::get(newShape, tenType.getElementType());
}
return self;
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -148,9 +152,8 @@ class IntegerTypeInterface
return a;
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -182,9 +185,8 @@ class ComplexTypeInterface
return builder.create<complex::ConjOp>(loc, a)->getResult(0);
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Type getShadowType(Type self, int64_t width) const {
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
Expand Down
Loading

0 comments on commit 5b330a9

Please sign in to comment.