Skip to content

Commit

Permalink
[mlir] Update the return type of getNum{Dynamic|Scalable}Dims (llvm…
Browse files Browse the repository at this point in the history
…#110472)

Updates the return type of `getNumDynamicDims` and `getNumScalableDims`
from `int64_t` to `size_t`. This is for consistency with other
helpers/methods that return "size" and to reduce the number of
`static_cast`s in various places.
  • Loading branch information
banach-space authored Sep 30, 2024
1 parent 0617629 commit bfde178
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class SparseTensorType {
/// Returns the number of dimensions which have dynamic sizes.
/// The return type is `int64_t` to maintain consistency with
/// `ShapedType::Trait<T>::getNumDynamicDims`.
int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
size_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }

ArrayRef<LevelType> getLvlTypes() const { return enc.getLvlTypes(); }
LevelType getLvlType(Level l) const {
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {

/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const {
size_t getNumDynamicDims() const {
return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic);
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}

/// Get the number of scalable dimensions.
int64_t getNumScalableDims() const {
size_t getNumScalableDims() const {
return llvm::count(getScalableDims(), true);
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
LogicalResult AllocTensorOp::verify() {
if (getCopy() && !getDynamicSizes().empty())
return emitError("dynamic sizes not needed when copying a tensor");
if (!getCopy() && getType().getNumDynamicDims() !=
static_cast<int64_t>(getDynamicSizes().size()))
if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
return emitError("expected ")
<< getType().getNumDynamicDims() << " dynamic sizes";
if (getCopy() && getCopy().getType() != getType())
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2045,8 +2045,7 @@ void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult AllocOp::verify() {
auto memRefType = llvm::cast<MemRefType>(getMemref().getType());

if (static_cast<int64_t>(getDynamicSizes().size()) !=
memRefType.getNumDynamicDims())
if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
return emitOpError("dimension operand count does not equal memref "
"dynamic dimension count");

Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
if (!memRefType)
return op.emitOpError("result must be a memref");

if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
memRefType.getNumDynamicDims())
if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
return op.emitOpError("dimension operand count does not equal memref "
"dynamic dimension count");

Expand Down Expand Up @@ -283,8 +282,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
// Create new memref type (which will have fewer dynamic dimensions).
MemRefType newMemRefType =
MemRefType::Builder(memrefType).setShape(newShapeConstants);
assert(static_cast<int64_t>(dynamicSizes.size()) ==
newMemRefType.getNumDynamicDims());
assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());

// Create and insert the alloc op for the new memref.
auto newAlloc = rewriter.create<AllocLikeOp>(
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ static RankedTensorType
foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
SmallVector<Value> &foldedDynamicSizes) {
SmallVector<int64_t> staticShape(type.getShape());
assert(type.getNumDynamicDims() ==
static_cast<int64_t>(dynamicSizes.size()) &&
assert(type.getNumDynamicDims() == dynamicSizes.size() &&
"incorrect number of dynamic sizes");

// Compute new static and dynamic sizes.
Expand Down Expand Up @@ -894,8 +893,7 @@ void EmptyOp::build(OpBuilder &builder, OperationState &result,
}

LogicalResult EmptyOp::verify() {
if (getType().getNumDynamicDims() !=
static_cast<int64_t>(getDynamicSizes().size()))
if (getType().getNumDynamicDims() != getDynamicSizes().size())
return emitOpError("incorrect number of dynamic sizes, has ")
<< getDynamicSizes().size() << ", expected "
<< getType().getNumDynamicDims();
Expand Down Expand Up @@ -3672,8 +3670,7 @@ void SplatOp::getAsmResultNames(
}

LogicalResult SplatOp::verify() {
if (getType().getNumDynamicDims() !=
static_cast<int64_t>(getDynamicSizes().size()))
if (getType().getNumDynamicDims() != getDynamicSizes().size())
return emitOpError("incorrect number of dynamic sizes, has ")
<< getDynamicSizes().size() << ", expected "
<< getType().getNumDynamicDims();
Expand Down

0 comments on commit bfde178

Please sign in to comment.