diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 5d2095f04f14..14e9202222c6 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -97,6 +97,10 @@ getBackendTypeForScalarType(MLIRContext *context, bool isUnsignedTorchType(Type type); +LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter, + Location loc, SmallVector dimensions, + Value input, Value &result); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f9b5cada1049..05636459b2fe 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7181,6 +7181,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d }]; } +def Torch_AtenAdaptiveMaxPool1dOp : Torch_Op<"aten.adaptive_max_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_max_pool1d : (Tensor, int[]) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenAdaptiveMaxPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -7206,6 +7231,31 @@ def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [ }]; } +def Torch_AtenAdaptiveMaxPool3dOp : Torch_Op<"aten.adaptive_max_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveMaxPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenAdaptiveMaxPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 2e3f3e8b8053..f998240b3472 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -306,7 +306,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "AveragePool", 19, + "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; SmallVector dilation; @@ -357,7 +357,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "padding list size does not match twice the number of axes"); } - if (binder.s64IntegerArrayAttr(strides, "strides", {1})) { + if (binder.s64IntegerArrayAttr( + strides, "strides", llvm::SmallVector(rank - 2, 1))) { return failure(); } if (strides.size() != 1 && strides.size() != rank - 2) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index a7bdddbc8d78..1101723aefcc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -970,17 +970,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } if (!constantValue) { - auto dataTensorType = data.getType().cast(); - if (dataTensorType.getDtype().isa()) - constantValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - if (dataTensorType.getDtype().isa()) - constantValue = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0f)); - - if (!constantValue) - return rewriter.notifyMatchFailure( - binder.op, "expected integer or float data tensor"); + constantValue = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0f)); } // Extract all the values of 1-D pad tensor and create a list of all diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index e4bf1886bb91..512123fe43fe 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1457,56 +1457,15 @@ class ConvertAtenPermuteOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "all dimensions must be constant"); Value inVector = adaptor.getSelf(); - auto inType = inVector.getType().cast(); - int64_t inputRank = inType.getRank(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); - Type elementType = inType.getElementType(); - - // Check if the dimensions are a valid constants. - int64_t numDimensions = dimensions.size(); - if (inputRank != numDimensions) + Value result; + if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(), + dimensions, inVector, result))) return rewriter.notifyMatchFailure( - op, "size of `dims` must be equal to the rank of the input"); - for (unsigned i = 0; i < numDimensions; i++) { - if (dimensions[i] < 0) - dimensions[i] = toPositiveDim(dimensions[i], inputRank); - if (!isValidDim(dimensions[i], inputRank)) - return rewriter.notifyMatchFailure(op, "dimension out of range"); - } - - Location loc = op.getLoc(); - - SmallVector outputDims; - for (unsigned i = 0; i < inputRank; i++) - outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i])); + op, "failed to perform permutation of tensor"); - Value outVector = rewriter.create( - loc, getAsOpFoldResult(outputDims), elementType); - SmallVector idExprs; - SmallVector swapExprs; - for (unsigned i = 0; i < inputRank; i++) - idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); - for (unsigned i = 0; i < inputRank; i++) - swapExprs.push_back(idExprs[dimensions[i]]); - - AffineMap inputMap = - AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); - AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, - swapExprs, op->getContext()); - SmallVector indexingMaps{inputMap, outputMap}; - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - auto transpose = rewriter - .create( - loc, outVector.getType(), inVector, outVector, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, outType, transpose); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + rewriter.replaceOpWithNewOp(op, outType, result); return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index e795d2ea9fb8..b1f114af8c72 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -114,8 +114,22 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter, SmallVectorImpl &paddingInts, Value initValue) { SmallVector lowPaddingIncludingNC = {0, 0}; - lowPaddingIncludingNC.append(paddingInts); - SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; + SmallVector highPaddingIncludingNC = {0, 0}; + + unsigned selfRank = self.getType().cast().getRank(); + unsigned paddingIntsSize = paddingInts.size(); + + if (paddingIntsSize == 2 * (selfRank - 2)) { + // This condition being true means that the `paddingInts` contain seperate + // values for low padding and high padding. + for (unsigned i = 0; i < paddingIntsSize / 2; i++) + lowPaddingIncludingNC.push_back(paddingInts[i]); + for (unsigned i = paddingIntsSize / 2; i < paddingIntsSize; i++) + highPaddingIncludingNC.push_back(paddingInts[i]); + } else { + lowPaddingIncludingNC.append(paddingInts); + highPaddingIncludingNC = lowPaddingIncludingNC; + } if (ceilMode) { for (int64_t i = 0; i < dimensionality; ++i) { @@ -159,11 +173,42 @@ static LogicalResult createPoolingOp( Value windowTensor = rewriter.create( loc, getAsOpFoldResult(shape), elementType); - result = rewriter - .create(loc, outTensorInitialized.getType(), - ValueRange{paddedInput, windowTensor}, - outTensorInitialized, stridesAttr, dilationAttr) - .getResult(0); + Value permutedInput = paddedInput, permutedOutput = outTensorInitialized; + if (dimensionality == 3) { + // Permute input and output tensor as follows: + // (n,c,d,h,w) -> (n,d,h,w,c) + SmallVector dimensions = {0, 2, 3, 4, 1}; + if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(), + dimensions, paddedInput, + permutedInput))) + return rewriter.notifyMatchFailure( + op, "failed to perform permutation of tensor"); + + if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(), + dimensions, outTensorInitialized, + permutedOutput))) + return rewriter.notifyMatchFailure( + op, "failed to perform permutation of tensor"); + } + + Value poolingResult = + rewriter + .create(loc, permutedOutput.getType(), + ValueRange{permutedInput, windowTensor}, permutedOutput, + stridesAttr, dilationAttr) + .getResult(0); + + result = poolingResult; + if (dimensionality == 3) { + // Permute output tensor as follows: + // (n,d,h,w,c) -> (n,c,d,h,w) + SmallVector dimensions = {0, 4, 1, 2, 3}; + if (failed(torch_to_linalg::permuteTensor( + op, rewriter, op->getLoc(), dimensions, poolingResult, result))) + return rewriter.notifyMatchFailure( + op, "failed to perform permutation of tensor"); + } + return success(); } @@ -574,16 +619,17 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); - Value divisor; - if constexpr (std::is_same()) { - Value kHtimeskW = rewriter.create( - loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); + // } + + Value divisor = kernelSizeIntValues[0]; + for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { divisor = - op.getDivisorOverride().getType().template isa() - ? kHtimeskW - : adaptor.getDivisorOverride(); - } else { - divisor = kernelSizeIntValues[0]; + rewriter.create(loc, divisor, kernelSizeIntValues[i]); + } + if constexpr (!std::is_same()) { + divisor = isa(op.getDivisorOverride().getType()) + ? divisor + : adaptor.getDivisorOverride(); } divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); @@ -629,159 +675,191 @@ This is problematic for linalg ops for a few reasons: h! Although it is a bit like using a hammer to paint, our workaround is to use tensor.extract to access the elements of the input tensor inside our linalg generic op's payload. - -Current TODO's: - 1. gather most of the boilerplate out of this op and make it into an -adaptive pooling helper function. - 2. figure out what to do with the conflicting decompositions in -DecomposeComplexOps.cpp - 3. Implement more efficient passes for when the kernel-size, input spatial -dims, and output spatial dims are constant. */ namespace { -class ConvertAtenAdaptiveAvgPool1dOp - : public OpConversionPattern { + +class AdaptivePoolingHelper { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenAdaptiveAvgPool1dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + AdaptivePoolingHelper(ConversionPatternRewriter &cpr, int64_t rnk, + int64_t nsp, Type elt) + : rewriter(cpr), rank(rnk), nonSpatial(nsp), elementType(elt) {} + + // Variables that are used in various helper functions in the derived classes + // are stored as members of the base class (to reduce the number of arguments + // passed to helper functions). + ConversionPatternRewriter &rewriter; + const int64_t rank; + const int64_t nonSpatial; + Type elementType; +}; - Location loc = op->getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); +// The following two derived helper classes are used to store the differing +// logic between adaptive avg pooling and adaptive max pooling. +// 1. auxTensorSetup initializes a tensor for storing either indices (max) or +// kernel volumes (avg) +// 2. payloadCustomization customizes those features of the main linalg generic +// op that are not generically "AdaptivePooling". Specifically, for switching +// between sum/max and writing the code for computing the aux tensor elements. +// 3. customizedOpReplacement finishes the op replacement. In the adaptive avg +// case, it includes an additional generic op to divide the sum pool by the +// kernel volume. +// To access these helper functions in the conversion pattern, we +// have an AdaptivePoolingOpTraits class that stores the number of dimensions +// and aliases the associated helper class to a more generic name. + +template +class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { + + // This member variable is templated, so I've chosen not to make it part of + // the base class (to keep the base class non-templated). + const OpConversionPattern &opConversionPattern; - // get rank of input (same as rank of output) - int64_t rank = - adaptor.getSelf().getType().cast().getRank(); - // input operand should be NCH (i.e. rank 3) - if (rank != 3) { - return rewriter.notifyMatchFailure(op, "only supports input type NCH"); +public: + // Constructor for AdaptiveMaxPoolingHelper. Just forwards all arguments + // (except the OpConversionPattern) to the base class constructor. + template + AdaptiveMaxPoolingHelper(const OpConversionPattern &ocp, Args &&...args) + : AdaptivePoolingHelper(std::forward(args)...), + opConversionPattern(ocp) {} + + LogicalResult auxTensorSetup(OpTy op, const SmallVector &outputSizes, + const SmallVector &outShapeIndexVector, + RankedTensorType &outputType, + RankedTensorType &auxTensorType, Value &buffVal, + Value &auxTensor, + SmallVector &auxTensorExprs) { + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); + outputType = typeConverter->convertType(op.getResult0().getType()) + .template cast(); + auxTensorType = typeConverter->convertType(op.getResult1().getType()) + .template cast(); + Type auxTensorElementType = auxTensorType.getElementType(); + auto smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + buffVal = rewriter.create(loc, elementType, + smallestFPValueAttr); + auxTensor = rewriter.create( + loc, getAsOpFoldResult(outputSizes), auxTensorElementType); + for (unsigned i = 0; i < rank; i++) { + auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); } + return success(); + } - // input tensor and output shape - Value input = adaptor.getSelf(); - Value outputShape = op.getOutputSize(); - SmallVector outShapeVector; - getListConstructElements(outputShape, outShapeVector); - outShapeVector = - getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); - Value hIn = getDimOp(rewriter, loc, input, 2); - Value hOut = outShapeVector[0]; - Value hOutIndex = castIntToIndex(rewriter, loc, hOut); - RankedTensorType inputType = input.getType().cast(); - RankedTensorType outputType = - typeConverter->convertType(op.getResult().getType()) - .cast(); + LogicalResult payloadCustomization( + OpBuilder &b, Location loc, const Value &inElt, const Value &res, + const Value &maxIndex, const SmallVector &inputElementIndices, + const SmallVector &inputSpatialSizes, const Value &indexOne, + const SmallVector &starts, const SmallVector &ends, + Value &out2, Value &auxOut) { + // compute max using select, since cond1 will be used for indices + Value cond1 = + b.create(loc, arith::CmpFPredicate::OGT, inElt, res); + out2 = b.create(loc, cond1, inElt, res); + // index in different dims (n x c x d x h x w) + // 1d: (iw) + // 2d: (ih*W + iw) + // 3d: (id*H*W + ih*W + iw) + Value currIndex = inputElementIndices[nonSpatial]; + for (unsigned i = 0; i < rank - nonSpatial - 1; i++) { + Value prevTimesNewSize = + b.create(loc, currIndex, inputSpatialSizes[i + 1]); + currIndex = b.create( + loc, prevTimesNewSize, inputElementIndices[nonSpatial + i + 1]); + } + Value indexOut1Int = castIndexToInt64(b, loc, currIndex); + auxOut = b.create(loc, cond1, indexOut1Int, maxIndex); + return success(); + } - // get elementType of input tensor - Type elementType = inputType.getElementType(); + LogicalResult + customizedOpReplacement(OpTy op, const RankedTensorType &outputType, + const RankedTensorType &auxTensorType, + const Value &adaptivePoolOutput, + const Value &auxTensorReturn, + const SmallVector &auxTensorExprs, + const SmallVector &outputExprs) { + Location loc = op->getLoc(); + Value maxValues = + rewriter.create(loc, outputType, adaptivePoolOutput); + Value outputIndices = + rewriter.create(loc, auxTensorType, auxTensorReturn); + rewriter.replaceOp(op, {maxValues, outputIndices}); + return success(); + } +}; - // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut - Type boolType = rewriter.getI1Type(); - Value kIter; - Value constantOne = - rewriter.create(loc, rewriter.getIndexAttr(1)); - Value hInPlusOne = rewriter.create(loc, hIn, constantOne); - Value kMaxMinusOne = - rewriter.create(loc, hInPlusOne, hOutIndex); - Value kMax = rewriter.create(loc, constantOne, kMaxMinusOne); - kIter = rewriter.create( - loc, getAsOpFoldResult(ValueRange({kMax})), boolType); - - // need to buffer input, else there will possibly be an out of bounds access - // later buffVal = 0 for avg pooling and -inf for max pooling - Value buffVal = rewriter.create( - loc, elementType, rewriter.getFloatAttr(elementType, 0)); - SmallVector lowPadding = {0, 0, 0}; - SmallVector highPadding = {0, 0, 1}; - Value buffInput = torch_to_linalg::getPaddedTensor( - op, rewriter, input, lowPadding, highPadding, buffVal); +template +class AdaptiveAvgPoolingHelper : public AdaptivePoolingHelper { - // make a list of outputSizes - SmallVector outputSizes; - for (unsigned i = 0; i < rank - 1; i++) { - outputSizes.push_back(getDimOp(rewriter, loc, input, i)); - } - outputSizes.push_back(hOutIndex); + const OpConversionPattern &opConversionPattern; - // initialize a kernel size tensor (only for avg pooling) - Value kSizeTensor = rewriter.create( - loc, getAsOpFoldResult(ValueRange({hOutIndex})), elementType); +public: + template + AdaptiveAvgPoolingHelper(const OpConversionPattern &ocp, Args &&...args) + : AdaptivePoolingHelper(std::forward(args)...), + opConversionPattern(ocp) {} + + LogicalResult auxTensorSetup(OpTy op, const SmallVector &outputSizes, + const SmallVector &outShapeIndexVector, + RankedTensorType &outputType, + RankedTensorType &auxTensorType, Value &buffVal, + Value &auxTensor, + SmallVector &auxTensorExprs) { - // initialize an output tensor - Value initOutput = - createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); + Location loc = op->getLoc(); + const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); + outputType = typeConverter->convertType(op.getResult().getType()) + .template cast(); + buffVal = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 0)); + auxTensor = rewriter.create( + loc, getAsOpFoldResult(outShapeIndexVector), elementType); + for (unsigned i = nonSpatial; i < rank; i++) { + auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); + } + return success(); + } - // setup indexing maps and iterator types for linalg generic op - // for kIter (d0,d1,d2,d3) -> (d3) - // for output (d0,d1,d2,d3) -> (d0,d1,d2) - // for kSizeTensor (d0,d1,d2,d3) -> (d2) - SmallVector kIterExprs, outputExprs, kSizeTensorExprs; - for (unsigned i = 0; i < 3; i++) { - outputExprs.push_back(rewriter.getAffineDimExpr(i)); + LogicalResult payloadCustomization( + OpBuilder &b, Location loc, const Value &inElt, const Value &res, + const Value &maxIndex, const SmallVector &inputElementIndices, + const SmallVector &inputSpatialSizes, const Value &indexOne, + const SmallVector &starts, const SmallVector &ends, + Value &out2, Value &auxOut) { + out2 = b.create(loc, inElt, res); + Value kernelVolume = indexOne; + for (unsigned i = 0; i < rank - nonSpatial; i++) { + Value currSize = b.create(loc, ends[i], starts[i]); + kernelVolume = b.create(loc, kernelVolume, currSize); } - kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2)); - kIterExprs.push_back(rewriter.getAffineDimExpr(3)); - SmallVector indexingMaps = AffineMap::inferFromExprList( - {kIterExprs, outputExprs, kSizeTensorExprs}, rewriter.getContext()); - SmallVector iteratorTypes( - 3, utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::reduction); + Value auxOutSI = castIndexToInt64(b, loc, kernelVolume); + auxOut = b.create(loc, elementType, auxOutSI); + return success(); + } - Value indexOne = rewriter.create(loc, 1); - auto sumPool = rewriter.create( - loc, /*resultTensorTypes=*/ - TypeRange({initOutput.getType(), kSizeTensor.getType()}), - /*inputs=*/ValueRange({kIter}), - /*outputs=*/ValueRange({initOutput, kSizeTensor}), - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value res = args[1]; - Value ind0 = b.create(loc, 0); - Value ind1 = b.create(loc, 1); - Value ind2 = b.create(loc, 2); - Value ind3 = b.create(loc, 3); - // compute start and end indices - // st = s1( s0(ind2 * Hin) // Hout ) - Value s0 = b.create(loc, ind2, hIn); - Value s1 = b.create(loc, s0, hOutIndex); - // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) - Value e0 = b.create(loc, ind2, indexOne); - Value e1 = b.create(loc, e0, hIn); - Value e2 = b.create(loc, e1, indexOne); - Value e3 = b.create(loc, e2, hOutIndex); - Value e4 = b.create(loc, indexOne, e3); - // get input element @ st + ind3: - Value wIndex = b.create(loc, s1, ind3); - Value inElt = b.create( - loc, elementType, buffInput, ValueRange({ind0, ind1, wIndex})); - // check if we extracted at windex < end index - Value cond = - b.create(loc, arith::CmpIPredicate(6), wIndex, e4); - // if inElt is in bounds, include it in the computation - // else, use buffVal = 0 (for max pool use -infinity) - Value out1 = b.create(loc, cond, inElt, buffVal); - // compute Kernel size: we store this to kwTensor - Value kSize = b.create(loc, e4, s1); - Value kSizeInt = castIndexToInt64(b, loc, kSize); - Value kSizeF = b.create(loc, elementType, kSizeInt); - // accumulate out2 to res = args[1] - Value out2 = b.create(loc, res, out1); - b.create(loc, ValueRange({out2, kSizeF})); - }); + LogicalResult + customizedOpReplacement(OpTy op, const RankedTensorType &outputType, + const RankedTensorType &auxTensorType, + const Value &adaptivePoolOutput, + const Value &auxTensorReturn, + const SmallVector &auxTensorExprs, + const SmallVector &outputExprs) { - // make a linalg generic to divide each element by the corresponding - // Kernel Width. This step is only necessary for avg pooling. + Location loc = op->getLoc(); SmallVector indexingMaps1 = AffineMap::inferFromExprList( - {kSizeTensorExprs, outputExprs}, rewriter.getContext()); + {auxTensorExprs, outputExprs}, op.getContext()); SmallVector iteratorTypes1( - 3, utils::IteratorType::parallel); + rank, utils::IteratorType::parallel); auto output = rewriter.create( - loc, /*resultTensorTypes=*/initOutput.getType(), - /*inputs=*/sumPool.getResultTensors()[1], - /*outputs=*/sumPool.getResultTensors()[0], + loc, /*resultTensorTypes=*/adaptivePoolOutput.getType(), + /*inputs=*/auxTensorReturn, + /*outputs=*/adaptivePoolOutput, /*indexingMaps=*/indexingMaps1, /*iteratorTypes=*/iteratorTypes1, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -794,65 +872,103 @@ class ConvertAtenAdaptiveAvgPool1dOp return success(); } }; -} // namespace -// The logic for this conversion is similar to the AdaptiveAvgPool1dOp -// conversion. Before writing any more adaptive pooling conversions, the logic -// in this should be off-loaded to a helper function, since each of the adaptive -// ops are essentially the same with some minor tweaks. Instead of kSizeTensor, -// we named the additional output of the linalg generic op auxTensor. -// For max pooling, auxTensor holds the indices of max values, and for -// avg pooling, the auxTensor will be kSizeTensor, used to later divide the -// sum pool by the kernel size. -namespace { -class ConvertAtenAdaptiveMaxPool2dOp - : public OpConversionPattern { +// stores Dim = spatial dims and aliases helper class to a generic name +template struct AdaptivePoolingOpTraits {}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 1; + using AdaptivePoolingHelper = + AdaptiveMaxPoolingHelper; +}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 2; + using AdaptivePoolingHelper = + AdaptiveMaxPoolingHelper; +}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 3; + using AdaptivePoolingHelper = + AdaptiveMaxPoolingHelper; +}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 1; + using AdaptivePoolingHelper = + AdaptiveAvgPoolingHelper; +}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 2; + using AdaptivePoolingHelper = + AdaptiveAvgPoolingHelper; +}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 3; + using AdaptivePoolingHelper = + AdaptiveAvgPoolingHelper; +}; + +template <> struct AdaptivePoolingOpTraits { + static constexpr int64_t Dim = 3; + using AdaptivePoolingHelper = + AdaptiveAvgPoolingHelper; +}; + +template +class ConvertAtenAdaptivePoolOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static const int64_t Dim = AdaptivePoolingOpTraits::Dim; + public: - using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenAdaptiveMaxPool2dOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); + + Value input = adaptor.getSelf(); + RankedTensorType inputType = input.getType().cast(); + const Type elementType = inputType.getElementType(); // get rank of input (same as rank of output) - int64_t rank = - adaptor.getSelf().getType().cast().getRank(); - // input operand should be NCHW (i.e. rank 4) - if (rank != 4) { - return rewriter.notifyMatchFailure(op, "only supports input type NCHW"); + const int64_t rank = inputType.getRank(); + // get number of non-spatial dims + const int64_t nonSpatial = rank - Dim; + if (nonSpatial < 0) { + return rewriter.notifyMatchFailure(op, + "input has insufficient spatial dims"); } - // input tensor and output shape - Value input = adaptor.getSelf(); + typename AdaptivePoolingOpTraits::AdaptivePoolingHelper + adaptivePoolingHelper(*this, rewriter, rank, nonSpatial, elementType); + + // get input and output spatial dimensions as index values Value outputShape = op.getOutputSize(); SmallVector outShapeVector; getListConstructElements(outputShape, outShapeVector); outShapeVector = getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); SmallVector inputSpatialSizes; - for (unsigned i = 2; i < rank; i++) { + for (unsigned i = nonSpatial; i < rank; i++) { inputSpatialSizes.push_back(getDimOp(rewriter, loc, input, i)); } SmallVector outShapeIndexVector; for (auto v : outShapeVector) { outShapeIndexVector.push_back(castIntToIndex(rewriter, loc, v)); } - RankedTensorType inputType = input.getType().cast(); - RankedTensorType outputType = - typeConverter->convertType(op.getResult0().getType()) - .cast(); - - // get elementType of input tensor - Type elementType = inputType.getElementType(); // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut Type boolType = rewriter.getI1Type(); SmallVector kIterSizeVector; Value constantOne = rewriter.create(loc, rewriter.getIndexAttr(1)); - for (int i = 0; i < rank - 2; i++) { + for (int i = 0; i < rank - nonSpatial; i++) { Value hInPlusOne = rewriter.create( loc, inputSpatialSizes[i], constantOne); Value kMaxMinusOne = rewriter.create( @@ -864,67 +980,66 @@ class ConvertAtenAdaptiveMaxPool2dOp Value kIter = rewriter.create( loc, getAsOpFoldResult(kIterSizeVector), boolType); - // need to buffer input, else there will possibly be an out of bounds access - // later buffVal = 0 for avg pooling and -inf for max pooling - auto smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), - /*Negative=*/true)); - Value buffVal = rewriter.create(loc, elementType, - smallestFPValueAttr); - SmallVector lowPadding(rank, 0); - SmallVector highPadding(2, 0); - for (int i = 0; i < rank - 2; i++) { - highPadding.push_back(1); - } - Value buffInput = torch_to_linalg::getPaddedTensor( - op, rewriter, input, lowPadding, highPadding, buffVal); - - // make a list of outputSizes + // get output sizes used for initializing some tensors SmallVector outputSizes; - for (unsigned i = 0; i < 2; i++) { + for (unsigned i = 0; i < nonSpatial; i++) { outputSizes.push_back(getDimOp(rewriter, loc, input, i)); } - for (unsigned i = 2; i < rank; i++) { - outputSizes.push_back(outShapeIndexVector[i - 2]); + for (unsigned i = 0; i < rank - nonSpatial; i++) { + outputSizes.push_back(outShapeIndexVector[i]); } - // for avg pooling the auxTensor should hold kernel widths (kSizeTensor) - // for max Pooling, it should hold the indices - RankedTensorType outputType1 = - typeConverter->convertType(op.getResult1().getType()) - .cast(); - Type indicesType = outputType1.getElementType(); - Value auxTensor = rewriter.create( - loc, getAsOpFoldResult(outputSizes), indicesType); + // get outputType and initialize an auxTensor + // the auxTensor is customizable: + // avg pooling -> auxTensor = kernelVolumes + // max pooling -> auxTensor = indices + RankedTensorType outputType, auxTensorType; + Value buffVal, auxTensor; + SmallVector auxTensorExprs; + if (failed(adaptivePoolingHelper.auxTensorSetup( + op, outputSizes, outShapeIndexVector, outputType, auxTensorType, + buffVal, auxTensor, auxTensorExprs))) { + return rewriter.notifyMatchFailure(op, "failed auxTensor setup"); + } - // initialize an output tensor + // initialize output tensor Value initOutput = createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); - // setup indexing maps and iterator types for linalg generic op (outputShape - // (rank),kIter (rank -2)) for kIter (d0,d1,d2,d3,d4,d5) -> (d4,d5) for - // output (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) for auxTensor - // (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) (or (d2,d3) for avg pooling) - SmallVector kIterExprs, outputExprs, auxTensorExprs; + // pad the input with buffVal = 0 (avg) or -inf (max) + SmallVector lowPadding(rank, 0); + SmallVector highPadding(nonSpatial, 0); + for (int i = 0; i < rank - nonSpatial; i++) { + highPadding.push_back(1); + } + Value buffInput = torch_to_linalg::getPaddedTensor( + op, rewriter, input, lowPadding, highPadding, buffVal); + + // setup indexing maps and iterator types for linalg generic op + // for example, with rank = 4 and nonSpatial = 2: + // kIter (d0,d1,d2,d3,d4,d5) -> (d4,d5) + // output (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) + SmallVector kIterExprs, outputExprs; // batch + channel + output spatial dims for (unsigned i = 0; i < rank; i++) { outputExprs.push_back(rewriter.getAffineDimExpr(i)); - auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); } // kIter covers last rank-2 indices - for (unsigned i = rank; i < 2 * rank - 2; i++) { + for (unsigned i = rank; i < 2 * rank - nonSpatial; i++) { kIterExprs.push_back(rewriter.getAffineDimExpr(i)); } SmallVector indexingMaps = AffineMap::inferFromExprList( {kIterExprs, outputExprs, auxTensorExprs}, rewriter.getContext()); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - for (unsigned i = 0; i < rank - 2; i++) { + for (unsigned i = 0; i < rank - nonSpatial; i++) { iteratorTypes.push_back(utils::IteratorType::reduction); } Value indexOne = rewriter.create(loc, 1); - auto maxPool = rewriter.create( + + bool failedCustomization = false; + // adaptive pooling generic op + auto adaptivePool = rewriter.create( loc, /*resultTensorTypes=*/ TypeRange({initOutput.getType(), auxTensor.getType()}), /*inputs=*/ValueRange({kIter}), @@ -935,64 +1050,70 @@ class ConvertAtenAdaptiveMaxPool2dOp Value res = args[1]; Value maxIndex = args[2]; SmallVector ind; - for (unsigned i = 0; i < 2 * rank - 2; i++) { + for (unsigned i = 0; i < 2 * rank - nonSpatial; i++) { ind.push_back(b.create(loc, i)); } // compute start and end indices // st = s1( s0(ind2 * Hin) // Hout ) SmallVector starts; SmallVector ends; - for (unsigned i = 2; i < rank; i++) { - Value s0 = - b.create(loc, ind[i], inputSpatialSizes[i - 2]); + for (unsigned i = nonSpatial; i < rank; i++) { + Value s0 = b.create( + loc, ind[i], inputSpatialSizes[i - nonSpatial]); Value s1 = b.create( - loc, s0, outShapeIndexVector[i - 2]); + loc, s0, outShapeIndexVector[i - nonSpatial]); starts.push_back(s1); // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) Value e0 = b.create(loc, ind[i], indexOne); - Value e1 = - b.create(loc, e0, inputSpatialSizes[i - 2]); + Value e1 = b.create( + loc, e0, inputSpatialSizes[i - nonSpatial]); Value e2 = b.create(loc, e1, indexOne); Value e3 = b.create( - loc, e2, outShapeIndexVector[i - 2]); + loc, e2, outShapeIndexVector[i - nonSpatial]); Value e4 = b.create(loc, indexOne, e3); ends.push_back(e4); } + // extract input element SmallVector inputElementIndices; - inputElementIndices.push_back(ind[0]); - inputElementIndices.push_back(ind[1]); - for (unsigned i = 2; i < rank; i++) { - inputElementIndices.push_back( - b.create(loc, starts[i - 2], ind[rank - 2 + i])); + for (unsigned i = 0; i < nonSpatial; i++) { + inputElementIndices.push_back(ind[i]); + } + for (unsigned i = nonSpatial; i < rank; i++) { + inputElementIndices.push_back(b.create( + loc, starts[i - nonSpatial], ind[rank - nonSpatial + i])); } Value inElt = b.create(loc, elementType, buffInput, inputElementIndices); // check if we extracted at windex < end index - for (unsigned i = 0; i < rank - 2; i++) { - Value cond = - b.create(loc, arith::CmpIPredicate(6), - inputElementIndices[i + 2], ends[i]); + for (unsigned i = 0; i < rank - nonSpatial; i++) { + Value cond = b.create( + loc, arith::CmpIPredicate(6), + inputElementIndices[i + nonSpatial], ends[i]); + // if out-of-bounds, replace the extracted element with buffVal inElt = b.create(loc, cond, inElt, buffVal); } - Value cond1 = b.create(loc, arith::CmpFPredicate::OGT, - inElt, res); - // index location is (ih * input_width + iw) - Value indexOut0 = b.create(loc, inputElementIndices[2], - inputSpatialSizes[1]); - Value indexOut1 = - b.create(loc, indexOut0, inputElementIndices[3]); - Value indexOut1Int = castIndexToInt64(b, loc, indexOut1); - Value indexOut2 = - b.create(loc, cond1, indexOut1Int, maxIndex); - Value out2 = b.create(loc, cond1, inElt, res); - b.create(loc, ValueRange({out2, indexOut2})); + Value out2, auxOut; + // customize for max vs. avg: + if (failed(adaptivePoolingHelper.payloadCustomization( + b, loc, inElt, res, maxIndex, inputElementIndices, + inputSpatialSizes, indexOne, starts, ends, out2, auxOut))) { + failedCustomization = true; + } + b.create(loc, ValueRange({out2, auxOut})); }); - Value maxValues = rewriter.create( - loc, outputType, maxPool.getResultTensors()[0]); - Value outputIndices = rewriter.create( - loc, outputType1, maxPool.getResultTensors()[1]); - rewriter.replaceOp(op, {maxValues, outputIndices}); + if (failedCustomization) { + return rewriter.notifyMatchFailure( + op, "failed linalg generic payload customization."); + } + Value adaptivePoolOutput = adaptivePool.getResultTensors()[0]; + Value auxTensorReturn = adaptivePool.getResultTensors()[1]; + + if (failed(adaptivePoolingHelper.customizedOpReplacement( + op, outputType, auxTensorType, adaptivePoolOutput, auxTensorReturn, + auxTensorExprs, outputExprs))) { + return rewriter.notifyMatchFailure(op, "failed customizedOpReplacement."); + } return success(); } }; @@ -1009,15 +1130,32 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns .add>( typeConverter, context); patterns .add>( typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns + .add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 366f5492aa6d..c83025e42e67 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -576,3 +576,55 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) { llvm_unreachable("Unknown type checked for signedness"); return false; } + +LogicalResult torch_to_linalg::permuteTensor(Operation *op, + PatternRewriter &rewriter, + Location loc, + SmallVector dimensions, + Value input, Value &result) { + auto inType = cast(input.getType()); + int64_t inputRank = inType.getRank(); + Type elementType = inType.getElementType(); + + // Check if the dimensions are a valid constants. + int64_t numDimensions = dimensions.size(); + if (inputRank != numDimensions) + return rewriter.notifyMatchFailure( + op, "size of `dims` must be equal to the rank of the input"); + for (uint32_t i = 0; i < numDimensions; i++) { + if (dimensions[i] < 0) + dimensions[i] = toPositiveDim(dimensions[i], inputRank); + if (!isValidDim(dimensions[i], inputRank)) + return rewriter.notifyMatchFailure(op, "dimension out of range"); + } + + SmallVector outputDims; + for (uint32_t i = 0; i < inputRank; i++) + outputDims.push_back(getDimOp(rewriter, loc, input, dimensions[i])); + + Value outVector = rewriter.create( + loc, getAsOpFoldResult(outputDims), elementType); + SmallVector idExprs; + SmallVector swapExprs; + for (uint32_t i = 0; i < inputRank; i++) + idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); + for (uint32_t i = 0; i < inputRank; i++) + swapExprs.push_back(idExprs[dimensions[i]]); + + AffineMap inputMap = + AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); + AffineMap outputMap = + AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext()); + SmallVector indexingMaps{inputMap, outputMap}; + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); + result = rewriter + .create( + loc, outVector.getType(), input, outVector, indexingMaps, + iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + return success(); +} diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a65c446b5fe9..ee9fe6e26d44 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5977,6 +5977,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) INSERT_UNARY_PATTERN(AtenErfOp, tosa::ErfOp) INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 65aeb6ddad4f..ac1c08594e9a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8155,24 +8155,192 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %38 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.avg_pool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.avg_pool3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.avg_pool3d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %int-4 = torch.constant.int -4\n" +" %int-5 = torch.constant.int -5\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: max_pool3d: padding must either be a single int, or a tuple of thee ints\"\n" +" %str_1 = torch.constant.str \"AssertionError: max_pool3d: stride must either be omitted, a single int, or a tuple of three ints\"\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: max_pool3d: kernel_size must either be a single int, or a tuple of three ints\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.tuple) {\n" +" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" } else {\n" +" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" }\n" +" %6:3 = torch.prim.TupleUnpack %5 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13:3 = torch.prim.If %12 -> (!torch.int, !torch.int, !torch.int) {\n" +" torch.prim.If.yield %6#0, %6#0, %6#0 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %38 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %40:3 = torch.prim.If %39 -> (!torch.int, !torch.int, !torch.int) {\n" +" %41 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %41, %42, %43 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %41 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %41, %42, %43 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" torch.prim.If.yield %40#0, %40#1, %40#2 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.tuple) {\n" +" %38 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" } else {\n" +" %38 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg3, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" }\n" +" %20:3 = torch.prim.TupleUnpack %19 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %21 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.int) {\n" +" %38 = torch.aten.__getitem__.t %arg0, %int-5 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %38 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %27 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %29 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %31 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%28, %6#0, %20#0, %13#0, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %32 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%29, %6#1, %20#1, %13#1, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %33 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%30, %6#2, %20#2, %13#2, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %34 = call @__torch__._pool3d_shape_check(%arg0, %6#0, %6#1, %6#2, %13#0, %13#1, %13#2, %20#0, %20#1, %20#2, %int1, %int1, %int1, %31, %32, %33) : (!torch.list, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n" +" %35 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.list) {\n" +" %38 = torch.prim.ListConstruct %27, %31, %32, %33 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %38 : !torch.list\n" +" } else {\n" +" %38 = torch.prim.ListConstruct %26, %27, %31, %32, %33 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %38 : !torch.list\n" +" }\n" +" return %37 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" -" %0 = call @__torch__.adaptive_max_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list>\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__.adaptive_pool(%arg0, %arg1, %int1) : (!torch.list, !torch.list, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" " }\n" -" func.func @__torch__.adaptive_max_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" func.func @__torch__.adaptive_pool(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.tuple, list> {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" " %int2 = torch.constant.int 2\n" -" %int3 = torch.constant.int 3\n" -" %int4 = torch.constant.int 4\n" " %int0 = torch.constant.int 0\n" " %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.aten.eq.int %0, %arg2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -8180,26 +8348,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %3 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.eq.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %12 : !torch.bool\n" +" %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %13 = torch.aten.add.int %arg2, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.eq.int %12, %13 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %14 : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" torch.prim.Loop %5, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %12 -> () {\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %12 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.ne.int %12, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -8207,24 +8377,41 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" -" %6 = torch.prim.ListConstruct : () -> !torch.list\n" -" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %8 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %8, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %12 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %7, %12 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" -" %9 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" torch.prim.Loop %9, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %11 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" %10 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %12 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %7, %12 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" -" %10 = torch.prim.TupleConstruct %6, %6 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %10 : !torch.tuple, list>\n" +" %11 = torch.prim.TupleConstruct %7, %7 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %11 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %int2 = torch.constant.int 2\n" +" %0 = call @__torch__.adaptive_pool(%arg0, %arg1, %int2) : (!torch.list, !torch.list, !torch.int) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool3d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %int3 = torch.constant.int 3\n" +" %0 = call @__torch__.adaptive_pool(%arg0, %arg1, %int3) : (!torch.list, !torch.list, !torch.int) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool3d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = call @__torch__.adaptive_pool(%arg0, %arg1, %int3) : (!torch.list, !torch.list, !torch.int) -> !torch.tuple, list>\n" +" %1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple, list>, !torch.int -> !torch.list\n" +" return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.flatten.using_ints\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" @@ -9809,6 +9996,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10270,12 +10461,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 39d198c1dac7..66ca5e12c9d4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5248,6 +5248,14 @@ class DecomposeAtenPadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPadOp op, PatternRewriter &rewriter) const override { + std::string mode; + if (!matchPattern(op.getMode(), m_TorchConstantStr(mode))) { + return rewriter.notifyMatchFailure(op, "Unsupported value of mode"); + } + + if (mode != "constant") { + return rewriter.notifyMatchFailure(op, "Unsupported mode: " + mode); + } Value value = op.getValue(); if (value.getType().isa()) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9c362df4a928..14b30bcd5519 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -258,6 +258,9 @@ "ElementwiseDivRoundingModeTruncModule_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", # ERROR: Exception: Unsupported op: get_attr "NumToTensorFloatModule_basic", @@ -460,6 +463,7 @@ "AtenToDtypeModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", + "AvgPool3dStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmStaticModule_basic", @@ -919,6 +923,7 @@ STABLEHLO_CRASHING_SET = { "AtenEmbeddingBagSumExample_basic", + "AvgPool3dStaticModule_basic" } # Write the TOSA set as a "passing" set as it is very early in development @@ -1690,6 +1695,19 @@ "AdaptiveMaxPool2dDynamic_basic", "AdaptiveMaxPool2dStaticWithIndices_basic", "AdaptiveMaxPool2dStatic_basic", + "AdaptiveMaxPool3dStatic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamicNoBatch_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", "AddCDivModule_basic", "AddIntModule_basic", "Add_Module_basic", @@ -2060,17 +2078,8 @@ "LinalgNormModule_basic", # Failure - onnx_lowering: onnx.AveragePool - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dStaticEvenMultiple_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AvgPool1dFloatModule_basic", - "AvgPool1dIntModule_basic", - "AvgPool1dStaticModule_basic", - "AvgPool2dCeilModeTrueModule_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AvgPool2dDivisorOverrideModule_basic", - "AvgPool2dFloatModule_basic", - "AvgPool2dIntModule_basic", - "AvgPool2dStaticModule_basic", # Failure - onnx_lowering: onnx.Cast "BucketizeTensorOutInt32RightModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index f21d2d57fcb5..7b5630e73a53 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -877,6 +877,69 @@ def aten〇max_pool2d_with_indices_backward〡shape(grad_output: List[int], self def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size +# TODO: This should be upstreamed. +# See https://github.com/pytorch/pytorch/pull/76889 for an example. +def avg_pool3d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]): + assert ( + len(kernel_size) == 1 or len(kernel_size) == 3 + ), "max_pool3d: kernel_size must either be a single int, or a tuple of three ints" + (kD, kH, kW) = (kernel_size[0], kernel_size[0], kernel_size[0]) if len(kernel_size) == 1 else (kernel_size[0], kernel_size[1], kernel_size[2]) + + assert ( + len(stride) == 0 or len(stride) == 1 or len(stride) == 3 + ), "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints" + + if len(stride) == 0: + (dD, dH, dW) = (kD, kD, kD) + elif len(stride) == 1: + (dD, dH, dW) = (stride[0], stride[0], stride[0]) + else: # len(stride) == 3 + (dD, dH, dW) = (stride[0], stride[1], stride[2]) + + assert ( + len(padding) == 1 or len(padding) == 3 + ), "max_pool3d: padding must either be a single int, or a tuple of thee ints" + (padD, padH, padW) = (padding[0], padding[0], padding[0]) if len(padding) == 1 else (padding[0], padding[1], padding[2]) + + dilationD = 1 + dilationH = 1 + dilationW = 1 + + assert len(input) == 4 or len(input) == 5 + nbatch = input[-5] if len(input) == 5 else 1 + nInputPlane = input[-4] + inputDepth = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputDepth = upstream_shape_functions.pooling_output_shape(inputDepth, kD, padD, dD, dilationD, ceil_mode) + outputHeight = upstream_shape_functions.pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = upstream_shape_functions.pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + _pool3d_shape_check( + input, + kD, + kH, + kW, + dD, + dH, + dW, + padD, + padH, + padW, + dilationD, + dilationH, + dilationW, + outputDepth, + outputHeight, + outputWidth, + ) + + if len(input) == 4: + return [nInputPlane, outputDepth, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputDepth, outputHeight, outputWidth] + # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]): @@ -974,26 +1037,38 @@ def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) +def aten〇avg_pool3d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: + return avg_pool3d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]: return upstream_shape_functions.adaptive_avg_pool2d(self, output_size) -def adaptive_max_pool2d(self: List[int], out: List[int]): - assert len(out) == 2 - assert len(self) == 3 or len(self) == 4 +def adaptive_pool(self: List[int], out: List[int], dim: int): + assert len(out) == dim + assert len(self) == dim + 1 or len(self) == dim + 2 for i in range(len(self)): assert self[i] != 0 shape: List[int] = [] - for i in range(len(self) - 2): + for i in range(len(self) - dim): shape.append(self[i]) for j in range(len(out)): shape.append(out[j]) return shape, shape +def aten〇adaptive_max_pool1d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]: + return adaptive_pool(self, output_size, 1) + def aten〇adaptive_max_pool2d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]: - return adaptive_max_pool2d(self, output_size) + return adaptive_pool(self, output_size, 2) + +def aten〇adaptive_max_pool3d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]: + return adaptive_pool(self, output_size, 3) + +def aten〇adaptive_avg_pool3d〡shape(self: List[int], output_size: List[int]) -> List[int]: + return adaptive_pool(self, output_size, 3)[0] def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) @@ -2103,6 +2178,11 @@ def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_ self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 9)], output_size=[2, 2, 2])) +def aten〇adaptive_avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype @@ -2509,11 +2589,21 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) +def aten〇adaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) def aten〇adaptive_max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 13)], output_size=[2, 2, 2])) +def aten〇adaptive_max_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7db3ea511164..e14dd6dc9159 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -514,7 +514,9 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_max_pool1d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") + emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 33eddb6b1dd8..9bb696e54895 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -252,7 +252,7 @@ def _get_for_tracing( # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'], - OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'], + OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d','aten.adaptive_avg_pool2d', 'aten.unflatten.int'], OutputType.STABLEHLO: [], } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 22ff3bb330ad..8ab03ddeb019 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -873,6 +873,38 @@ def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils): # ============================================================================== +class AvgPool3dStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[2, 2, 2], + stride=[2, 2, 2], + padding=[0, 0, 0], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([2, 2, 4, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool3dStaticModule()) +def AvgPool3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2, 4, 4, 4, low=-1)) + + +# ============================================================================== + + class AvgPool1dFloatModule(torch.nn.Module): def __init__(self): @@ -1004,6 +1036,26 @@ def AdaptiveAvgPool1dGeneralDynamic_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 10)) +class AdaptiveAvgPool1dGeneralDynamicNoBatches(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) + + @export + @annotate_args([ + None, + ([-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dGeneralDynamicNoBatches()) +def AdaptiveAvgPool1dGeneralDynamicNoBatches_basic( + module, tu: TestUtils): + module.forward(tu.rand(512, 10)) + class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): @@ -1084,6 +1136,155 @@ def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) +# AdaptiveAvgPool2d + + +class AdaptiveAvgPool2dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d(output_size=(7,13)) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dDynamic()) +def AdaptiveAvgPool2dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + +class AdaptiveAvgPool2dDynamicNoBatch(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d(output_size=(7,13)) + + @export + @annotate_args([ + None, + ([-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dDynamicNoBatch()) +def AdaptiveAvgPool2dDynamicNoBatch_basic( + module, tu: TestUtils): + module.forward(tu.rand(512, 10, 16)) + +# AdaptiveAvgPool3d + +class AdaptiveAvgPool3dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap3d = torch.nn.AdaptiveAvgPool3d(output_size=(7,13,15)) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool3dDynamic()) +def AdaptiveAvgPool3dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16, 17)) + +class AdaptiveAvgPool3dDynamicNoBatch(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap3d = torch.nn.AdaptiveAvgPool3d(output_size=(7,13,15)) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool3dDynamicNoBatch()) +def AdaptiveAvgPool3dDynamicNoBatch_basic( + module, tu: TestUtils): + module.forward(tu.rand(512, 10, 16, 17)) + +# AdaptiveMaxPool1d + +class AdaptiveMaxPool1dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(7), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool1dDynamic()) +def AdaptiveMaxPool1dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10)) + +class AdaptiveMaxPool1dDynamicNoBatch(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(7), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool1dDynamicNoBatch()) +def AdaptiveMaxPool1dDynamicNoBatch_basic( + module, tu: TestUtils): + module.forward(tu.rand(512, 10)) + +class AdaptiveMaxPool1dStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(7), return_indices=False) + + @export + @annotate_args([ + None, + ([1, 512, 10], torch.float32, True) + ]) + def forward(self,x): + return self.amp1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool1dStatic()) +def AdaptiveMaxPool1dStatic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10)) + +# AdaptiveMaxPool2d + class AdaptiveMaxPool2dDynamic(torch.nn.Module): def __init__(self): @@ -1104,6 +1305,26 @@ def AdaptiveMaxPool2dDynamic_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) +class AdaptiveMaxPool2dDynamicNoBatch(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamicNoBatch()) +def AdaptiveMaxPool2dDynamicNoBatch_basic( + module, tu: TestUtils): + module.forward(tu.rand(512, 10, 16)) + class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): def __init__(self): @@ -1164,3 +1385,106 @@ def forward(self,x): def AdaptiveMaxPool2dStaticWithIndices_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) + +# AdaptiveMaxPool3d + +class AdaptiveMaxPool3dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool3dDynamic()) +def AdaptiveMaxPool3dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16, 17)) + +class AdaptiveMaxPool3dDynamicNoBatch(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool3dDynamicNoBatch()) +def AdaptiveMaxPool3dDynamicNoBatch_basic( + module, tu: TestUtils): + module.forward(tu.rand(512, 10, 16, 17)) + +class AdaptiveMaxPool3dDynamicWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=True) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool3dDynamicWithIndices()) +def AdaptiveMaxPool3dDynamicWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16, 17)) + + +class AdaptiveMaxPool3dStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=False) + + @export + @annotate_args([ + None, + ([1, 512, 10, 9, 5], torch.float32, True) + ]) + def forward(self,x): + return self.amp3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool3dStatic()) +def AdaptiveMaxPool3dStatic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 9, 5)) + +class AdaptiveMaxPool3dStaticWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp3d = torch.nn.AdaptiveMaxPool3d(output_size=(7,13,15), return_indices=True) + + @export + @annotate_args([ + None, + ([1, 512, 10, 16, 17], torch.float32, True) + ]) + def forward(self,x): + return self.amp3d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices()) +def AdaptiveMaxPool3dStaticWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16, 17)) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 74025cfc6342..317b5c9efe86 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1448,3 +1448,29 @@ func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5] %0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1> return %0 : !torch.vtensor<[5,5],i1> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sin$basic( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.sin %[[ARG_BUILTIN]] : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.sin$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.sin %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cos$basic( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.cos %[[ARG_BUILTIN]] : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cos$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.cos %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} \ No newline at end of file diff --git a/test/Dialect/Torch/decompose-complex-ops-illegal.mlir b/test/Dialect/Torch/decompose-complex-ops-illegal.mlir new file mode 100644 index 000000000000..773c0f5c3c30 --- /dev/null +++ b/test/Dialect/Torch/decompose-complex-ops-illegal.mlir @@ -0,0 +1,41 @@ +// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s + +func.func @torch.aten.pad.reflect(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "reflect" + // CHECK: torch.aten.pad %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} + +// ----- + +func.func @torch.aten.pad.edge(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "edge" + // CHECK: torch.aten.pad %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} + +// ----- + +func.func @torch.aten.pad.wrap(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "wrap" + // CHECK: torch.aten.pad %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops-legal.mlir b/test/Dialect/Torch/decompose-complex-ops-legal.mlir index 9cf4c3e9babd..27a5b5647c94 100644 --- a/test/Dialect/Torch/decompose-complex-ops-legal.mlir +++ b/test/Dialect/Torch/decompose-complex-ops-legal.mlir @@ -8,3 +8,17 @@ func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torc %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> return %ret : !torch.tensor<[2,3],f32> } + +// ----- + +func.func @torch.aten.pad.constant(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "constant" + // CHECK: torch.aten.constant_pad_nd %{{.*}}, %{{.*}}, %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} diff --git a/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-contract-check.mlir b/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-contract-check.mlir new file mode 100644 index 000000000000..33fbfcb90c66 --- /dev/null +++ b/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-contract-check.mlir @@ -0,0 +1,24 @@ +// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{verify=0})' -split-input-file %s | FileCheck %s + +// CHECK: func.func @tosa +func.func @tosa(%arg0: tensor) -> tensor { + // CHECK: tosa.abs + %1 = tosa.abs %arg0 : (tensor) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK: func.func @torch_gemm +func.func @torch_gemm(%arg0: tensor, %arg1: tensor<3x?xf32>, %arg2: tensor) -> (tensor {onnx.name = "gemm"}) attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch_c.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[?,3],f32> + %1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xf32> -> !torch.vtensor<[3,?],f32> + %2 = torch_c.from_builtin_tensor %arg2 : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %3 = torch.aten.mm %0, %1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32> + %4 = torch.aten.add.Tensor %3, %2, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + %5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[?,?],f32> -> tensor + %6 = tosa.abs %5 : (tensor) -> tensor + return %6 : tensor +} diff --git a/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-mlprogram.mlir b/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-mlprogram.mlir new file mode 100644 index 000000000000..52280ecdfa0f --- /dev/null +++ b/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-mlprogram.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{use-mlprogram=0})' -split-input-file %s | FileCheck %s +// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{use-mlprogram=1})' -split-input-file %s | FileCheck --check-prefix=YES-CHECK %s + +// CHECK-NOT: ml_program.global{{.*}}@global_seed +// YES-CHECK: ml_program.global{{.*}}@global_seed +// CHECK: func.func @torch_gemm +func.func @torch_gemm(%arg0: tensor, %arg1: tensor<3x?xf32>, %arg2: tensor) -> (tensor {onnx.name = "gemm"}) attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch_c.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[?,3],f32> + %1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xf32> -> !torch.vtensor<[3,?],f32> + %2 = torch_c.from_builtin_tensor %arg2 : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %3 = torch.aten.mm %0, %1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32> + %4 = torch.aten.add.Tensor %3, %2, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + %5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[?,?],f32> -> tensor + return %5 : tensor +}