diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 77d94eb0f8b9..d3260500cfa8 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -178,7 +178,7 @@ struct OpBinder { } if (auto arrayAttr = dyn_cast(attr)) { for (auto element : arrayAttr) { - auto integerAttr = element.dyn_cast(); + auto integerAttr = dyn_cast(element); if (!integerAttr) return failure(); IntegerType t = cast(integerAttr.getType()); @@ -200,7 +200,7 @@ struct OpBinder { return success(); if (auto arrayAttr = dyn_cast(attr)) { for (auto element : arrayAttr) { - StringAttr stringAttr = element.dyn_cast(); + StringAttr stringAttr = dyn_cast(element); if (!stringAttr) return failure(); values.push_back(stringAttr.getValue().str()); diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 876b81092ae9..7edf514b47aa 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -98,7 +98,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, // Compute the knowledge based on the inferred type. auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = result_ty.cast().getElementType(); + inferredKnowledge.dtype = cast(result_ty).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f661f0e02ebd..f8829ec06b25 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1287,7 +1287,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), axisScalar, finalOffset); Torch::BaseTensorType resultTensorType = - resultType.cast(); + cast(resultType); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( binder.op, "expected result type to have a dtype"); @@ -1899,7 +1899,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // If its a dense resource attr we need to convert to a dense type: if (DenseResourceElementsAttr rattr = - attr.dyn_cast_or_null()) { + dyn_cast_or_null(attr)) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { @@ -1916,7 +1916,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Attribute splattr; if (isa(attr)) { - auto denseAttr = attr.cast(); + auto denseAttr = cast(attr); splattr = denseAttr.getSplatValue(); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4a3ca533d242..f0b1e14780e9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1366,7 +1366,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // set the splitted axis to variable shape llvm::SmallVector intermediateShape(result0Ty.getSizes()); for (auto result : binder.op->getResultTypes()) { - int64_t d = result.cast().getSizes()[dim]; + int64_t d = cast(result).getSizes()[dim]; intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; } @@ -1437,7 +1437,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector intermediateShape(result0Ty.getSizes()); for (auto result : binder.op->getResultTypes()) { - int64_t d = result.cast().getSizes()[dim]; + int64_t d = cast(result).getSizes()[dim]; intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index c15261c3bd19..abd119fc0ac5 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -272,9 +272,9 @@ class ConvertAtenAddOp : public OpConversionPattern { convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType); Value operandB = convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType); - if (resultType.isa()) { + if (isa(resultType)) { rewriter.replaceOpWithNewOp(op, operandA, operandB); - } else if (resultType.isa()) { + } else if (isa(resultType)) { rewriter.replaceOpWithNewOp(op, operandA, operandB); } else { return rewriter.notifyMatchFailure( diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 56d80425df6b..adaaca504263 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1840,7 +1840,7 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { RankedTensorType inputType = input.getType().cast(); auto inputElementType = getElementTypeOrSelf(input.getType()); - if (!inputElementType.isa()) { + if (!isa(inputElementType)) { return op.emitError("only ComplexType is allowed as input type"); } Type elementType = resultType.getElementType(); diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9bd18715cf54..6ca3e3c6a063 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -131,7 +131,7 @@ class ConvertAtenMmOp : public OpConversionPattern { auto resultTy = op.getType().cast(); auto resultDTy = resultTy.toBuiltinTensor().getElementType(); Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = newResultType.cast().getElementType(); + Type elementType = cast(newResultType).getElementType(); auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); if (accumulatorDType != resultDTy) { elementType = accumulatorDType; @@ -201,7 +201,7 @@ class ConvertAtenMmOp : public OpConversionPattern { if (accumulatorDType != resultDTy) { Type resultElementType = - newResultType.cast().getElementType(); + cast(newResultType).getElementType(); matmul = torch_to_linalg::convertTensorToElementType( rewriter, loc, matmul, resultElementType); } @@ -307,7 +307,7 @@ class ConvertAtenMatmulOp : public OpConversionPattern { unsigned rhsRank = rhsType.getRank(); Type newResultType = getTypeConverter()->convertType(op.getType()); - auto resultType = newResultType.cast(); + auto resultType = cast(newResultType); Type elementType = resultType.getElementType(); // The different cases of torch_matmul op is mentioned here: @@ -600,9 +600,9 @@ class ConvertAtenBmmOp : public OpConversionPattern { RankedTensorType rhsType = rhs.getType().cast(); Type newResultType = getTypeConverter()->convertType(op.getType()); Type resultElementType = - newResultType.cast().getElementType(); - Type lhsElementType = lhsType.cast().getElementType(); - Type rhsElementType = rhsType.cast().getElementType(); + cast(newResultType).getElementType(); + Type lhsElementType = cast(lhsType).getElementType(); + Type rhsElementType = cast(rhsType).getElementType(); if (lhsType.getRank() != 3 || rhsType.getRank() != 3) { return rewriter.notifyMatchFailure( @@ -712,9 +712,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto weightDTy = weight.getType().cast().getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType(); - if (!inputDTy.isa() || - !weightDTy.isa() || - !resultDTy.isa()) + if (!isa(inputDTy) || + !isa(weightDTy) || + !isa(resultDTy)) return op.emitError("unimplemented: non-fp not-int type"); size_t inRank = input.getType().cast().getRank(); size_t numSpatialDims = inRank - 2; @@ -790,9 +790,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector outDims{inBatch, weightBatch}; Value paddedInput; if (transposed) { - if (!inputDTy.isa() || - !weightDTy.isa() || - !resultDTy.isa()) + if (!isa(inputDTy) || !isa(weightDTy) || + !isa(resultDTy)) return rewriter.notifyMatchFailure( op, "transpose does not support non-fp type yet"); @@ -927,10 +926,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { accumulatorDType); if (bias.getType().isa()) { Value c0; - if (accumulatorDType.isa()) { + if (isa(accumulatorDType)) { c0 = rewriter.create( loc, FloatAttr::get(accumulatorDType, 0.0)); - } else if (accumulatorDType.isa()) { + } else if (isa(accumulatorDType)) { c0 = rewriter.create( loc, IntegerAttr::get(accumulatorDType, 0)); } @@ -1021,7 +1020,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = - newResultType.cast().getElementType(); + cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } @@ -1081,7 +1080,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = - newResultType.cast().getElementType(); + cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } @@ -1125,7 +1124,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = - newResultType.cast().getElementType(); + cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } @@ -1203,7 +1202,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = - newResultType.cast().getElementType(); + cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index b1f114af8c72..c85604dc1d53 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -154,7 +154,7 @@ static LogicalResult createPoolingOp( SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { Location loc = op->getLoc(); Type elementType = self.getType().cast().getElementType(); - if (!elementType.isa() && !supportNonFPInput) + if (!isa(elementType) && !supportNonFPInput) return op->emitError("unimplemented: non-floating point type"); Value initValue = @@ -248,7 +248,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { Type elementType = self.getType().cast().getElementType(); TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), + APFloat::getInf(cast(elementType).getFloatSemantics(), /*Negative=*/true)); Value initValue = rewriter.create(op->getLoc(), smallestFPValueAttr); @@ -366,7 +366,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getInf( - elementType.cast().getFloatSemantics(), + cast(elementType).getFloatSemantics(), /*Negative=*/true)); if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, @@ -447,7 +447,7 @@ class ConvertAtenMaxPool2dWithIndicesOp // `maxpool2d` contains the result of maxpool2d operation over the input. auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), + APFloat::getInf(cast(elementType).getFloatSemantics(), /*Negative=*/true)); Value maxPool2d, paddedInput; SmallVector outTensorShape; @@ -586,7 +586,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { self.getType().cast().getElementType(); Type resultType = typeConverter->convertType(op.getType()); Type resultElementType = - resultType.cast().getElementType(); + cast(resultType).getElementType(); bool ceilMode; SmallVector kernelSizeIntValues; @@ -647,9 +647,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { Value avg; - if (resultElementType.isa()) + if (isa(resultElementType)) avg = b.create(loc, args[0], divisor); - else if (resultElementType.isa()) + else if (isa(resultElementType)) avg = b.create(loc, args[0], divisor); b.create(loc, avg); }) @@ -739,7 +739,7 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { Type auxTensorElementType = auxTensorType.getElementType(); auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), + APFloat::getInf(cast(elementType).getFloatSemantics(), /*Negative=*/true)); buffVal = rewriter.create(loc, elementType, smallestFPValueAttr); diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 35c349a6a673..3b18844df516 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -130,7 +130,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { RankedTensorType resultType = self.getType().cast(); Type elemTy = resultType.getElementType(); - if (!elemTy.isa()) + if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); if (!generator.getType().isa()) diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 825389050211..bd8b1fc6bfb1 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -70,7 +70,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { input.getType().template cast(); Type idxElementType = getElementTypeOrSelf(typec->convertType(idxResultType)); - if (!idxElementType.isa()) + if (!isa(idxElementType)) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires integer-like result type"); @@ -89,8 +89,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Type inElementType = inputType.getElementType(); bool isUnsigned = false; - if (!inElementType.isa()) { - if (inElementType.isa()) { + if (!isa(inElementType)) { + if (isa(inElementType)) { auto integerTy = op.getSelf() .getType() .template cast() @@ -121,22 +121,21 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { loc, getAsOpFoldResult(resultShape), inElementType); Value fillValue; - if (inElementType.isa()) { + if (isa(inElementType)) { fillValue = rewriter.create( - loc, - rewriter.getFloatAttr( - inElementType, - APFloat::getInf( - inElementType.cast().getFloatSemantics(), - /*Negative=*/isMax))); + loc, rewriter.getFloatAttr( + inElementType, + APFloat::getInf( + cast(inElementType).getFloatSemantics(), + /*Negative=*/isMax))); } else if (!isUnsigned) { - auto width = inElementType.cast().getWidth(); + auto width = cast(inElementType).getWidth(); auto init = isMax ? APSInt::getSignedMinValue(width) : APSInt::getSignedMaxValue(width); fillValue = rewriter.create( loc, rewriter.getIntegerAttr(inElementType, init)); } else if (isUnsigned) { - auto width = inElementType.cast().getWidth(); + auto width = cast(inElementType).getWidth(); auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width); fillValue = rewriter.create( loc, rewriter.getIntegerAttr(inElementType, init)); @@ -180,7 +179,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { rewriter.create(loc, dim)); Value resultVal, predicate; - if (inElementType.isa()) { + if (isa(inElementType)) { arith::CmpFPredicate predType; if (isMax) { predType = arith::CmpFPredicate::OGT; @@ -300,21 +299,21 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, return b.create(loc, b.getZeroAttr(elementType)); if (isa(op)) { - if (elementType.isa()) + if (isa(elementType)) return b.create(loc, b.getFloatAttr(elementType, 1.0)); - else if (elementType.isa()) + else if (isa(elementType)) return b.create(loc, b.getIntegerAttr(elementType, 1)); } if (isa(op)) { - if (elementType.isa()) + if (isa(elementType)) return b.create( loc, b.getFloatAttr( elementType, APFloat::getInf( - elementType.cast().getFloatSemantics(), + cast(elementType).getFloatSemantics(), /*Negative=*/true))); - else if (elementType.isa() && + else if (isa(elementType) && elementType.getIntOrFloatBitWidth() != 8) return b.create( loc, b.getIntegerAttr(elementType, @@ -323,14 +322,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, } if (isa(op)) { - if (elementType.isa()) + if (isa(elementType)) return b.create( loc, b.getFloatAttr( elementType, APFloat::getInf( - elementType.cast().getFloatSemantics(), + cast(elementType).getFloatSemantics(), /*Negative=*/false))); - else if (elementType.isa() && + else if (isa(elementType) && elementType.getIntOrFloatBitWidth() != 8) return b.create( loc, b.getIntegerAttr(elementType, @@ -359,25 +358,25 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; - if (resultElementType.isa()) + if (isa(resultElementType)) return b.create(loc, self, result); - else if (resultElementType.isa()) + else if (isa(resultElementType)) return b.create(loc, self, result); } else if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; - if (resultElementType.isa()) + if (isa(resultElementType)) return b.create(loc, self, result); - else if (resultElementType.isa()) + else if (isa(resultElementType)) return b.create(loc, self, result); } else if (auto max = dyn_cast(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; - if (resultElementType.isa()) + if (isa(resultElementType)) return b.create(loc, self, result); - else if (resultElementType.isa()) { + else if (isa(resultElementType)) { IntegerType intType = max.getSelf() .getType() .cast() @@ -392,9 +391,9 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; - if (resultElementType.isa()) + if (isa(resultElementType)) return b.create(loc, self, result); - else if (resultElementType.isa()) { + else if (isa(resultElementType)) { IntegerType intType = min.getSelf() .getType() .cast() @@ -626,10 +625,10 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { if ((isa(op) || isa(op) || isa(op)) && - !elemType.isa()) + !isa(elemType)) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); - if (isa(op) && elemType.isa() && + if (isa(op) && isa(elemType) && elemType.getIntOrFloatBitWidth() == 8) return rewriter.notifyMatchFailure(op, "uint8 is not supported"); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 385f5b435e1b..1a549cd5e399 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -100,7 +100,7 @@ class ConvertAtenConstantPadNdOp } Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = newResultType.cast().getElementType(); + Type elementType = cast(newResultType).getElementType(); Value castedValue = convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); @@ -553,7 +553,7 @@ class ConvertAtenArangeStartStepOp // The size of the result is calculated as follows: // ceil((end - start)/step) Value resultShape; - if (dtype.isa()) { + if (isa(dtype)) { Value subOut = rewriter.create(loc, end, start); resultShape = rewriter.create(loc, subOut, step); } else { @@ -585,7 +585,7 @@ class ConvertAtenArangeStartStepOp index = castIndexToInt64(b, loc, index); index = convertScalarToDtype(b, loc, index, dtype); Value mulOut, result; - if (dtype.isa()) { + if (isa(dtype)) { mulOut = b.create(loc, step, index); result = b.create(loc, start, mulOut); } else { diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 85680d456c19..f2f5c2ecf8e7 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -35,16 +35,16 @@ using namespace mlir::torch::Torch; template static bool hasElementType(Value tensor) { auto tensorType = tensor.getType().cast(); Type tensorElementType = tensorType.getElementType(); - return tensorElementType.isa(); + return isa(tensorElementType); } template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, Value lhs, Value rhs) { - if (type.isa()) + if (isa(type)) return b.create(loc, fpred, lhs, rhs); - if (IntegerType intType = type.dyn_cast()) { + if (IntegerType intType = dyn_cast(type)) { if (intType.isUnsigned()) return b.create(loc, iupred, lhs, rhs); if (intType.isSigned()) @@ -319,7 +319,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(bitwiseAndScalar.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { bitwiseAndScalar.emitError( "bitwise_and.Scalar does not support non-integer input dtype."); return nullptr; @@ -371,7 +371,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { bitwiseRightShiftTensor.emitError( "Bitwise_Right_Shift op does not support non-integer input dtype."); return nullptr; @@ -385,7 +385,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { bitwiseLeftShiftTensor.emitError( "Bitwise_Left_Shift op does not support non-integer input dtype."); return nullptr; @@ -623,7 +623,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); - if (dtype.isa()) { + if (isa(dtype)) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { @@ -647,7 +647,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType, /*originalScalar=*/sub.getAlpha()); - if (dtype.isa()) { + if (isa(dtype)) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { @@ -664,10 +664,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value alpha = convertScalarToDtype( b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), /*dstOriginalDtype=*/dtype); - if (dtype.isa()) { + if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); - } else if (dtype.isa()) { + } else if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } @@ -690,10 +690,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value alpha = convertScalarToDtype(b, loc, operands[2], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); - if (dtype.isa()) { + if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); - } else if (dtype.isa()) { + } else if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } @@ -708,9 +708,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - if (dtype.isa()) { + if (isa(dtype)) { return b.create(loc, lhs, rhs); - } else if (dtype.isa()) { + } else if (isa(dtype)) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); @@ -720,7 +720,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(atan2.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { atan2.emitError("Atan2 requires floating point result type"); return nullptr; } @@ -759,9 +759,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, lhs, rhs); - else if (dtype.isa()) { + else if (isa(dtype)) { if (dtype.isUnsignedInteger()) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); @@ -777,7 +777,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value div; - if (dtype.isa()) + if (isa(dtype)) div = b.create(loc, lhs, rhs); else { if (dtype.isUnsignedInteger()) @@ -798,7 +798,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. - if (dtype.isa()) { + if (isa(dtype)) { Value ceil = b.create(loc, div); Value floor = b.create(loc, div); Value cstZero = b.create(loc, b.getZeroAttr(dtype)); @@ -811,7 +811,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, div); else if (!dtype.isUnsignedInteger()) { Type defaultIntToFloatType = b.getF64Type(); @@ -831,7 +831,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (auto pow = dyn_cast(op)) { Type dtype = pow.getType().cast().getDtype(); - if (!dtype.isa()) { + if (!isa(dtype)) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -857,7 +857,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(pow.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -870,7 +870,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(imag.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { imag.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -882,7 +882,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(real.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { real.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -898,10 +898,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], otherPromoted); - if (IntegerType intType = dtype.dyn_cast()) { + if (IntegerType intType = dyn_cast(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. gtScalar.emitError( @@ -928,10 +928,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, arith::CmpFPredicate::UGE, payloadArgs[0], otherPromoted); - if (IntegerType intType = dtype.dyn_cast()) { + if (IntegerType intType = dyn_cast(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. geScalar.emitError( @@ -955,7 +955,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - if (dtype.isa()) { + if (isa(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. eqScalar.emitError( @@ -971,7 +971,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - if (dtype.isa()) { + if (isa(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. neScalar.emitError( @@ -989,10 +989,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share // a lot of code that can be refactored. - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, arith::CmpFPredicate::ULT, payloadArgs[0], otherPromoted); - if (IntegerType intType = dtype.dyn_cast()) { + if (IntegerType intType = dyn_cast(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. ltScalar.emitError( @@ -1017,10 +1017,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code // that can be refactored. - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, arith::CmpFPredicate::ULE, payloadArgs[0], otherPromoted); - if (IntegerType intType = dtype.dyn_cast()) { + if (IntegerType intType = dyn_cast(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. leScalar.emitError( @@ -1096,14 +1096,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(clamp.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { clamp.emitError("unimplement type for clamp"); return nullptr; } Type dstOriginalDtype = clamp.getType().cast().getDtype(); bool isUnsigned = isa(dstOriginalDtype); - if (auto intTy = dstOriginalDtype.dyn_cast()) { + if (auto intTy = dyn_cast(dstOriginalDtype)) { isUnsigned = intTy.isUnsigned(); } auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { @@ -1112,11 +1112,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*dstOriginalDtype=*/dstOriginalDtype); Value pred; - if (dtype.isa()) { + if (isa(dtype)) { auto cmp = getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; pred = b.create(loc, cmp, input, clamp); - } else if (dtype.isa()) { + } else if (isa(dtype)) { auto cmp = isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; if (getMax) @@ -1151,10 +1151,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( isMinNone = false; auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value pred; - if (dtype.isa()) { + if (isa(dtype)) { pred = b.create(loc, arith::CmpFPredicate::ULT, result, minPromoted); - } else if (dtype.isa()) { + } else if (isa(dtype)) { pred = b.create(loc, arith::CmpIPredicate::slt, result, minPromoted); } else { @@ -1169,10 +1169,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( max = isMinNone ? payloadArgs[1] : payloadArgs[2]; auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); Value pred; - if (dtype.isa()) { + if (isa(dtype)) { pred = b.create(loc, arith::CmpFPredicate::UGT, result, maxPromoted); - } else if (dtype.isa()) { + } else if (isa(dtype)) { pred = b.create(loc, arith::CmpIPredicate::sgt, result, maxPromoted); } else { @@ -1194,10 +1194,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value alpha = convertScalarToDtype( b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), /*dstOriginalDtype=*/dtype); - if (dtype.isa()) { + if (isa(dtype)) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); - } else if (dtype.isa()) { + } else if (isa(dtype)) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } @@ -1211,9 +1211,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, lhs, rhs); - if (dtype.isa()) + if (isa(dtype)) return b.create(loc, lhs, rhs); mulScalar.emitError("unimplemented: Only integer/float dtype supported"); return nullptr; @@ -1246,7 +1246,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(divScalar.getType()) .cast() .getElementType(); - if (!dtype.isa()) { + if (!isa(dtype)) { divScalar.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -1263,9 +1263,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, operands[1], newResultType); Value result; - if (newResultType.isa()) { + if (isa(newResultType)) { result = b.create(loc, self, other); - } else if (newResultType.isa()) { + } else if (isa(newResultType)) { result = b.create(loc, self, other); } else { remScalar.emitError( @@ -1283,9 +1283,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value result; - if (newResultType.isa()) { + if (isa(newResultType)) { result = b.create(loc, self, other); - } else if (newResultType.isa()) { + } else if (isa(newResultType)) { result = b.create(loc, self, other); } else { remTensor.emitError( @@ -1303,12 +1303,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value result; - if (newResultType.isa()) { + if (isa(newResultType)) { Value n = b.create(loc, self, other); n = b.create(loc, n); Value n_y = b.create(loc, n, other); result = b.create(loc, self, n_y); - } else if (newResultType.isa()) { + } else if (isa(newResultType)) { Value n = b.create(loc, self, other); Value n_y = b.create(loc, n, other); result = b.create(loc, self, n_y); @@ -1349,7 +1349,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value predicate; - if (dtype.isa()) + if (isa(dtype)) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else @@ -1372,7 +1372,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; - if (dtype.isa()) + if (isa(dtype)) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else @@ -1426,7 +1426,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type elementType = converter->convertType(bitwiseNot.getType()) .cast() .getElementType(); - if (elementType.isa()) { + if (isa(elementType)) { bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); return nullptr; } @@ -2253,7 +2253,7 @@ class ConvertLogitOp : public OpConversionPattern { auto inputType = input.getType().cast(); auto inputElementType = inputType.getElementType(); - if (!inputElementType.isa()) { + if (!isa(inputElementType)) { op.emitError("Logit does not support non-floating point type"); return failure(); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 0e49eee04745..7c8b7c8a4980 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -548,7 +548,7 @@ FailureOr torch_to_linalg::getBackendTypeForScalarType( } Type type = *maybeType; // The linalg-on-tensors backend currently expects integers to be signless. - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless); } return type; diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 208dcefcc85f..60206f03999b 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -140,11 +140,11 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { // If the target type is non-torch type, then use TypeConverter to convert // the type of the source. - if (targetType.isa()) { + if (isa(targetType)) { targetType = Torch::FloatType::get(op->getContext()); torchArg = typeConverter->materializeSourceConversion( rewriter, scfWhileOp.getLoc(), targetType, {to}); - } else if (targetType.isa()) { + } else if (isa(targetType)) { unsigned bitWidth = targetType.getIntOrFloatBitWidth(); if (bitWidth == 1) targetType = Torch::BoolType::get(op->getContext()); @@ -179,7 +179,7 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { // If the argument is a torch tensor, directly add it in the list of // iter args. - if (torchType.isa()) { + if (isa(torchType)) { loopConditionIterArgs.push_back(torchArg); continue; } @@ -262,11 +262,11 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { // If the target type is non-torch type, then use TypeConverter to convert // the type of the source. - if (targetType.isa()) { + if (isa(targetType)) { targetType = Torch::FloatType::get(op->getContext()); torchArg = typeConverter->materializeSourceConversion( rewriter, scfForOp.getLoc(), targetType, {to}); - } else if (targetType.isa()) { + } else if (isa(targetType)) { unsigned bitWidth = targetType.getIntOrFloatBitWidth(); if (bitWidth == 1) targetType = Torch::BoolType::get(op->getContext()); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 4d6c8d194554..ae4e69ccd4a0 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -42,11 +42,11 @@ static Value getConstantLike(OpBuilder &b, Location loc, T constant, Value val) { Type ty = getElementTypeOrSelf(val.getType()); auto getAttr = [&]() -> Attribute { - if (ty.isa()) + if (isa(ty)) return b.getIntegerAttr(ty, constant); - if (ty.isa()) + if (isa(ty)) return b.getFloatAttr(ty, constant); - if (auto complexTy = ty.dyn_cast()) + if (auto complexTy = dyn_cast(ty)) return complex::NumberAttr::get(complexTy, constant, 0); llvm_unreachable("unhandled element type"); }; @@ -105,17 +105,17 @@ bool skipMultiplyAlpha(Value alphaValue) { static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); - if (elementType.isa()) { + if (isa(elementType)) { auto constAttr = SplatElementsAttr::get( constType, - APFloat::getInf(elementType.cast().getFloatSemantics(), + APFloat::getInf(cast(elementType).getFloatSemantics(), /*negative=*/false)); return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } - if (elementType.isa()) { - auto integerType = elementType.cast(); + if (isa(elementType)) { + auto integerType = cast(elementType); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( @@ -134,17 +134,17 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, static FailureOr getMinValueOfDtype(Operation *op, Type elementType, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementType); - if (elementType.isa()) { + if (isa(elementType)) { auto constAttr = SplatElementsAttr::get( constType, - APFloat::getInf(elementType.cast().getFloatSemantics(), + APFloat::getInf(cast(elementType).getFloatSemantics(), /*negative=*/true)); return rewriter .create(op->getLoc(), constType, constAttr) .getResult(); } - if (elementType.isa()) { - auto integerType = elementType.cast(); + if (isa(elementType)) { + auto integerType = cast(elementType); DenseElementsAttr constAttr; if (integerType.isUnsigned()) { constAttr = SplatElementsAttr::get( @@ -446,7 +446,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern { op, "only support constant str rounding mode"); // if trunc and int, do nothing - if (roundingMode == "trunc" && outElemTy.isa()) { + if (roundingMode == "trunc" && isa(outElemTy)) { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. auto sign = rewriter.create(loc, result); @@ -457,7 +457,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern { if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) - if (outElemTy.isa()) + if (isa(outElemTy)) result = rewriter.create(loc, result).getResult(); else if (!outElemTy.isUnsignedInteger()) { TensorType defaultIntToFloatType = @@ -518,10 +518,10 @@ class ConvertAtenCompareOp : public OpConversionPattern { chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; - if (lhsElemTy.isa()) { + if (isa(lhsElemTy)) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::FLOAT); - } else if (lhsElemTy.isa()) { + } else if (isa(lhsElemTy)) { compareTypeAttr = chlo::ComparisonTypeAttr::get( op->getContext(), chlo::ComparisonType::SIGNED); } @@ -985,14 +985,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto lhsTy = lhs.getType().cast(); auto lhsElemTy = lhsTy.getElementType(); - if (!lhsElemTy.isa()) { + if (!isa(lhsElemTy)) { return op->emitError("only float tensor in relu op is supported"); } Value zeroTensor; zeroTensor = getConstantLike( rewriter, op->getLoc(), - APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), + APFloat::getZero(cast(lhsElemTy).getFloatSemantics(), false), lhs); rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); @@ -1160,7 +1160,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), batchNormTrainingResult.getResult(0), - outputTy.cast()); + cast(outputTy)); } else { auto batchNormTrainingResult = rewriter.create( @@ -1204,7 +1204,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( runningVar, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), bnResult, - outputTy.cast()); + cast(outputTy)); } else { output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, @@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ->convertType(op.getType()) .cast(); auto dtype = outType.getElementType(); - if (!dtype.isa() && !dtype.isa()) { + if (!isa(dtype) && !isa(dtype)) { return rewriter.notifyMatchFailure( op, "unimplemented: only int or float dtype supported"); } @@ -1607,7 +1607,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto shape_tensor = rewriter.create( loc, rewriter.getI64TensorAttr(elements)); auto outTy = getTypeConverter()->convertType(op.getType()); - auto outElemTy = outTy.cast().getElementType(); + auto outElemTy = cast(outTy).getElementType(); Value from = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getFrom(), outElemTy); Value to = diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 115609b461c3..ac1c8bacf9a8 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -34,14 +34,14 @@ static Value createInitialValueForGatherScatterOp(Operation *op, PatternRewriter &rewriter) { auto elementTy = constType.getElementType(); if (isa(op)) { - if (elementTy.isa()) { + if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( - elementTy.cast().getFloatSemantics(), + cast(elementTy).getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (elementTy.isa() && + } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 40b0dd691071..b8a5321306bb 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -37,14 +37,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, // Avg pooling if (isa(op)) { - if (elementTy.isa()) { + if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( - elementTy.cast().getFloatSemantics(), + cast(elementTy).getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (elementTy.isa() && + } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); @@ -55,14 +55,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, // Max pooling if (isa(op)) { - if (elementTy.isa()) { + if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( - constType, {APFloat::getInf( - elementTy.cast().getFloatSemantics(), - /*negative=*/true)}); + constType, + {APFloat::getInf(cast(elementTy).getFloatSemantics(), + /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (elementTy.isa() && + } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 0b27d0748855..c525c8b40de5 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -37,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, auto constType = RankedTensorType::get({}, elementTy); if (isa(op)) { - if (elementTy.isa()) { + if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( - elementTy.cast().getFloatSemantics(), + cast(elementTy).getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (elementTy.isa() && + } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); @@ -54,14 +54,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } if (isa(op)) { - if (elementTy.isa()) { + if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( - constType, {APFloat::getInf( - elementTy.cast().getFloatSemantics(), - /*negative=*/true)}); + constType, + {APFloat::getInf(cast(elementTy).getFloatSemantics(), + /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (elementTy.isa() && + } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, @@ -72,14 +72,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } if (isa(op)) { - if (elementTy.isa()) { + if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( - constType, {APFloat::getInf( - elementTy.cast().getFloatSemantics(), - /*negative=*/false)}); + constType, + {APFloat::getInf(cast(elementTy).getFloatSemantics(), + /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (elementTy.isa() && + } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, @@ -234,7 +234,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported! - if (inputElemTy.isa() && + if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " @@ -305,7 +305,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "Only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported - if (inputElemTy.isa() && + if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " @@ -319,7 +319,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ->convertType(op.getResult(1).getType()) .template cast(); Type idxElementType = idxResultType.getElementType(); - if (!idxElementType.isa()) { + if (!isa(idxElementType)) { return op.emitError("Aten.max.dim needs integer-like result"); } @@ -404,7 +404,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported - if (inputElemTy.isa() && + if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " @@ -466,7 +466,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported - if (inputElemTy.isa() && + if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " @@ -529,7 +529,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported - if (inputElemTy.isa() && + if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " @@ -603,7 +603,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } // Currently, (u)int8 dtype is not supported - if (inputElemTy.isa() && + if (isa(inputElemTy) && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " @@ -715,7 +715,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto inputRank = inputType.getRank(); auto inputElemType = inputType.getElementType(); - if (!inputElemType.isa()) { + if (!isa(inputElemType)) { return op.emitError( "only float dtype allowed in input tensor of AtenFrobeniusNormDimOp"); } @@ -830,7 +830,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()).cast(); auto outElemType = outType.getElementType(); - if (!outElemType.isa()) { + if (!isa(outElemType)) { return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp"); } @@ -912,7 +912,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op->getLoc(), blockArgumentTy, DenseElementsAttr::get( blockArgumentTy, - APFloat(outElemType.cast().getFloatSemantics(), 1))); + APFloat(cast(outElemType).getFloatSemantics(), 1))); auto reciprocalOrd = rewriter.create( op->getLoc(), blockArgumentTy, constantOne, ord); auto output = rewriter.create( diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 910ba049d07d..c098a91446aa 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -51,7 +51,7 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - if (selfTy.getElementType().isa()) { + if (isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -146,12 +146,12 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unable to extract the scalar constant"); - if (dtype.isa()) { + if (isa(dtype)) { tosaTensor = tosa::getConstTensor(rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) .value(); - } else if (auto intType = dtype.dyn_cast()) { + } else if (auto intType = dyn_cast(dtype)) { auto w = intType.getWidth(); if (w != 1 && w != 32 && w != 64) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { @@ -279,7 +279,7 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } Type rhsAlphaMulElemType; - if (outElemTy.isa()) { + if (isa(outElemTy)) { rhsAlphaMulElemType = outElemTy; } else { // if output type is 64, input type should also be 32 @@ -362,7 +362,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same() || std::is_same() || std::is_same(); - if (lhsElemTy.isa() && isBitwiseOp) { + if (isa(lhsElemTy) && isBitwiseOp) { return rewriter.notifyMatchFailure(op, "For bitwise operators, only integer " "datatype legalization is supported"); @@ -452,8 +452,7 @@ class ConvertAtenMulOp : public OpConversionPattern { rhsTensor = rhsType ? rhs : rhsAsTensor; } - if (outElemTy.isa() || - outElemTy.isa()) { + if (isa(outElemTy) || isa(outElemTy)) { auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); @@ -550,7 +549,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto selfTy = self.getType().cast(); - if (selfTy && selfTy.getElementType().isa()) { + if (selfTy && isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); @@ -567,7 +566,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto selfTy = self.getType().cast(); - if (selfTy && selfTy.getElementType().isa()) { + if (selfTy && isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); @@ -594,7 +593,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // Rescale the clampIn for quantized types. TBD - if (!selfTy.getElementType().isa()) { + if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); } @@ -614,7 +613,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value self = adaptor.getSelf(); auto selfTy = self.getType().cast(); - if (!selfTy.getElementType().isa()) { + if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); } @@ -1027,7 +1026,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); - if (!expTy.getElementType().isa()) + if (!isa(expTy.getElementType())) return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); @@ -1061,7 +1060,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); - if (!selfTy.getElementType().isa()) + if (!isa(selfTy.getElementType())) return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); @@ -1095,7 +1094,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); - if (!selfTy.getElementType().isa()) + if (!isa(selfTy.getElementType())) return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); @@ -1977,7 +1976,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Rsub"); - if (!selfTy.getElementType().isa()) + if (!isa(selfTy.getElementType())) return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); @@ -2103,7 +2102,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to // define a 48-bit int. - if (inputElemTy.isa()) { + if (isa(inputElemTy)) { SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor( rewriter, op, zeroVec, {static_cast(weightShape[0])}) @@ -2121,7 +2120,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Bias provided but not a ranked tensor"); } auto biasElemTy = - inputElemTy.isa() ? inputElemTy : rewriter.getI32Type(); + isa(inputElemTy) ? inputElemTy : rewriter.getI32Type(); int64_t groups; if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) { @@ -2308,7 +2307,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .getResult(); Value rescaledResult = transposedOutput; - if (inputElemTy.isa()) { + if (isa(inputElemTy)) { rescaledResult = tosa::buildRescaleOpConvOutput( rewriter, op, transposedOutput, inputTy, weightTy, outputTy); } @@ -2419,7 +2418,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Note: cudnn_enabled is not handled. // FIXME: Handle training and momentum. - if (op.getMomentum().getType().isa()) + if (isa(op.getMomentum().getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); auto meanType = adaptor.getRunningMean().getType().dyn_cast(); @@ -2440,7 +2439,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (toBcastType.getRank() > 1) return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1"); - RankedTensorType outTensorType = outType.cast(); + RankedTensorType outTensorType = cast(outType); SmallVector newShape = { makeShapeTorchCompatible(toBcastType.getShape())[0]}; for (auto i = 2; i < outTensorType.getRank(); ++i) @@ -2523,9 +2522,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Note: cudnn_enabled is not handled. // FIXME: Handle the None cases for the optional parameters. - if (adaptor.getWeight().getType().isa()) + if (isa(adaptor.getWeight().getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for weight"); - if (adaptor.getBias().getType().isa()) + if (isa(adaptor.getBias().getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); auto weightType = adaptor.getWeight().getType().cast(); @@ -2889,7 +2888,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point or integer datatype legalization supported"); // Integer types with width > 32 are not supported - auto selfIntType = selfElemTy.dyn_cast(); + auto selfIntType = dyn_cast(selfElemTy); if (selfIntType && selfIntType.getWidth() > 32) { return rewriter.notifyMatchFailure( op, "Integer types with width greater than 32 are not supported"); @@ -3168,7 +3167,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); - if (!selfElemTy.isa()) { + if (!isa(selfElemTy)) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); } @@ -3205,7 +3204,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); - if (!selfElemTy.isa()) { + if (!isa(selfElemTy)) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); } @@ -3269,7 +3268,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // Integer types with width > 32 are not supported - auto selfIntType = selfElemTy.dyn_cast(); + auto selfIntType = dyn_cast(selfElemTy); if (selfIntType && selfIntType.getWidth() > 32) { return rewriter.notifyMatchFailure( op, "Integer types with width greater than 32 are not supported"); @@ -3335,7 +3334,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( typeConverter->convertType(op.getType()).cast(); auto indicesType = indices.getType().dyn_cast(); - if (!indicesType || !indicesType.getElementType().isa()) + if (!indicesType || !isa(indicesType.getElementType())) return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); @@ -4525,8 +4524,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!selfType.hasStaticShape() || !otherType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); - if (!selfType.getElementType().isa() || - !otherType.getElementType().isa()) { + if (!isa(selfType.getElementType()) || + !isa(otherType.getElementType())) { return rewriter.notifyMatchFailure( op, "unimplemented: only FP element type is supported"); } @@ -4630,7 +4629,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // TODO: Add support for pin_memory features. // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( @@ -4796,7 +4795,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "Unable to extract the scalar constant"); auto outElemTy = resultType.getElementType(); - if (outElemTy.isa()) { + if (isa(outElemTy)) { rewriter.replaceOpWithNewOp( op, resultType, DenseElementsAttr::get(resultType, {intValue})); } else if (outElemTy.isF64()) { @@ -4885,7 +4884,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( @@ -4944,7 +4943,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto divTensor = self; // tosa::DivOp only supports int - if (outElemTy.isa()) { + if (isa(outElemTy)) { auto otherTensorReciprocal = rewriter.create( op.getLoc(), otherTensor.getType(), otherTensor); divTensor = rewriter.create( @@ -5792,7 +5791,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only single values are supported."); auto elementTy = outputTy.getElementType(); - if (!elementTy.isa()) + if (!isa(elementTy)) return rewriter.notifyMatchFailure(op, "Only integer values are supported."); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index b4e82360c60f..f6bc54180e7b 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -119,7 +119,7 @@ tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, Value lhs, Value rhs) { auto lhsElemTy = lhs.getType().cast().getElementType(); auto rhsElemTy = rhs.getType().cast().getElementType(); - if (lhsElemTy.isa() || rhsElemTy.isa()) { + if (isa(lhsElemTy) || isa(rhsElemTy)) { (void)rewriter.notifyMatchFailure(op, "tosa.div only supports integer type"); } @@ -213,7 +213,7 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, Type outType, Value paramsValue, Value indicesValue) { - auto resultType = outType.dyn_cast(); + auto resultType = dyn_cast(outType); auto paramsType = paramsValue.getType().dyn_cast(); auto indicesType = indicesValue.getType().dyn_cast(); @@ -419,7 +419,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, Operation *op, Type outType, Value paramsValue, Value indicesValue, Value fillValues) { - auto resultType = outType.dyn_cast(); + auto resultType = dyn_cast(outType); auto paramsType = paramsValue.getType().dyn_cast(); auto indicesType = indicesValue.getType().dyn_cast(); auto fillValuesType = fillValues.getType().dyn_cast(); @@ -988,7 +988,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; Type elemType = output_type.getElementType(); - if (!elemType.isa()) { + if (!isa(elemType)) { op->emitOpError("Only floating-point datatype legalization supported for " "AtenLinalgVectorNorm op"); return std::nullopt; diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1fcc91991f37..7e7195e836bd 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -166,7 +166,7 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type) { - RankedTensorType resultType = type.dyn_cast(); + RankedTensorType resultType = dyn_cast(type); if (!resultType) { (void)rewriter.notifyMatchFailure(op, "not ranked tensor type"); @@ -179,7 +179,7 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, Attribute zeroAttr = rewriter.getZeroAttr(zeroType); return CreateOpAndInfer(rewriter, op->getLoc(), zeroType, - zeroAttr.cast()) + cast(zeroAttr)) .getResult(); } @@ -322,7 +322,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result) { Type srcElemTy = src.getType().dyn_cast().getElementType(); - Type destElemTy = destType.dyn_cast().getElementType(); + Type destElemTy = dyn_cast(destType).getElementType(); if (failed(checkValidityOfCast(srcElemTy, destElemTy))) return rewriter.notifyMatchFailure( @@ -451,7 +451,7 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time // FP16 is supported, the accumulator type can be selected based on trade-off // between performance and accuracy. Set to FP32 by default. - accType = inputETy.isa() + accType = isa(inputETy) ? mlir::TypeAttr::get(rewriter.getF32Type()) : mlir::TypeAttr::get(rewriter.getIntegerType(32)); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 4d42b5fea943..e99ed57f6a23 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -27,9 +27,9 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op, // TODO: Remove this check but use a separate verification pass to verify the // invariants expected by later passes. auto isValidLinalgType = [](Type type) { - if (type.isa()) + if (isa(type)) return false; - auto tensor = type.dyn_cast(); + auto tensor = dyn_cast(type); return !tensor || tensor.toBuiltinTensor().dyn_cast_or_null(); }; @@ -43,8 +43,8 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op, LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) { Type type = v.getType(); - if (type.isa() || type.isa() || - type.isa()) + if (isa(type) || isa(type) || + isa(type)) return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); return success(); } @@ -104,7 +104,7 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, Type lhsType = lhsDim.getType(); Type rhsType = rhsDim.getType(); auto checkIntOrIndex = [](Type type) { - assert((type.isa() || type.isa()) && + assert((isa(type) || isa(type)) && "must be either integer or index type"); }; checkIntOrIndex(lhsType); @@ -198,13 +198,13 @@ Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { // Creates a constant of type `elemType` with value `val`. Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) { TypedAttr attr = {}; - if (elemType.isa()) + if (isa(elemType)) attr = b.getFloatAttr(elemType, val); - if (elemType.isa()) + if (isa(elemType)) attr = b.getIndexAttr(val); - if (elemType.isa()) - attr = b.getIntegerAttr( - elemType, APInt(elemType.cast().getWidth(), val)); + if (isa(elemType)) + attr = b.getIntegerAttr(elemType, + APInt(cast(elemType).getWidth(), val)); if (!attr) return nullptr; return b.create(loc, elemType, attr); @@ -264,7 +264,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return scalar; auto isByteOrChar = [](Type type) { - if (auto integerTy = type.dyn_cast()) { + if (auto integerTy = dyn_cast(type)) { return integerTy.getWidth() == 8; } return false; @@ -303,10 +303,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, if (dtype.isSignlessInteger(1)) { Type scalarType = scalar.getType(); Value cstZero = b.create(loc, b.getZeroAttr(scalarType)); - if (scalarType.isa()) { + if (isa(scalarType)) { return b.create(loc, arith::CmpFPredicate::UNE, scalar, cstZero); - } else if (scalarType.isa()) { + } else if (isa(scalarType)) { return b.create(loc, arith::CmpIPredicate::ne, scalar, cstZero); } else { @@ -317,14 +317,14 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, } } - if (auto dtypeFloat = dtype.dyn_cast()) { - if (auto scalarFloat = scalarType.dyn_cast()) { + if (auto dtypeFloat = dyn_cast(dtype)) { + if (auto scalarFloat = dyn_cast(scalarType)) { if (scalarFloat.getWidth() > dtypeFloat.getWidth()) return b.create(loc, dtype, scalar); // Only scalarFloat width < dtypeFloat width can reach here. return b.create(loc, dtype, scalar); } - assert(scalarType.isa()); + assert(isa(scalarType)); if (scalarType.isSignlessInteger(1) || (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) return b.create(loc, dtype, scalar); @@ -333,11 +333,11 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtype, scalar); } - if (auto dtypeInteger = dtype.dyn_cast()) { - if (auto scalarFloat = scalarType.dyn_cast()) + if (auto dtypeInteger = dyn_cast(dtype)) { + if (auto scalarFloat = dyn_cast(scalarType)) return b.create(loc, dtype, scalar); - assert(scalarType.isa()); - auto scalarInteger = scalarType.cast(); + assert(isa(scalarType)); + auto scalarInteger = cast(scalarType); if (scalarInteger.getWidth() > dtypeInteger.getWidth()) return b.create(loc, dtype, scalar); if (scalarType.isSignlessInteger(1) || diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 1e8c91e8afd4..8f34358b9c0f 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -49,7 +49,7 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp, size_t resultIndex = en.index(); Type resultType = en.value(); - auto tensorType = resultType.dyn_cast(); + auto tensorType = dyn_cast(resultType); if (tensorType == nullptr) { tmtensorOp.emitOpError() << "tensor to buffer conversion expects ranked tensor results"; diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index e7fcbb434a2c..d57b3e74198e 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -100,10 +100,12 @@ void TorchDialect::initialize() { addOperations< #define GET_OP_LIST #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" + >(); addTypes< #define GET_TYPEDEF_LIST #include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc" + >(); addInterfaces(); } @@ -144,35 +146,34 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op, Operation *TorchDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (auto integerType = type.dyn_cast()) - return builder.create(loc, value.cast()); + if (auto integerType = dyn_cast(type)) + return builder.create(loc, cast(value)); - if (auto floatType = type.dyn_cast()) - return builder.create(loc, value.cast()); + if (auto floatType = dyn_cast(type)) + return builder.create(loc, cast(value)); - if (auto numberType = type.dyn_cast()) { - if (auto floatValue = value.dyn_cast()) { + if (auto numberType = dyn_cast(type)) { + if (auto floatValue = dyn_cast(value)) { return builder.create(loc, floatValue); - } else if (auto intValue = value.dyn_cast()) { + } else if (auto intValue = dyn_cast(value)) { return builder.create(loc, intValue); } } - if (type.isa()) { - return builder.create(loc, - value.cast()); + if (isa(type)) { + return builder.create(loc, cast(value)); } - if (type.isa()) + if (isa(type)) return builder.create(loc); - if (auto stringAttr = value.dyn_cast()) + if (auto stringAttr = dyn_cast(value)) return builder.create(loc, stringAttr); - if (auto elementsAttr = value.dyn_cast()) { + if (auto elementsAttr = dyn_cast(value)) { // Only !torch.vtensor can be constant folded. !torch.tensor has // non-trivial aliasing semantics which prevent deduplicating it. - assert(type.isa() && "should be a vtensor type!"); + assert(isa(type) && "should be a vtensor type!"); return builder.create(loc, elementsAttr); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index be81eae53186..8ac7d1805f97 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -41,9 +41,8 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, return value; // If the type is a tensor, then adjust the static information. - if ((type.isa() && desiredType.isa()) || - (type.isa() && - desiredType.isa())) { + if ((isa(type) && isa(desiredType)) || + (isa(type) && isa(desiredType))) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; @@ -90,7 +89,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, // then we do the copy by going to a value tensor and back. if (tensor.getType().isa()) tensor = builder.create(loc, tensor); - if (newType.isa()) + if (isa(newType)) tensor = builder.create(loc, tensor); return tensor; @@ -132,11 +131,11 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { static Value getScalarIntValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); - if (inputType.isa()) { + if (isa(inputType)) { return input; } - auto inputTensorType = inputType.dyn_cast(); + auto inputTensorType = dyn_cast(inputType); if (!inputTensorType) return nullptr; @@ -166,11 +165,11 @@ static Value getScalarIntValue(Value input, Location loc, static Value getScalarFloatValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); - if (inputType.isa()) { + if (isa(inputType)) { return input; } - auto inputTensorType = inputType.dyn_cast(); + auto inputTensorType = dyn_cast(inputType); if (!inputTensorType) return nullptr; @@ -273,7 +272,7 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { LogicalResult PrimListConstructOp::verify() { auto resultType = getResult().getType(); - auto resultElementType = resultType.dyn_cast().getContainedType(); + auto resultElementType = dyn_cast(resultType).getContainedType(); auto matchResultElementType = [&](Type type) { return isValidSubtype(type, resultElementType); }; @@ -606,7 +605,7 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { Type rhsType = rhs.getType(); // If either type is a NoneType, make it be the lhsType. - if (rhsType.isa()) { + if (isa(rhsType)) { std::swap(lhsType, rhsType); std::swap(lhs, rhs); } @@ -615,14 +614,14 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { // If both types are the singleton `!torch.none` type, then we don't even need // to look at the values. - if (lhsType.isa() && rhsType.isa()) + if (isa(lhsType) && isa(rhsType)) return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue); // If neither type is a subtype of the other, then the result is false. // TODO: Implement and use subtype infra for this. // For now, check a specific case. // If the rhs is not OptionalType, then we know it cannot be None. - if (lhsType.isa() && !rhsType.isa()) { + if (isa(lhsType) && !isa(rhsType)) { return IntegerAttr::get(IntegerType::get(op->getContext(), 1), !equalIsTrue); } @@ -640,9 +639,9 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { auto step = adaptor.getStep(); if (!lo || !hi || !step) return nullptr; - auto loInt = lo.dyn_cast_or_null().getValue(); - auto hiInt = hi.dyn_cast_or_null().getValue(); - auto stepInt = step.dyn_cast_or_null().getValue(); + auto loInt = dyn_cast_or_null(lo).getValue(); + auto hiInt = dyn_cast_or_null(hi).getValue(); + auto stepInt = dyn_cast_or_null(step).getValue(); // TODO: Implement folding for negative steps. if (stepInt.isNegative()) return nullptr; @@ -650,7 +649,7 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { // r[i] = lo + step*i such that i >= 0 and r[i] < hi // So maximize `i` such that lo + step * i < hi // ==> i == ceildiv(hi - lo, step) - return IntegerAttr::get(lo.cast().getType(), + return IntegerAttr::get(cast(lo).getType(), llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, APInt::Rounding::UP)); } @@ -665,10 +664,10 @@ OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { auto step = adaptor.getStep(); if (!index || !start || !step) return nullptr; - auto indexInt = index.dyn_cast_or_null().getValue(); - auto startInt = start.dyn_cast_or_null().getValue(); - auto stepInt = step.dyn_cast_or_null().getValue(); - return IntegerAttr::get(index.cast().getType(), + auto indexInt = dyn_cast_or_null(index).getValue(); + auto startInt = dyn_cast_or_null(start).getValue(); + auto stepInt = dyn_cast_or_null(step).getValue(); + return IntegerAttr::get(cast(index).getType(), startInt + stepInt * indexInt); } @@ -2807,9 +2806,9 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns( Value constValue; Attribute value = op.getValueAttr(); - if (auto floatValue = value.dyn_cast()) { + if (auto floatValue = dyn_cast(value)) { constValue = rewriter.create(loc, floatValue); - } else if (auto intValue = value.dyn_cast()) { + } else if (auto intValue = dyn_cast(value)) { constValue = rewriter.create(loc, intValue); } else { return failure(); @@ -3127,6 +3126,9 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, if (!listConstruct) return failure(); + if (op->getNumResults() != listConstruct.getElements().size()) + return failure(); + rewriter.replaceOp(op, listConstruct.getElements()); return success(); }); @@ -3228,9 +3230,9 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, BinaryFloatOperatorFn f) { double lhs, rhs; auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool { - if (auto intLhs = attr.dyn_cast_or_null()) { + if (auto intLhs = dyn_cast_or_null(attr)) { value = static_cast(intLhs.getValue().getSExtValue()); - } else if (auto floatLhs = attr.dyn_cast_or_null()) { + } else if (auto floatLhs = dyn_cast_or_null(attr)) { value = floatLhs.getValue().convertToDouble(); } else { return false; @@ -4049,7 +4051,7 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { } Type resultType = getResult().getType(); - BaseTensorType resultTensorType = resultType.dyn_cast(); + BaseTensorType resultTensorType = dyn_cast(resultType); if (!resultTensorType || !resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { return nullptr; @@ -4070,11 +4072,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { return nullptr; } auto elementType = shapedty.getElementType(); - if (elementType.isa()) { + if (isa(elementType)) { Attribute attribute = IntegerAttr::get(elementType, 1); return DenseElementsAttr::get(shapedty, attribute); } - if (elementType.isa()) { + if (isa(elementType)) { Attribute attribute = FloatAttr::get(elementType, 1.0); return DenseElementsAttr::get(shapedty, attribute); } @@ -4088,7 +4090,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { } Type resultType = getResult().getType(); - BaseTensorType resultTensorType = resultType.dyn_cast(); + BaseTensorType resultTensorType = dyn_cast(resultType); if (!resultTensorType || !resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { return nullptr; @@ -4110,11 +4112,11 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { } auto elementType = shapedty.getElementType(); - if (elementType.isa()) { + if (isa(elementType)) { Attribute attribute = IntegerAttr::get(elementType, 0); return DenseElementsAttr::get(shapedty, attribute); } - if (elementType.isa()) { + if (isa(elementType)) { Attribute attribute = FloatAttr::get(elementType, 0.0); return DenseElementsAttr::get(shapedty, attribute); } @@ -4129,7 +4131,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { } Type resultType = getResult().getType(); - BaseTensorType resultTensorType = resultType.dyn_cast(); + BaseTensorType resultTensorType = dyn_cast(resultType); if (!resultTensorType || !resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { return nullptr; @@ -4147,14 +4149,14 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); auto elementType = shapedty.getElementType(); - if (elementType.isa()) { + if (isa(elementType)) { int64_t value = 0; if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) { Attribute attribute = IntegerAttr::get(elementType, value); return DenseElementsAttr::get(shapedty, attribute); } } - if (elementType.isa()) { + if (isa(elementType)) { double value = 0.0; if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { Attribute attribute = FloatAttr::get(elementType, value); @@ -4735,15 +4737,14 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { auto initialize = cast(getBody()->getTerminator()); for (Attribute symName : initialize.getSlotSymNames()) { auto wasInserted = initializedGlobalSlots - .insert(symName.cast().getAttr()) + .insert(cast(symName).getAttr()) .second; if (!wasInserted) return initialize.emitError("duplicate initialization of global slot: ") << symName; } auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) { - return lhs.cast().getValue() < - rhs.cast().getValue(); + return cast(lhs).getValue() < cast(rhs).getValue(); }; auto known = llvm::to_vector(knownGlobalSlots); llvm::sort(known, lessThanByStringValue); @@ -4756,7 +4757,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { InFlightDiagnostic diag = initialize.emitOpError( "must have one initializer for each global slot in the module"); for (auto knownGlobalSlot : known) { - auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast()); + auto symName = FlatSymbolRefAttr::get(cast(knownGlobalSlot)); if (!initializedGlobalSlots.count(knownGlobalSlot)) { diag.attachNote( symbolTable.lookup(symName.getAttr()).getLoc()) @@ -4767,7 +4768,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { if (!knownGlobalSlots.count(initializedGlobalSlot)) { diag.attachNote().append( "unexpected global slot initializer for non-existent global slot ", - FlatSymbolRefAttr::get(initializedGlobalSlot.cast())); + FlatSymbolRefAttr::get(cast(initializedGlobalSlot))); } } return diag; diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index b22c82b8a28f..c162166cdd13 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -29,7 +29,7 @@ bool Torch::isValidSubtype(Type subtype, Type type) { // For a UnionType to be a subtype, all of its contained types must be // subtypes. - if (auto unionType = subtype.dyn_cast()) { + if (auto unionType = dyn_cast(subtype)) { for (auto containedType : unionType.getContainedTypes()) { if (!isValidSubtype(containedType, type)) return false; @@ -37,17 +37,17 @@ bool Torch::isValidSubtype(Type subtype, Type type) { return true; } - if (auto any = type.dyn_cast()) + if (auto any = dyn_cast(type)) return true; - if (auto number = type.dyn_cast()) - return subtype.isa() || subtype.isa(); + if (auto number = dyn_cast(type)) + return isa(subtype) || isa(subtype); - if (auto optional = type.dyn_cast()) + if (auto optional = dyn_cast(type)) return isValidSubtype(subtype, optional.getContainedType()) || - subtype.isa(); + isa(subtype); - if (auto unionType = type.dyn_cast()) { + if (auto unionType = dyn_cast(type)) { for (auto containedType : unionType.getContainedTypes()) { if (isValidSubtype(subtype, containedType)) return true; @@ -55,10 +55,10 @@ bool Torch::isValidSubtype(Type subtype, Type type) { return false; } - if (auto tuple = type.dyn_cast()) { - if (!subtype.isa()) + if (auto tuple = dyn_cast(type)) { + if (!isa(subtype)) return false; - auto subtypes = subtype.cast().getContainedTypes(); + auto subtypes = cast(subtype).getContainedTypes(); auto types = tuple.getContainedTypes(); if (subtypes.size() != types.size()) return false; @@ -69,14 +69,14 @@ bool Torch::isValidSubtype(Type subtype, Type type) { return true; } - auto subtypeTensorType = subtype.dyn_cast(); - auto typeTensorType = type.dyn_cast(); + auto subtypeTensorType = dyn_cast(subtype); + auto typeTensorType = dyn_cast(type); if (subtypeTensorType && typeTensorType) { // Check that both tensors have the same `BaseTensorType` subtype. // TODO: This is not subtyping according to PEP 483. See description // of NonValueTensorType. - if (subtypeTensorType.isa() != - typeTensorType.isa()) + if (isa(subtypeTensorType) != + isa(typeTensorType)) return false; // `type` must not have more static information than `subtype`, and `type` @@ -181,23 +181,23 @@ void Torch::UnionType::print(AsmPrinter &printer) const { static bool isValidTorchDtype(Type dtype) { // For complex types, get the underlying element type - if (dtype.isa()) { - dtype = dtype.cast().getElementType(); + if (isa(dtype)) { + dtype = cast(dtype).getElementType(); } // Torch quantized types. - if (dtype.isa()) + if (isa(dtype)) return true; // Builtin floating point types. - if (dtype.isa()) + if (isa(dtype)) return true; if (dtype.isa()) return true; - if (dtype.isa()) + if (isa(dtype)) return true; // Builtin integer types. - if (IntegerType type = dtype.dyn_cast()) { + if (IntegerType type = dyn_cast(dtype)) { if (type.isSignless() && type.getWidth() == 1) return true; if (type.isSigned()) { @@ -273,7 +273,7 @@ verifyTensorType(function_ref emitError, } } } - if (!optionalSparsity.isa()) { + if (!isa(optionalSparsity)) { emitError() << "invalid sparsity encoding attribute"; return failure(); } @@ -441,12 +441,12 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { } static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { - if (auto floatType = dtype.dyn_cast()) { + if (auto floatType = dyn_cast(dtype)) { return dtype; - } else if (auto integerType = dtype.dyn_cast()) { + } else if (auto integerType = dyn_cast(dtype)) { return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); - } else if (dtype.isa()) { + } else if (isa(dtype)) { return dtype; } @@ -502,8 +502,8 @@ void ValueTensorType::print(AsmPrinter &printer) const { } Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { - assert(((lhs.isa() && rhs.isa()) || - (lhs.isa() && rhs.isa())) && + assert(((isa(lhs) && isa(rhs)) || + (isa(lhs) && isa(rhs))) && "expected lhs and rhs to have same sense of value semantics"); // First, calculate the dtype. @@ -566,21 +566,21 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { // linkage) and the predicates themselves can't be added/used in the // specification of the parameters of the Torch_DictType. static bool isAnyTorchDictKeyType(Type type) { - return type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa(); + return isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type); } static bool isAnyTorchType(Type type) { return isValidSubtype(type, Torch::NumberType::get(type.getContext())) || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa(); + isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type); } LogicalResult diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 2891a22eb817..000efbc7ceb1 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -53,7 +53,7 @@ class AdjustCallingConventionForFunc auto typeBoundAttr = func.getArgAttrOfType(type.index(), typeBoundIdent); Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type(); - if (!bound.isa()) + if (!isa(bound)) return rewriter.notifyMatchFailure( func, "unimplemented: preserving aliasing for non-value-semantic " "type bounds"); @@ -72,10 +72,10 @@ class AdjustCallingConventionForFunc SmallVector newResultTypes; for (auto type : func.getFunctionType().getResults()) { - if (auto none = type.dyn_cast()) { + if (auto none = dyn_cast(type)) { continue; } - if (auto tuple = type.dyn_cast()) { + if (auto tuple = dyn_cast(type)) { llvm::append_range(newResultTypes, tuple.getContainedTypes()); continue; } @@ -133,12 +133,12 @@ class AdjustCallingConventionForCall int newOpResultIdx = 0; SmallVector newResults; for (auto type : call.getResultTypes()) { - if (type.isa()) { + if (isa(type)) { newResults.push_back( rewriter.create(call.getLoc(), type)); continue; } - if (type.isa()) { + if (isa(type)) { newResults.push_back(rewriter.create( call.getLoc(), type, newCall.getResults())); continue; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d6cd34b82091..737e093a479f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1386,7 +1386,7 @@ static Value getSoftmaxResult(OpTy op, Value self, Type resultType, unNormalizedExp, sum); if (resultType != accumulatorType) result = convertTensorToDtype(rewriter, loc, result, - resultType.cast().getDtype()); + cast(resultType).getDtype()); return result; } @@ -1405,7 +1405,7 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { op, "expected result type to have a dtype"); } Type resultTensorDtype = resultTensorType.getDtype(); - if (!resultTensorDtype.isa()) + if (!isa(resultTensorDtype)) return rewriter.notifyMatchFailure(op, "Only support floating-point type"); @@ -1980,7 +1980,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { } Type dtype = resType.getDtype(); - if (dtype.isa()) { + if (isa(dtype)) { return rewriter.notifyMatchFailure( op, "lowering of aten.linalg_cross for complex inputs dtype is " "currently unimplemented"); @@ -2015,7 +2015,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { Value none = rewriter.create(loc); // idx = torch.arange(3) - auto outType = opType.dyn_cast(); + auto outType = dyn_cast(opType); auto arangeType = outType.getWithSizesAndDtype( llvm::ArrayRef(3), IntegerType::get(op.getContext(), 64, IntegerType::Signed)); @@ -5873,7 +5873,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, Value keepDim = op.getKeepdim(); BaseTensorType inputTensorTy = self.getType().cast(); Type outputType = op.getType(); - BaseTensorType outputTensorType = outputType.cast(); + BaseTensorType outputTensorType = cast(outputType); if (!outputTensorType.hasDtype()) { return rewriter.notifyMatchFailure(op, "expected result type to have a dtype"); @@ -5918,7 +5918,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, Type meanDimResultType = inputTensorTy; for (unsigned i = 0; i < dimListElements.size(); i++) meanDimResultType = computeReductionType( - rewriter, op, meanDimResultType.cast(), + rewriter, op, cast(meanDimResultType), dimListElements[i], /*keepDim=*/true); @@ -6214,7 +6214,7 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { Location loc = op.getLoc(); Type resultType = op.getType(); - BaseTensorType resultTensorType = resultType.cast(); + BaseTensorType resultTensorType = cast(resultType); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index ce9d58a9c2f4..1dc0cc9a9d8d 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -207,7 +207,7 @@ class QuantizeAccumulator : public OpRewritePattern { return failure(); Type resultETy = resultTy.getDtype(); - if (!resultETy.isa()) + if (!isa(resultETy)) return failure(); Value lhsScale; diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 1e8c90deac4e..5d59dfd8c596 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -183,13 +183,13 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { } LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { - if (Value value = point.dyn_cast()) { + if (Value value = dyn_cast(point)) { bool isSafe = isValueSafeTransferFunction(value); auto *state = getOrCreate(value); propagateIfChanged(state, state->setSafe(isSafe)); // Handle GlobalSlotGetOp's. - if (auto opResult = value.dyn_cast()) { + if (auto opResult = dyn_cast(value)) { if (auto globalSlotGet = dyn_cast(opResult.getOwner())) { auto *flatSymbolRefPoint = getProgramPoint( @@ -205,7 +205,7 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { return success(); } - if (auto *genericProgramPoint = point.dyn_cast()) { + if (auto *genericProgramPoint = dyn_cast(point)) { if (auto *flatSymbolRefPoint = dyn_cast(genericProgramPoint)) { if (initializeGlobalSlotsOp) { @@ -396,7 +396,7 @@ class InlineGlobalSlotsPass // This could be left to SymbolDCE but it's not hard to do here. for (FlatSymbolRefAttr symName : llvm::map_range(safeToInline, [](Attribute attr) { - return attr.cast(); + return cast(attr); })) { auto globalSlot = symbolTable.lookup(symName.getValue()); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 250577b132ee..78bf8504d564 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -46,14 +46,14 @@ static LogicalResult checkType(Operation *op, Type type, // can statically pattern match and eliminate from the program. // For example, a tensor operand might be optional, and the backend // will pattern-match statically whether it is passed as a tensor or None. - if (type.isa()) + if (isa(type)) return success(); // We blanket prohibit non-value-semantic tensors. // All of our backends are currently based on value-semantic tensors, so // we consider it our responsibility to lower all non-value-semantic tensors // to value-semantic tensors. - if (type.isa()) { + if (isa(type)) { if (actuallyEmitDiagnostics) { return op ->emitError("unsupported by backend contract: non-value tensor type") @@ -84,7 +84,7 @@ static LogicalResult checkType(Operation *op, Type type, // have an sufficiently rich system for representing PyTorch type promotion // rules. So we consider it our responsibility to ensure that all dtypes are // statically known. - if (auto tensorType = type.dyn_cast()) { + if (auto tensorType = dyn_cast(type)) { if (!tensorType.hasSizes()) { if (actuallyEmitDiagnostics) { return op @@ -115,7 +115,7 @@ static LogicalResult checkType(Operation *op, Type type, // Optional types are also in the category of types which we don't expect // backends to dynamically compute with, but they can be pattern matched // in many cases that are practically necessary. - if (auto optionalType = type.dyn_cast()) { + if (auto optionalType = dyn_cast(type)) { // TODO: Be stricter about tensor types. // See comment below for ListType. if (optionalType.getContainedType().isa()) @@ -127,7 +127,7 @@ static LogicalResult checkType(Operation *op, Type type, // backends to dynamically compute with, but they can be pattern matched // in many cases that are practically necessary. For example, the // strides of a convolution op are represented as a list. - if (auto listType = type.dyn_cast()) { + if (auto listType = dyn_cast(type)) { // TODO: Be stricter about tensor types. // For the moment, there are cases (such as for torch.cat) where we end // up with `!torch.list` which doesn't have shape or dtype in @@ -141,7 +141,7 @@ static LogicalResult checkType(Operation *op, Type type, // Tuple types are also in the category of types which we don't expect // backends to dynamically compute with, but they can be pattern matched // in many cases that are practically necessary. - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = dyn_cast(type)) { for (auto containedType : tupleType.getContainedTypes()) { if (failed(checkType(op, containedType, actuallyEmitDiagnostics))) return failure(); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 7db6bc6776b3..4026d0464dca 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -140,7 +140,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock auto returnOp = ops.returnOp.value(); for (auto operand : llvm::enumerate(returnOp->getOperands())) { auto type = operand.value().getType(); - if (!type.isa()) + if (!isa(type)) continue; originalReturnTypes[operand.index()] = type; } diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index f8161de1fa0b..746b9068284c 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -38,15 +38,15 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter, } static Type getContainerOrTensorTypeWithValueSemantics(Type type) { - if (auto optionalType = type.dyn_cast()) { + if (auto optionalType = dyn_cast(type)) { Type newContainedType = getContainerOrTensorTypeWithValueSemantics( optionalType.getContainedType()); return OptionalType::get(newContainedType); - } else if (auto listType = type.dyn_cast()) { + } else if (auto listType = dyn_cast(type)) { Type newContainedType = getContainerOrTensorTypeWithValueSemantics(listType.getContainedType()); return ListType::get(newContainedType); - } else if (auto tensorType = type.dyn_cast()) { + } else if (auto tensorType = dyn_cast(type)) { return tensorType.getWithValueSemantics(); } else { return nullptr; @@ -92,10 +92,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { SmallVector newOperands; for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); - if (operandType.isa()) { + if (isa(operandType)) { opOperand.set(rewriter.create(op->getLoc(), opOperand.get())); - } else if (auto listType = operandType.dyn_cast()) { + } else if (auto listType = dyn_cast(operandType)) { if (!(listType.getContainedType().isa() || listType.getContainedType().isa())) continue; @@ -144,7 +144,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { } opOperand.set(rewriter.create( op->getLoc(), newListType, newListElements)); - } else if (auto optionalType = operandType.dyn_cast()) { + } else if (auto optionalType = dyn_cast(operandType)) { // TODO: A more general way to handle the optional type is to // introduce a `copy.to_optional_vtensor` op. if (!optionalType.getContainedType().isa()) @@ -450,7 +450,7 @@ struct ReduceOpVariantsPass auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system. - if (auto tensorType = t.dyn_cast()) { + if (auto tensorType = dyn_cast(t)) { return false; } return true; diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index a34e0208c9d9..8049d8af8d59 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -170,7 +170,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, if (operandType == desiredType) return operand; - if (desiredType.isa()) { + if (isa(desiredType)) { // Generator's are currently passed as Any because TorchScript cannot // compile a function with Generator type arguments. // Ignoring that hack, this is a correct handling of Any type should we need @@ -180,8 +180,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // The type `!torch.number` can be an `int`, `float`, or `complex`. // TODO: Add a new type `Torch::ComplexType` to handle the complex case. - if (desiredType.isa() && - operandType.isa()) { + if (isa(desiredType) && + isa(operandType)) { return b.create(loc, desiredType, operand).getResult(); } @@ -189,7 +189,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // `Scalar` inputs. At compile time, such inputs will usually be // resolved to an `int`, `float`, or `None` so we need to derefine // to match the library function signature. - if (auto unionType = desiredType.dyn_cast()) { + if (auto unionType = dyn_cast(desiredType)) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { return containedType .isa(); @@ -200,8 +200,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // Operands with type `!torch.none` correspond to library function inputs with // types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the // type is derefined to match the expected type of the library function. - if (operandType.isa()) { - assert(!desiredType.isa() && + if (isa(operandType)) { + assert(!isa(desiredType) && "Don't expect library functions to have NoneType parameters"); return b.create(loc, desiredType, operand).getResult(); } @@ -211,8 +211,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // dtype of input scalars. However, this also means we sometimes have to // manually turn `Scalar`s into `float`s when inserting the shape functions // into the IR. - if (operandType.isa() && - desiredType.isa()) { + if (isa(operandType) && + isa(desiredType)) { return b.create(loc, desiredType, operand).getResult(); } @@ -224,8 +224,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // type). // A case where this happens is `!torch.optional` -> // `!torch.optional>>`. - if (auto operandOptionalType = operandType.dyn_cast()) { - if (desiredType.isa()) { + if (auto operandOptionalType = dyn_cast(operandType)) { + if (isa(desiredType)) { // if optional is None: // return derefine(None) // else: @@ -258,7 +258,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // If the desired type is OptionalType, then recursively adjust the operand to // the contained type, then derefine it to `!torch.optional`. For example, // `!torch.vtensor -> !torch.optional>>`. - if (auto desiredOptionalType = desiredType.dyn_cast()) { + if (auto desiredOptionalType = dyn_cast(desiredType)) { FailureOr adjusted = adjustFunctionArg( b, loc, operand, desiredOptionalType.getContainedType(), baseTransformation); @@ -267,7 +267,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, return b.create(loc, desiredType, *adjusted).getResult(); } - if (auto desiredListType = desiredType.dyn_cast()) { + if (auto desiredListType = dyn_cast(desiredType)) { // Pseudocode: // // operand = ... @@ -311,7 +311,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // The library functions use `float` where the operator // signature uses `Scalar` (see comments in torch_ods_gen.py for // explanation). - if (desiredType.isa() && + if (isa(desiredType) && operand.getType().isa()) { return b.create(loc, desiredType, operand).getResult(); } diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index ac6d1ceac363..860ae79bdb86 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -29,7 +29,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, // Turn every tensor into a tuple of (tensor_rank, tensor_dtype) auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { - if (desiredType.isa() && + if (isa(desiredType) && operand.getType().isa()) { Type intType = Torch::IntType::get(b.getContext()); Type sizeListType = Torch::ListType::get(intType); diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index f755b5c0a405..9b1c5e7fdccd 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -38,7 +38,7 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc, Type desiredType) -> Value { // The shape library functions have tensor operands replaced with // `!torch.list` types for the shape. Get the sizes. - auto desiredListType = desiredType.dyn_cast(); + auto desiredListType = dyn_cast(desiredType); if (!desiredListType) return operand; if (operand.getType().isa() && diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index 1a2d3d545cbe..05daa41382cd 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -262,13 +262,13 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, originalResultType.template dyn_cast()) { // If we didn't get any new information, there is nothing left for us to do. updatedType = meetTensorTypes(originalBaseTensorType, - newResultType.cast()); + cast(newResultType)); if (!updatedType || updatedType == originalBaseTensorType) return rewriter.notifyMatchFailure( calculateOp, "New type information does not refine old type"); } else if (auto originalResultType = result.getType().template dyn_cast()) { - if (!newResultType.isa()) { + if (!isa(newResultType)) { return rewriter.notifyMatchFailure( calculateOp, "Refinement of `NumberType` must be a `FloatType` or `IntType`"); @@ -291,10 +291,10 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, } if (!originalTypedValue) { rewriter.setInsertionPointAfter(calculateOp); - if (originalResultType.isa()) { + if (isa(originalResultType)) { originalTypedValue = rewriter.create( loc, originalResultType, result); - } else if (originalResultType.isa()) { + } else if (isa(originalResultType)) { originalTypedValue = rewriter.create(loc, originalResultType, result); } else { @@ -314,14 +314,14 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, OpOperand &use = yieldValues->getOpOperand(resultNum); Value def = use.get(); Value newYieldedValue; - if (def.isa() && - def.cast() + if (isa(def) && + cast(def) .getDefiningOp() ->hasTrait()) { newYieldedValue = def; } else { rewriter.setInsertionPoint(yieldValues); - if (updatedType.isa()) { + if (isa(updatedType)) { newYieldedValue = rewriter.create(loc, updatedType, def); } else { diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index fbbd6c48043b..d68b0d4bd3a7 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -53,8 +53,9 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, op, "Failed to convert `dtypeScalarType` to a builtin type"); } impliedTypeFromDtype = - originalResultType.cast().getWithSizesAndDtype( - originalResultType.getOptionalSizes(), *builtinType); + cast(originalResultType) + .getWithSizesAndDtype(originalResultType.getOptionalSizes(), + *builtinType); } else { return rewriter.notifyMatchFailure(op, "Unimplemented: Expected result type to " @@ -179,7 +180,7 @@ class RefineNumToTensorScalarOpType } Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType()); auto impliedTypeFromInputType = - originalResultType.cast() + cast(originalResultType) .getWithSizesAndDtype(originalResultType.getOptionalSizes(), inputType) .cast(); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 1669be7c4e62..c56376a6c1bc 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -98,7 +98,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, auto originalResultType = result.getType().cast(); auto impliedTypesFromShape = - originalResultType.cast() + cast(originalResultType) .getWithSizesAndDtype(ArrayRef(sizes), originalResultType.getOptionalDtype()) .cast(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index c6b23a2f4862..d38d4423f690 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -107,8 +107,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::QInt8; if (type.isa()) return torch_upstream::ScalarType::QInt32; - if (type.isa()) { - mlir::Type complexElemType = type.cast().getElementType(); + if (isa(type)) { + mlir::Type complexElemType = cast(type).getElementType(); if (complexElemType.isF16()) return torch_upstream::ScalarType::ComplexHalf; if (complexElemType.isF32()) @@ -121,9 +121,9 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { Type Torch::getTypeForTorchType( MLIRContext *context, Type type, mlir::IntegerType::SignednessSemantics signedness) { - if (type.isa()) + if (isa(type)) return IntegerType::get(context, 64, signedness); - if (type.isa()) + if (isa(type)) return Float64Type::get(context); llvm::report_fatal_error("unhandled type for getTypeForTorchType"); } @@ -187,14 +187,14 @@ Torch::getTorchTypeForScalarType(MLIRContext *context, Type Torch::getDefaultDtypeForTorchScalar(Type type) { MLIRContext *context = type.getContext(); - if (type.isa()) { + if (isa(type)) { // For now, use float32 which is the initial default dtype returned by // `torch.get_default_dtype`. return Float32Type::get(context); } - if (type.isa()) + if (isa(type)) return IntegerType::get(context, 64, IntegerType::Signed); - if (type.isa()) + if (isa(type)) return IntegerType::get(context, 1); llvm_unreachable( "getDefaultDtypeForTorchScalar called on an unsupported type"); @@ -202,11 +202,11 @@ Type Torch::getDefaultDtypeForTorchScalar(Type type) { Type Torch::getBuiltInTypeForTorchScalar(Type type) { MLIRContext *context = type.getContext(); - if (type.isa()) + if (isa(type)) return Float64Type::get(context); - if (type.isa()) + if (isa(type)) return IntegerType::get(context, 64, IntegerType::Signed); - if (type.isa()) + if (isa(type)) return IntegerType::get(context, 1); llvm_unreachable( "getBuiltInTypeForTorchScalar called on an unsupported type"); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index ac9a72586bef..4b89b8da1d6b 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -62,15 +62,14 @@ Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (auto integerType = type.dyn_cast()) - return builder.create(loc, value.cast()); + if (auto integerType = dyn_cast(type)) + return builder.create(loc, cast(value)); - if (auto floatType = type.dyn_cast()) - return builder.create(loc, value.cast()); + if (auto floatType = dyn_cast(type)) + return builder.create(loc, cast(value)); - if (type.isa()) { - return builder.create(loc, - value.cast()); + if (isa(type)) { + return builder.create(loc, cast(value)); } return arith::ConstantOp::materialize(builder, value, type, loc); diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 7bcb67b17c61..12e30f287f3f 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -95,7 +95,7 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { // get outputs Type newResultType = getTypeConverter()->convertType(op.getType(0)); - auto resultType = newResultType.cast(); + auto resultType = cast(newResultType); if (!resultType) { return failure(); } diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index c6085f419eac..0c8cdf2fc54d 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -33,7 +33,7 @@ class VerifyStablehloBackendContractPass converter.addConversion([](Type type) -> Type { auto elemTy = type; if (isa(type)) - elemTy = type.cast().getElementType(); + elemTy = cast(type).getElementType(); if (BaseMemRefType::isValidElementType(elemTy)) return type; return nullptr; diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 4ada196e944c..1cf52144e0a7 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -54,11 +54,11 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } //===----------------------------------------------------------------------===// static bool isArgMemRefTypeValid(Type type) { - if (auto memRefType = type.dyn_cast()) { + if (auto memRefType = dyn_cast(type)) { Type elemTy = memRefType.getElementType(); if (elemTy.isa()) { return true; - } else if (auto integerTy = elemTy.dyn_cast()) { + } else if (auto integerTy = dyn_cast(elemTy)) { if (integerTy.isSignlessInteger(64)) return true; if (integerTy.isSignlessInteger(32)) @@ -69,7 +69,7 @@ static bool isArgMemRefTypeValid(Type type) { return true; if (integerTy.isSignlessInteger(1)) return true; - } else if (auto complexTy = elemTy.dyn_cast()) { + } else if (auto complexTy = dyn_cast(elemTy)) { return complexTy.getElementType().isa(); } } @@ -81,7 +81,7 @@ static void addEmitCInterfaceAttr(func::FuncOp func) { } static Type getAbiTypeForMemRef(Type type) { - return UnrankedMemRefType::get(type.cast().getElementType(), 0); + return UnrankedMemRefType::get(cast(type).getElementType(), 0); } // Helper function to get the type string for one return value like i32, f64, @@ -90,12 +90,12 @@ static Type getAbiTypeForMemRef(Type type) { static std::string getTypeToken(Type type) { if (type.isSignlessInteger()) return ("i" + Twine(type.getIntOrFloatBitWidth())).str(); - else if (type.isa()) + else if (isa(type)) return ("f" + Twine(type.getIntOrFloatBitWidth())).str(); - else if (auto complexTy = type.dyn_cast()) + else if (auto complexTy = dyn_cast(type)) return ("c" + Twine(complexTy.getElementType().getIntOrFloatBitWidth())) .str(); - else if (auto memRefType = type.dyn_cast()) + else if (auto memRefType = dyn_cast(type)) return "mr" + getTypeToken(memRefType.getElementType()); llvm_unreachable( @@ -171,7 +171,7 @@ static LogicalResult mungeFunction( for (auto en : llvm::enumerate(types)) { Type retType = en.value(); Value retVal = op.getOperand(en.index()); - if (auto memrefReturnType = retType.dyn_cast()) { + if (auto memrefReturnType = dyn_cast(retType)) { auto elemType = memrefReturnType.getElementType(); retType = UnrankedMemRefType::get(elemType, 0); // Cast to unranked memref type before sending it as a function diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c993d72d2bcb..7340cb458457 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -704,6 +704,7 @@ "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "PowIntFloatModule_basic", + "PrimListUnpackNumMismatchModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -1365,6 +1366,7 @@ "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "PrimListUnpackNumMismatchModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeModule_basic", "PrimsSumFloatModule_basic", @@ -1552,7 +1554,10 @@ # failed to legalize operation 'torch.operator' "ElementwisePreluModule_basic", - "ElementwisePreluStaticModule_basic", + "ElementwisePreluStaticModule_basic", + + # It appears that you're trying to get value out of a tracing tensor + "PrimListUnpackNumMismatchModule_basic", } MAKE_FX_TOSA_CRASHING_SET = {"CumsumModule_basic"} diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 7c39b3cd7f08..6a2505a9f581 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -758,6 +758,33 @@ def SliceCopyNonZeroDim_Module_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4, 4), tu.rand(10, 2, 4)) +# ============================================================================== +class PrimListUnpackNumMismatchModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + @export + @annotate_args([ + None, + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, x): + if len(x.shape) == 5: + b0, t, c0, h0, w0 = x.shape + b, c, h, w = torch.mul(b0, t), c0, h0, w0 + else: + b1, c1, h1, w1 = x.shape + b, c, h, w = b1, c1, h1, w1 + res = torch.reshape(x, [b, c, h, w]) + return res + + +@register_test_case(module_factory=lambda: PrimListUnpackNumMismatchModule()) +def PrimListUnpackNumMismatchModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2, 1)) + + # ==============================================================================