Skip to content

Commit

Permalink
Merge pull request #233 from Xilinx/bump_to_308c45e6
Browse files Browse the repository at this point in the history
Merge with fixes of 308c45e (9)
  • Loading branch information
mgehre-amd authored Aug 15, 2024
2 parents bb40cfa + 55a27dc commit 1affa1f
Show file tree
Hide file tree
Showing 46 changed files with 458 additions and 427 deletions.
4 changes: 2 additions & 2 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ struct OpBinder {
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
auto integerAttr = element.dyn_cast<IntegerAttr>();
auto integerAttr = dyn_cast<IntegerAttr>(element);
if (!integerAttr)
return failure();
IntegerType t = cast<IntegerType>(integerAttr.getType());
Expand All @@ -200,7 +200,7 @@ struct OpBinder {
return success();
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
StringAttr stringAttr = element.dyn_cast<StringAttr>();
StringAttr stringAttr = dyn_cast<StringAttr>(element);
if (!stringAttr)
return failure();
values.push_back(stringAttr.getValue().str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), axisScalar, finalOffset);

Torch::BaseTensorType resultTensorType =
resultType.cast<Torch::BaseTensorType>();
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
Expand Down Expand Up @@ -1899,7 +1899,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

// If its a dense resource attr we need to convert to a dense type:
if (DenseResourceElementsAttr rattr =
attr.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(attr)) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
Expand All @@ -1916,7 +1916,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

Attribute splattr;
if (isa<SplatElementsAttr>(attr)) {
auto denseAttr = attr.cast<DenseElementsAttr>();
auto denseAttr = cast<DenseElementsAttr>(attr);
splattr = denseAttr.getSplatValue<Attribute>();
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// set the splitted axis to variable shape
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
for (auto result : binder.op->getResultTypes()) {
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
}

Expand Down Expand Up @@ -1437,7 +1437,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
for (auto result : binder.op->getResultTypes()) {
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
Value operandB =
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
if (resultType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(resultType)) {
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
} else if (resultType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(resultType)) {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
} else {
return rewriter.notifyMatchFailure(
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1840,7 +1840,7 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern<AtenViewAsRealOp> {

RankedTensorType inputType = input.getType().cast<RankedTensorType>();
auto inputElementType = getElementTypeOrSelf(input.getType());
if (!inputElementType.isa<ComplexType>()) {
if (!isa<ComplexType>(inputElementType)) {
return op.emitError("only ComplexType is allowed as input type");
}
Type elementType = resultType.getElementType();
Expand Down
35 changes: 17 additions & 18 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
auto resultTy = op.getType().cast<ValueTensorType>();
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Type elementType = cast<TensorType>(newResultType).getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
if (accumulatorDType != resultDTy) {
elementType = accumulatorDType;
Expand Down Expand Up @@ -201,7 +201,7 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {

if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType);
}
Expand Down Expand Up @@ -307,7 +307,7 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
unsigned rhsRank = rhsType.getRank();

Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = newResultType.cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(newResultType);
Type elementType = resultType.getElementType();

// The different cases of torch_matmul op is mentioned here:
Expand Down Expand Up @@ -600,9 +600,9 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
Type lhsElementType = cast<RankedTensorType>(lhsType).getElementType();
Type rhsElementType = cast<RankedTensorType>(rhsType).getElementType();

if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -712,9 +712,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
auto resultDTy = resultTy.toBuiltinTensor().getElementType();

if (!inputDTy.isa<mlir::FloatType, mlir::IntegerType>() ||
!weightDTy.isa<mlir::FloatType, mlir::IntegerType>() ||
!resultDTy.isa<mlir::FloatType, mlir::IntegerType>())
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
return op.emitError("unimplemented: non-fp not-int type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
size_t numSpatialDims = inRank - 2;
Expand Down Expand Up @@ -790,9 +790,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput;
if (transposed) {
if (!inputDTy.isa<mlir::FloatType>() ||
!weightDTy.isa<mlir::FloatType>() ||
!resultDTy.isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
!isa<mlir::FloatType>(resultDTy))
return rewriter.notifyMatchFailure(
op, "transpose does not support non-fp type yet");

Expand Down Expand Up @@ -927,10 +926,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) {
Value c0;
if (accumulatorDType.isa<mlir::FloatType>()) {
if (isa<mlir::FloatType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(accumulatorDType, 0.0));
} else if (accumulatorDType.isa<mlir::IntegerType>()) {
} else if (isa<mlir::IntegerType>(accumulatorDType)) {
c0 = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(accumulatorDType, 0));
}
Expand Down Expand Up @@ -1021,7 +1020,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
Expand Down Expand Up @@ -1081,7 +1080,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
Expand Down Expand Up @@ -1125,7 +1124,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
Expand Down Expand Up @@ -1203,7 +1202,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(newResultType).getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
Expand Down
16 changes: 8 additions & 8 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ static LogicalResult createPoolingOp(
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc();
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type");

Value initValue =
Expand Down Expand Up @@ -248,7 +248,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
Value initValue =
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);
Expand Down Expand Up @@ -366,7 +366,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
Expand Down Expand Up @@ -447,7 +447,7 @@ class ConvertAtenMaxPool2dWithIndicesOp
// `maxpool2d` contains the result of maxpool2d operation over the input.
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
Value maxPool2d, paddedInput;
SmallVector<Value, 4> outTensorShape;
Expand Down Expand Up @@ -586,7 +586,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
self.getType().cast<RankedTensorType>().getElementType();
Type resultType = typeConverter->convertType(op.getType());
Type resultElementType =
resultType.cast<RankedTensorType>().getElementType();
cast<RankedTensorType>(resultType).getElementType();

bool ceilMode;
SmallVector<Value, Dim> kernelSizeIntValues;
Expand Down Expand Up @@ -647,9 +647,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value avg;
if (resultElementType.isa<mlir::IntegerType>())
if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
else if (resultElementType.isa<mlir::FloatType>())
else if (isa<mlir::FloatType>(resultElementType))
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
b.create<linalg::YieldOp>(loc, avg);
})
Expand Down Expand Up @@ -739,7 +739,7 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper {
Type auxTensorElementType = auxTensorType.getElementType();
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
smallestFPValueAttr);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ConvertAtenUniformOp : public OpConversionPattern<AtenUniformOp> {
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType();

if (!elemTy.isa<mlir::FloatType>())
if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type");

if (!generator.getType().isa<Torch::NoneType>())
Expand Down
Loading

0 comments on commit 1affa1f

Please sign in to comment.