Skip to content

Commit

Permalink
better batchType
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Jan 5, 2025
1 parent 652f057 commit 0746782
Showing 1 changed file with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ using namespace mlir::enzyme;
namespace {

static mlir::Type batchType(mlir::Type type, int64_t width) {
if (width > 1 || ShapedType::isDynamic(width)) {
if (auto TT = dyn_cast<mlir::TensorType>(type)) {
SmallVector<int64_t> shape;
shape.reserve(TT.getShape().size() + 1);
shape.push_back(width);
shape.append(TT.getShape().begin(), TT.getShape().end());
return TT.clone(shape);
}
return RankedTensorType::get({width}, type);
if (width == 1)
return type;

if (auto TT = dyn_cast<mlir::TensorType>(type)) {
SmallVector<int64_t> shape;
shape.reserve(TT.getShape().size() + 1);
shape.push_back(width);
shape.append(TT.getShape().begin(), TT.getShape().end());
return TT.clone(shape);
}
return type;

return RankedTensorType::get({width}, type);
}

class FloatTypeInterface
Expand Down

0 comments on commit 0746782

Please sign in to comment.