diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 14e9202222c6..b59d183b4084 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter, Location loc, SmallVector dimensions, Value input, Value &result); +// Flips an input tensor based on the values of axis list. +Value flipTensor(PatternRewriter &rewriter, Location loc, Value input, + SmallVector axis); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 195f36e9aa1f..f41b8707b59d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // TODO: Implement max and min cases if (reduction == "mul") { - reduction = "multiply"; + reduction = "prod"; } else if (reduction == "max" || reduction == "min") { return rewriter.notifyMatchFailure( binder.op, "max/min reduction unsupported for scatter elements"); + } else if (reduction == "add") { + reduction = "sum"; } Value cstStrReduction = rewriter.create(binder.getLoc(), reduction); - - rewriter.replaceOpWithNewOp( + Value cstTrue = + rewriter.create(binder.getLoc(), true); + rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, updates, - cstStrReduction); + cstStrReduction, cstTrue); return success(); }); patterns.onOp( diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5542e0fc642f..902daa1cb5ad 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef a) { template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + int64_t &dim, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value one = rewriter.create(loc, 1); Value negone = rewriter.create(loc, -1); - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); @@ -1658,10 +1658,17 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - // TODO: Handle the case where the dim(th) dimension is dynamic. + // assert dynamic squeeze dim size == 1 if (inputType.isDynamicDim(dim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: dim(th) dimension is not expected to be dynamic"); + Value cstDim = rewriter.create(op.getLoc(), dim); + Value dimVal = rewriter.create(op.getLoc(), input, cstDim); + Value cstOne = rewriter.create(op.getLoc(), 1); + Value cmp = rewriter.create( + op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); + rewriter.create( + op.getLoc(), cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); } const TypeConverter *typeConverter = getTypeConverter(); @@ -1671,7 +1678,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { // If the dim(th) dimension of operand tensor type is not statically unit, // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1) { + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } @@ -1857,14 +1864,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } + + // If stride is negative, then flip the input tensor corresponding to that + // dim, update the stride for flipped tensor by multiplying it by -1, and + // update the offset as follows: + // flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride) + // + // For example: + // Input = [0, 1, 2, 3, 4, 5] + // stride = [-2], result_shape = [2], offset = [3] + // Result = [3, 1] + // After flipping: + // Input = [5, 4, 3, 2, 1, 0] + // stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2] + // Result = [3, 1] + + Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input, + SmallVector{dim}); + Value cstDim = rewriter.create(loc, dim); + Value zero = rewriter.create(loc, 0); + Value isNegativeStride = rewriter.create( + loc, arith::CmpIPredicate::slt, strides[dim], zero); + strides[dim] = rewriter.create(loc, strides[dim]); + Value resShapeMulStride = + rewriter.create(loc, resultShape[dim], strides[dim]); + Value inputDim = rewriter.create(loc, input, cstDim); + Value flippedOffset = + rewriter.create(loc, inputDim, resShapeMulStride); + offsets[dim] = rewriter.create( + loc, isNegativeStride, flippedOffset, offsets[dim]); + + input = rewriter.create(loc, isNegativeStride, + flippedInput, input); + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); auto sliceType = RankedTensorType::get( dynShape, resultType.getElementType(), resultType.getEncoding()); @@ -2095,12 +2134,11 @@ class ConvertAtenSliceScatterOp RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 52765411bd73..a4962d12abdc 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = cast(adaptor.getSelf().getType()).getRank(); - Type elementType = - cast(adaptor.getSelf().getType()).getElementType(); - Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector axis; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern { } } - // Only used to calculate flipped values, i.e. those on the flip axes. Other - // dims won't be used. - SmallVector dims = getTensorSizes(rewriter, loc, self); - for (auto flipDim : axis) - dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); - - Value initTensor = createZeroInitTensor( - rewriter, loc, getTensorSizes(rewriter, loc, self), elementType); - - SmallVector iteratorTypes( - selfRank, utils::IteratorType::parallel); - SmallVector indexingMaps( - 2, AffineMap::getMultiDimIdentityMap(selfRank, context)); - Value flipped = - rewriter - .create( - loc, self.getType(), self, initTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (auto i = 0; i < selfRank; i++) - indices.push_back(b.create(loc, i)); - for (auto flipDim : axis) { - indices[flipDim] = b.create( - loc, dims[flipDim], indices[flipDim]); - } - Value res = b.create(loc, self, indices) - .getResult(); - b.create(loc, res); - }) - .getResult(0); - + Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis); rewriter.replaceOpWithNewOp(op, self.getType(), flipped); - return success(); } }; @@ -1221,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (numSpatialDims != 2) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); - // Special depthwise case: Cin = Cout = groups. // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple // of groups) to be depthwise in their documentation, but the linalg ops @@ -1236,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (inShape[1] == numGroups && weightShape[0] == numGroups && weightShape[1] == 1) { // Collapse weight shape (C/G == 1) - SmallVector collapsedDims = {{0, 1}, {2}, {3}}; - SmallVector collapsedShape{weightShape[0] * weightShape[1], - weightShape[2], weightShape[3]}; + SmallVector collapsedDims = {{0, 1}}; + SmallVector collapsedShape{weightShape[0] * weightShape[1]}; + for (unsigned i = 0; i < numSpatialDims; i++) { + collapsedDims.push_back({i + 2}); + collapsedShape.push_back(weightShape[i + 2]); + } Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), weightDTy); Value collapsedWeight = rewriter.create( loc, collapsedType, weight, collapsedDims); if (!inputZp) { - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + switch (numSpatialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D and 2D depthwise convolution " + "supported for special case of group convolution"); + }; } else { + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D depthwise quantized convolution " + "supported for special case of group convolution"); + // currently, the only named depthwise qconv op is nhwc_hwc // input: nchw -> nhwc; weight (collapsed): chw -> hwc // linalg conv result nhwc -> nchw @@ -1297,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D grouped convolution supported"); + // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { auto inType = cast(tensor.getType()); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 6ef947d890cd..18e8fb449ef5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, .getResult(0); return success(); } + +// Flips an input tensor based on the values of axis list. +Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, + Value input, SmallVector axis) { + Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); + Type elementType = cast(input.getType()).getElementType(); + auto selfRank = cast(input.getType()).getRank(); + + // Only used to calculate flipped values, i.e. those on the flip axes. Other + // dims won't be used. + SmallVector dims = getTensorSizes(rewriter, loc, input); + for (auto flipDim : axis) + dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); + + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, input), elementType); + + SmallVector iteratorTypes(selfRank, + utils::IteratorType::parallel); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext())); + Value flipped = + rewriter + .create( + loc, input.getType(), input, initTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (auto i = 0; i < selfRank; i++) + indices.push_back(b.create(loc, i)); + for (auto flipDim : axis) { + indices[flipDim] = b.create(loc, dims[flipDim], + indices[flipDim]); + } + Value res = b.create(loc, input, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + return flipped; +} diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index bf8434c93da1..3e1e73d9c418 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unable to extract the scalar constant"); + int64_t numElem = 1; + for (int64_t dim : dshape) + numElem *= dim; + if (isa(dtype)) { - tosaTensor = tosa::getConstTensor(rewriter, op, - (isFloat ? doubleValue : intValue), - dshape, dtype) - .value(); + tosaTensor = + tosa::getConstTensor( + rewriter, op, + SmallVector(numElem, (isFloat ? doubleValue : intValue)), + dshape, dtype) + .value(); } else if (auto intType = dyn_cast(dtype)) { auto w = intType.getWidth(); if (w != 1 && w != 32 && w != 64) @@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } bool d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } else if (w == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } else if (w == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "of destination type"); } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } } else { return rewriter.notifyMatchFailure(op, "Usupported element type"); @@ -4179,6 +4188,124 @@ class SimplifyAten_IndexPutImplOp }; // Handle Aten_IndexPutImplOp on 1d tensors +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexSelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Not a tensor type. + auto input = adaptor.getSelf(); + auto inputType = dyn_cast(input.getType()); + if (!inputType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType inputs are currently supported"); + + auto index = adaptor.getIndex(); + auto indexType = dyn_cast(index.getType()); + + if (!indexType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType indices are currently supported"); + + auto inputShape = inputType.getShape(); + int inputRank = inputType.getRank(); + + if (indexType.getRank() == 0) + return rewriter.notifyMatchFailure( + op, "Rank 0 index tensor is currently not supported"); + + // Dynamic shape check + if (!inputType.hasStaticShape() || !indexType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "AtenIndexSelectOp: support for dynamic input " + "shape not implemented"); + + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexType.getShape(), + rewriter.getIntegerType(32)), + index); + } + + // Get positive dim + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Value `dim` should be a torch constant int"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "Value `dim` is invalid"); + + // Get the output type + auto outType = getTypeConverter()->convertType(op.getType()); + + // Reshape and expand the index tensor to have same rank and same dimensions + // (except for the targeted dim) as the input + // + // For example: + // Input shape = (4, 5, 6) + // Index vector shape = (2) + // Targeted dim = 1 + // Reshaped and expanded index vector shape = (4, 2, 6) + // + // By reshaping and expanding the index vector, we can supply it into the + // gather op to mimic the functionality of aten.index_select + SmallVector indicesInputRankShape; + for (int64_t i = 0; i < inputRank; i++) { + if (i == dim) { + indicesInputRankShape.push_back(indexType.getShape()[0]); + } else { + indicesInputRankShape.push_back(1); + } + } + + auto indicesInputRankType = + RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape), + rewriter.getIntegerType(32)); + + auto reshapedIndices = rewriter.create( + op->getLoc(), indicesInputRankType, index, + rewriter.getDenseI64ArrayAttr(indicesInputRankShape)); + + SmallVector tileShape(indicesInputRankShape); + SmallVector expandedIndicesShape(indicesInputRankShape); + for (int64_t i = 0; i < inputRank; i++) { + if (tileShape[i] == 1 && i != dim) { + tileShape[i] = inputShape[i]; + expandedIndicesShape[i] = inputShape[i]; + } else { + tileShape[i] = 1; + } + } + + auto tileType = + RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape), + rewriter.getIntegerType(32)); + + auto expandedIndices = rewriter.create( + op->getLoc(), tileType, reshapedIndices.getResult(), + rewriter.getDenseI64ArrayAttr(tileShape)); + + // convert torch style index and dim into tf style indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = tosa::convertTorchIndexToTfIndices( + rewriter, op, input, expandedIndices.getResult(), dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert TorchIndex To TfIndices failed"); + + // do the tf gathernd algorithm with tf style indices as input. + auto result = + tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value()); + + if (!result) { + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + } + rewriter.replaceOp(op, {result.value()}); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, @@ -5755,7 +5882,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { }; template -class ConvertAtenFillScalarOp : public OpConversionPattern { +class ConvertAtenFillOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -5771,20 +5898,48 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { op, "Only Tensor types with static shapes are currently supported"); Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { + if (!outElemTy.isIntOrFloat()) return rewriter.notifyMatchFailure( op, "Only floating-point or integer datatype legalization supported"); + + Value fillValueTargetTensor; + if constexpr (std::is_same()) { + // Reshape value tensor to have same rank and shape as input + auto inputRank = + cast(adaptor.getSelf().getType()).getRank(); + + auto fillValue = adaptor.getValue(); + auto fillValueType = dyn_cast(fillValue.getType()); + if (!fillValueType) + return rewriter.notifyMatchFailure(op, "Fill value is not a tensor"); + auto fillValueElemTy = fillValueType.getElementType(); + + SmallVector fillValueMatchedInputRankShape(inputRank, 1); + + auto fillValueMatchedInputRankType = RankedTensorType::get( + makeShapeTorchCompatible(fillValueMatchedInputRankShape), + fillValueElemTy); + + auto fillValueMatchedInputRankTensor = rewriter.create( + op->getLoc(), fillValueMatchedInputRankType, fillValue, + rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + + fillValueTargetTensor = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), + fillValueElemTy), + fillValueMatchedInputRankTensor.getResult(), + makeShapeTorchCompatible(outType.getShape())); + } else { + if (failed(torchScalarToTosaTensor( + rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy, + makeShapeTorchCompatible(outType.getShape())))) + return rewriter.notifyMatchFailure( + op, "Fill value must be a scalar constant"); } - Value constOp; - if (failed(torchScalarToTosaTensor( - rewriter, op, op.getValue(), constOp, outElemTy, - makeShapeTorchCompatible(outType.getShape())))) - return rewriter.notifyMatchFailure( - op, "Supplied value must be a Scalar constant"); - auto newOp = - rewriter.createOrFold(op.getLoc(), outType, constOp); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp(op, outType, + fillValueTargetTensor); return success(); } @@ -6616,6 +6771,127 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.flip +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenFlipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are currently supported"); + + SmallVector dims; + if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure( + op, "Only constant dims are currently supported"); + + auto selfRank = selfTy.getRank(); + + auto resultTy = getTypeConverter()->convertType(op.getType()); + Value result = self; + + for (auto &dim : dims) { + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Not all dims are valid"); + + result = rewriter.create(op->getLoc(), resultTy, result, + static_cast(dim)); + } + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.round: +// Rounds elements of input to the nearest integer. +// Implements "round half to even" to break ties when a number is equidistant +// from two integers. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRoundOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To round to the nearest integer, we will consider the fractional part of + // the input element (= input element - integer part of element). If the + // fractional part is smaller than 0.5, round the number down. If the + // fractional part is 0.5, apply "round half to even" rule. If the fractional + // part is greater than 0.5, round up. + // + // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): + // res = floor(input) + // else: + // res = ceil(input) + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only tensor types supported"); + + auto resultTy = + cast(getTypeConverter()->convertType(op.getType())); + + auto boolTy = + RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1)); + + auto resultElemTy = resultTy.getElementType(); + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, resultElemTy).value(); + + auto two = + tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); + + auto floorInput = + rewriter.create(op->getLoc(), resultTy, self); + + // input - floor(input) + auto fractionalPart = rewriter.create( + op->getLoc(), resultTy, self, floorInput.getResult()); + + auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); + + auto floorInputDivByTwo = rewriter.create( + op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + + auto floorDivResult = rewriter.create( + op->getLoc(), resultTy, floorInputDivByTwo.getResult()); + + // (floor(input) // 2) * 2 + auto evenComparison = rewriter.create( + op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + + // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 + auto floorInputEven = rewriter.create( + op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult()); + + auto fracEqualOneHalf = rewriter.create( + op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); + + auto fracLtOneHalf = rewriter.create( + op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); + + // (frac == 0.5) && (floor(input) % 2 == 0) + auto fracEqualOneHalfCond = rewriter.create( + op->getLoc(), boolTy, fracEqualOneHalf.getResult(), + floorInputEven.getResult()); + + // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) + auto floorResultCond = rewriter.create( + op->getLoc(), boolTy, fracLtOneHalf.getResult(), + fracEqualOneHalfCond.getResult()); + + rewriter.replaceOpWithNewOp( + op, resultTy, floorResultCond.getResult(), floorInput.getResult(), + ceilInput.getResult()); + + return success(); +} + // Template to create supporting diagonal mask tensor for aten.diagonal template Value createDiagonalMask(PatternRewriter &rewriter, Operation *op, @@ -6799,6 +7075,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + } // namespace // ----------------------------------------------------------------------------- @@ -7047,12 +7324,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN -#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \ +#define INSERT_FILL_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); - INSERT_FILL_SCALAR_PATTERN(AtenFillScalarOp); -#undef INSERT_FILL_SCALAR_PATTERN + patterns.add>(typeConverter, context); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); +#undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ target.addIllegalOp(); \ @@ -7125,6 +7403,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); INSERT_ATENOP_PATTERN(AtenTrilOp); INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 5ea4e4bc47dc..6e79fbc5df15 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -23,6 +23,15 @@ namespace tosa { using namespace mlir::torch::Torch; +// This function is a helper for `convertTorchIndexToTfIndices`. +// +// We convert PyTorch index to TensorFlow-style indices so that we can use +// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather +// and Scatter operators to TOSA using TensorFlow-style indices. +// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow +// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want +// to gather/scatter elements, while in TensorFlow, the indices point directly +// to positions that you want to gather/scatter elements. std::optional createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, SmallVector indicesOneDimShape, int32_t dim, @@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, unsigned indexRank = indexShape.size(); SmallVector indicesVec; // input vec to create tosaConstant SmallVector indicesMetaElement; // torch.meshgrid inputs - int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid) // Create torch.meshgrid inputs // Example: indexShape=[1,4,2] // dim0: indicesMetaElement = torch.arange(0, 1) = [0] // dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3] // dim2: indicesMetaElement = torch.arange(0, 2) = [0,1] - for (int i = 0; i < indexShape[dim]; i++) { + for (int i = 0; i < indexShape[dim]; i++) indicesMetaElement.push_back(i); - } - - // Compute total number of meta element repeat times: - // = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim - // dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8 - // dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2 - // dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4 - for (int i = 0; i < static_cast(indexRank); i++) { - if (i == dim) { - continue; - } else { - indicesMetaElementRepeatTimes *= indexShape[i]; - } - } - if (dim != static_cast(indexShape.size()) - 1) { - // Create one dim indices for index except for last dim - // Create indices raw vector. - // torch.stack(torch.meshgrid) - // dim0: indicesVec = [0 0 0 0 0 0 0 0] - // dim0: indicesVec = [0 0 1 1 2 2 3 3] + int preDimMetaElementRepeatTimes = 1; + int postDimMetaElementRepeatTimes = 1; + + // Compute total number of times meta element range should repeat + // = product(indexShape[0:dim]) + // dim0: preDimMetaElementRepeatTimes = 1 + // dim1: preDimMetaElementRepeatTimes = 1 + // dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4 + for (int i = 0; i < dim; i++) + preDimMetaElementRepeatTimes *= indexShape[i]; + + // Compute total number of times meta element repeat + // = product(indexShape[dim+1:indexRank]) + // dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8 + // dim1: postDimMetaElementRepeatTimes = 2 + // dim2: postDimMetaElementRepeatTimes = 1 + for (int i = dim + 1; i < static_cast(indexRank); i++) + postDimMetaElementRepeatTimes *= indexShape[i]; + + // Example using dim1: + // preDimMetaElementRepeatTimes = 1 + // postDimMetaElementRepeatTimes = 2 + // Using postDimMetaElementRepeatTimes, we get the meta element range: + // [0 0 1 1 2 2 3 3] + // Using preDimMetaElementRepeatTimes, we get the full one dim indices: + // [0 0 1 1 2 2 3 3] + // + // Let's use a clearer example: + // indexShape = [3, 4, 2] + // Target dim = 1 + // => preDimMetaElementRepeatTimes = 3 + // postDimMetaElementRepeatTimes = 2 + // Using postDimMetaElementRepeatTimes, we get the meta element range: + // [0 0 1 1 2 2] + // Using preDimMetaElementRepeatTimes, we get the full one dim indices: + // [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2] + for (int i = 0; i < preDimMetaElementRepeatTimes; i++) { for (size_t elementId = 0; elementId < indicesMetaElement.size(); elementId++) { - for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { - indicesVec.push_back(indicesMetaElement[elementId]); - } - } - } else { // Create the one dim indices for last dim of index - // Create indices raw vector - // dim2: indicesVec= [0 1 0 1 0 1 0 1] - // Caution: indicesVec != [0 0 0 0 1 1 1 1] - for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { - for (size_t elementId = 0; elementId < indicesMetaElement.size(); - elementId++) { + for (int j = 0; j < postDimMetaElementRepeatTimes; j++) { indicesVec.push_back(indicesMetaElement[elementId]); } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fa4f4d68e990..329b16447617 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1094,6 +1094,7 @@ "ContiguousModule_basic", "Conv1dNoPaddingGroupModule_basic", "Conv1dNoPaddingModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", @@ -1721,7 +1722,34 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ArangeZeroElementOutputModule_basic", + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", + "AtenLinalgCrossBroadcast_basic", + "AtenLinalgCrossCustomDim_basic", + "AtenLinalgCrossFloat_basic", + "AtenLinalgCrossInt_basic", + "AtenLinalgCrossNegativeDim_basic", "BinaryCrossEntropyWithLogitsStaticModule_basic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", "DiagonalWithStaticShapeModule_basic", "EinsumStaticDiagonalDimensionModule_basic", "ElementwiseAtenFloorDivideBroadcastModule_basic", @@ -1879,7 +1907,6 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", - "ArangeZeroElementOutputModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", "ArangeFloatModule_basic", @@ -1944,6 +1971,7 @@ "ConstantPadNdPartialStaticModule_basic", "ConstantPadNdStaticModule_basic", "ContiguousModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv1dNoPaddingGroupModule_basic", "Conv1dNoPaddingModule_basic", "Conv2dBiasNoPaddingModule_basic", @@ -2196,6 +2224,7 @@ "IndexTensorModule3dInputStatic_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", + "IndexSelectStaticModule_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "LayerNormNormalizeOverAllDimsModule_basic", @@ -2308,8 +2337,6 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - "RepeatInterleaveFillModule_basic", - "RepeatInterleaveStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ReshapeAliasCollapseModule_basic", @@ -2334,8 +2361,8 @@ "SiluModule_basic", "SliceOutOfLowerBoundStartIndexStaticModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", - "SliceSizeTwoStepDivisibleStaticModule_basic", "SliceStaticModule_basic", + "SliceSizeTwoStepDivisibleStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", @@ -2531,7 +2558,6 @@ "IndexSelectWholeTensorModule_basic", "IndexSelectNegativeDimModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexSelectStaticModule_basic", "IndexSelectSingleIdxModule_basic", "IndexSelectTwoIdxModule_basic", "LinalgVectorNormModule_basic", @@ -3296,7 +3322,6 @@ "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", - "ScatterAddStaticModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", @@ -3461,6 +3486,23 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "ArangeZeroElementOutputModule_basic", + "NumpyTRank0Module_basic", + "Permute0RankModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceStartEqEndModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SplitWithSizes_Module_basic", + "ElementwiseCreateComplexModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "AtenPolarDoubleModule_basic", "AtenPolarFloatModule_basic", @@ -3468,11 +3510,6 @@ "HstackBasicFloatModule_basic", "HstackBasicIntFloatModule_basic", "HstackBasicIntModule_basic", - "Rot90BasicModule_basic", - "Rot90DynamicDimsModule_basic", - "Rot90MultipleRotationsModule_basic", - "Rot90NegativeEvenRotationsModule_basic", - "Rot90NegativeOddRotationsModule_basic", "AtenIntMM_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", @@ -3491,7 +3528,6 @@ "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluTrainModule_basic", "ElementwiseRreluTrainStaticModule_basic", - "FakeQuantizePerTensorAffineCachemaskModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3558,12 +3594,6 @@ "AtenIntTensorCharDtypeModule_basic", "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", - "AtenLinalgCrossBroadcast_basic", - "AtenLinalgCrossCustomDim_basic", - "AtenLinalgCrossDynamic_basic", - "AtenLinalgCrossFloat_basic", - "AtenLinalgCrossInt_basic", - "AtenLinalgCrossNegativeDim_basic", "AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8_basic", "AtenMatmulQint8MV_basic", @@ -3576,8 +3606,6 @@ "AtenMmQuint8_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", - "AtenRoundFloatHalfToEvenModule_basic", - "AtenRoundFloatModule_basic", "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", @@ -3619,6 +3647,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -3738,20 +3767,6 @@ "EqIntModule_basic", "ExpandModule_basic", "ExponentialModule_basic", - "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - "FakeQuantizePerTensorAffineModule_basic", - "FakeQuantizePerTensorAffineRoundToEvenModule_basic", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", - "Fill_TensorFloat64WithFloat32Static_basic", - "Fill_TensorFloat64WithFloat32_basic", - "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64Static_basic", - "Fill_TensorFloat64WithInt64_basic", - "FlipModuleStaticShape_basic", - "FlipModule_basic", - "FlipNegativeIndexModule_basic", "FloatImplicitModule_basic", "FullLikeModuleInt2D_basic", "FullLikeModuleInt3D_basic", @@ -3807,15 +3822,7 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectDynamicInputSizeModule_basic", - "IndexSelectDynamicModulebasic", - "IndexSelectNegativeDimModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -4088,9 +4095,7 @@ "VarMeanUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", + "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", } @@ -4103,6 +4108,14 @@ } ONNX_TOSA_XFAIL_SET = { + "ArangeZeroElementOutputModule_basic", + "LinspaceEmptyModule_basic", + "RepeatInterleaveSelfIntNoDimModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "TrilIndicesAllZerosModule_basic", + "TriuIndicesAllZerosModule_basic", + "ElementwiseCreateComplexModule_basic", + "ReduceAllDimFloatModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "HstackBasicComplexModule_basic", @@ -4265,8 +4278,6 @@ "AtenPolarDoubleModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", - "AtenRoundFloatHalfToEvenModule_basic", - "AtenRoundFloatModule_basic", "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", @@ -4310,8 +4321,6 @@ "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "CeilFloatModule_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", @@ -4330,6 +4339,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -4524,7 +4534,6 @@ "ElementwiseWhereSelfModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleF16_basic", - "EmbeddingModuleI32Static_basic", "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", "EmptyLikeMemoryFormatModule_basic", @@ -4618,12 +4627,6 @@ "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicInputSizeModule_basic", "IndexSelectDynamicModulebasic", - "IndexSelectNegativeDimModule_basic", - "IndexSelectRank0IdxModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinModule3dInput_basic", @@ -4641,10 +4644,8 @@ "IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", - "IndexTensorNegativeIndexModule_basic", "IndexTensorSelectDimModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticModule_basic", "IndexTensorStaticNonContiguousWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -4942,7 +4943,6 @@ "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", "SelectScattertModule_basic", "SelectScattertStaticModule_basic", "SignAndLogarithmOfDeterminantModule_F32", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 9024708fc654..81b3b842c8cc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1184,6 +1184,33 @@ def Conv1dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.float32, True), + ([4, 1, 3], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv1d( + inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4 + ) + + +@register_test_case( + module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule() +) +def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(4, 1, 3) + module.forward(inputVec, weight) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 7ca44e25dc3f..cfccf5d69871 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar // CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[AXIS:.*]] = torch.constant.int 1 - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] - // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] - // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 - // CHECK: %[[STR:.*]] = torch.constant.str "add" - // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[FIVE:.*]] = torch.constant.int 1 +// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1> +// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64> +// CHECK: %[[STR:.*]] = torch.constant.str "sum" +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } @@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, // CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[AXIS:.*]] = torch.constant.int 1 - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] - // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] - // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 - // CHECK: %[[STR:.*]] = torch.constant.str "multiply" - // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[FIVE:.*]] = torch.constant.int 1 +// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1> +// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64> +// CHECK: %[[STR:.*]] = torch.constant.str "prod" +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } diff --git a/test/Conversion/TorchToLinalg/squeeze.mlir b/test/Conversion/TorchToLinalg/squeeze.mlir new file mode 100644 index 000000000000..a8922eed5a9d --- /dev/null +++ b/test/Conversion/TorchToLinalg/squeeze.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic +func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} { + // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index + // CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1" + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor into tensor + // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 4647f1e7429b..085fd49a92af 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2127,3 +2127,116 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> %0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32> return %0 : !torch.vtensor<[5,6,2],si32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_select( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32> +// CHECK: } +func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32> + return %0 : !torch.vtensor<[4,5,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { + %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.flip( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + return %1 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.round( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +}