From e9ed4af9ced23c201f3d72b81f4ec3060bc99d8e Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Fri, 4 Oct 2024 12:24:22 -0700 Subject: [PATCH 01/12] [TOSA] Add legalization for aten.index_select (#3760) - Add Torch to TOSA legalization for aten.index_select - Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to correctly convert Torch indices to TF-style indices, which is used in convertGatherNdOp - Update e2e tests in xfail_sets.py - Update basic.mlir with new LIT test for aten.index_select Signed-off-by: Justin Ngo Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 119 ++++++++++++++++++ .../TorchToTosa/TosaLegalizeCommon.cpp | 81 +++++++----- projects/pt1/e2e_testing/xfail_sets.py | 55 ++++---- test/Conversion/TorchToTosa/basic.mlir | 32 +++++ 4 files changed, 230 insertions(+), 57 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e451f73826e6..5664ebc7152d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3821,6 +3821,124 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexSelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Not a tensor type. + auto input = adaptor.getSelf(); + auto inputType = dyn_cast(input.getType()); + if (!inputType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType inputs are currently supported"); + + auto index = adaptor.getIndex(); + auto indexType = dyn_cast(index.getType()); + + if (!indexType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType indices are currently supported"); + + auto inputShape = inputType.getShape(); + int inputRank = inputType.getRank(); + + if (indexType.getRank() == 0) + return rewriter.notifyMatchFailure( + op, "Rank 0 index tensor is currently not supported"); + + // Dynamic shape check + if (!inputType.hasStaticShape() || !indexType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "AtenIndexSelectOp: support for dynamic input " + "shape not implemented"); + + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexType.getShape(), + rewriter.getIntegerType(32)), + index); + } + + // Get positive dim + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Value `dim` should be a torch constant int"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "Value `dim` is invalid"); + + // Get the output type + auto outType = getTypeConverter()->convertType(op.getType()); + + // Reshape and expand the index tensor to have same rank and same dimensions + // (except for the targeted dim) as the input + // + // For example: + // Input shape = (4, 5, 6) + // Index vector shape = (2) + // Targeted dim = 1 + // Reshaped and expanded index vector shape = (4, 2, 6) + // + // By reshaping and expanding the index vector, we can supply it into the + // gather op to mimic the functionality of aten.index_select + SmallVector indicesInputRankShape; + for (int64_t i = 0; i < inputRank; i++) { + if (i == dim) { + indicesInputRankShape.push_back(indexType.getShape()[0]); + } else { + indicesInputRankShape.push_back(1); + } + } + + auto indicesInputRankType = + RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape), + rewriter.getIntegerType(32)); + + auto reshapedIndices = rewriter.create( + op->getLoc(), indicesInputRankType, index, + rewriter.getDenseI64ArrayAttr(indicesInputRankShape)); + + SmallVector tileShape(indicesInputRankShape); + SmallVector expandedIndicesShape(indicesInputRankShape); + for (int64_t i = 0; i < inputRank; i++) { + if (tileShape[i] == 1 && i != dim) { + tileShape[i] = inputShape[i]; + expandedIndicesShape[i] = inputShape[i]; + } else { + tileShape[i] = 1; + } + } + + auto tileType = + RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape), + rewriter.getIntegerType(32)); + + auto expandedIndices = rewriter.create( + op->getLoc(), tileType, reshapedIndices.getResult(), + rewriter.getDenseI64ArrayAttr(tileShape)); + + // convert torch style index and dim into tf style indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = tosa::convertTorchIndexToTfIndices( + rewriter, op, input, expandedIndices.getResult(), dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert TorchIndex To TfIndices failed"); + + // do the tf gathernd algorithm with tf style indices as input. + auto result = + tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value()); + + if (!result) { + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + } + rewriter.replaceOp(op, {result.value()}); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, @@ -6240,6 +6358,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); INSERT_ATENOP_PATTERN(AtenTrilOp); INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index b3e7f480a327..4df8a221d556 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -23,6 +23,15 @@ namespace tosa { using namespace mlir::torch::Torch; +// This function is a helper for `convertTorchIndexToTfIndices`. +// +// We convert PyTorch index to TensorFlow-style indices so that we can use +// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather +// and Scatter operators to TOSA using TensorFlow-style indices. +// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow +// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want +// to gather/scatter elements, while in TensorFlow, the indices point directly +// to positions that you want to gather/scatter elements. std::optional createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, SmallVector indicesOneDimShape, int32_t dim, @@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, unsigned indexRank = indexShape.size(); SmallVector indicesVec; // input vec to create tosaConstant SmallVector indicesMetaElement; // torch.meshgrid inputs - int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid) // Create torch.meshgrid inputs // Example: indexShape=[1,4,2] // dim0: indicesMetaElement = torch.arange(0, 1) = [0] // dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3] // dim2: indicesMetaElement = torch.arange(0, 2) = [0,1] - for (int i = 0; i < indexShape[dim]; i++) { + for (int i = 0; i < indexShape[dim]; i++) indicesMetaElement.push_back(i); - } - - // Compute total number of meta element repeat times: - // = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim - // dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8 - // dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2 - // dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4 - for (int i = 0; i < static_cast(indexRank); i++) { - if (i == dim) { - continue; - } else { - indicesMetaElementRepeatTimes *= indexShape[i]; - } - } - if (dim != static_cast(indexShape.size()) - 1) { - // Create one dim indices for index except for last dim - // Create indices raw vector. - // torch.stack(torch.meshgrid) - // dim0: indicesVec = [0 0 0 0 0 0 0 0] - // dim0: indicesVec = [0 0 1 1 2 2 3 3] + int preDimMetaElementRepeatTimes = 1; + int postDimMetaElementRepeatTimes = 1; + + // Compute total number of times meta element range should repeat + // = product(indexShape[0:dim]) + // dim0: preDimMetaElementRepeatTimes = 1 + // dim1: preDimMetaElementRepeatTimes = 1 + // dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4 + for (int i = 0; i < dim; i++) + preDimMetaElementRepeatTimes *= indexShape[i]; + + // Compute total number of times meta element repeat + // = product(indexShape[dim+1:indexRank]) + // dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8 + // dim1: postDimMetaElementRepeatTimes = 2 + // dim2: postDimMetaElementRepeatTimes = 1 + for (int i = dim + 1; i < static_cast(indexRank); i++) + postDimMetaElementRepeatTimes *= indexShape[i]; + + // Example using dim1: + // preDimMetaElementRepeatTimes = 1 + // postDimMetaElementRepeatTimes = 2 + // Using postDimMetaElementRepeatTimes, we get the meta element range: + // [0 0 1 1 2 2 3 3] + // Using preDimMetaElementRepeatTimes, we get the full one dim indices: + // [0 0 1 1 2 2 3 3] + // + // Let's use a clearer example: + // indexShape = [3, 4, 2] + // Target dim = 1 + // => preDimMetaElementRepeatTimes = 3 + // postDimMetaElementRepeatTimes = 2 + // Using postDimMetaElementRepeatTimes, we get the meta element range: + // [0 0 1 1 2 2] + // Using preDimMetaElementRepeatTimes, we get the full one dim indices: + // [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2] + for (int i = 0; i < preDimMetaElementRepeatTimes; i++) { for (size_t elementId = 0; elementId < indicesMetaElement.size(); elementId++) { - for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { - indicesVec.push_back(indicesMetaElement[elementId]); - } - } - } else { // Create the one dim indices for last dim of index - // Create indices raw vector - // dim2: indicesVec= [0 1 0 1 0 1 0 1] - // Caution: indicesVec != [0 0 0 0 1 1 1 1] - for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { - for (size_t elementId = 0; elementId < indicesMetaElement.size(); - elementId++) { + for (int j = 0; j < postDimMetaElementRepeatTimes; j++) { indicesVec.push_back(indicesMetaElement[elementId]); } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7dd6f3cd50a7..237a2ac96651 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1663,6 +1663,17 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenLinalgCrossBroadcast_basic", + "AtenLinalgCrossCustomDim_basic", + "AtenLinalgCrossFloat_basic", + "AtenLinalgCrossInt_basic", + "AtenLinalgCrossNegativeDim_basic", + "BinaryCrossEntropyWithLogitsStaticModule_basic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", "DiagonalWithStaticShapeModule_basic", "EinsumStaticDiagonalDimensionModule_basic", "ElementwiseAtenFloorDivideBroadcastModule_basic", @@ -2342,6 +2353,13 @@ } ) - { ### Test failing in make_fx_tosa but not in tosa + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", @@ -3205,6 +3223,17 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SplitWithSizes_Module_basic", + "ElementwiseCreateComplexModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "AtenPolarDoubleModule_basic", "AtenPolarFloatModule_basic", @@ -3302,12 +3331,6 @@ "AtenIntTensorCharDtypeModule_basic", "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", - "AtenLinalgCrossBroadcast_basic", - "AtenLinalgCrossCustomDim_basic", - "AtenLinalgCrossDynamic_basic", - "AtenLinalgCrossFloat_basic", - "AtenLinalgCrossInt_basic", - "AtenLinalgCrossNegativeDim_basic", "AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8_basic", "AtenMatmulQint8MV_basic", @@ -3551,15 +3574,7 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectDynamicInputSizeModule_basic", - "IndexSelectDynamicModulebasic", - "IndexSelectNegativeDimModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -3848,6 +3863,8 @@ } ONNX_TOSA_XFAIL_SET = { + "ElementwiseCreateComplexModule_basic", + "ReduceAllDimFloatModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "HstackBasicComplexModule_basic", @@ -4269,7 +4286,6 @@ "ElementwiseWhereSelfModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleF16_basic", - "EmbeddingModuleI32Static_basic", "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", "EmptyLikeMemoryFormatModule_basic", @@ -4363,12 +4379,6 @@ "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicInputSizeModule_basic", "IndexSelectDynamicModulebasic", - "IndexSelectNegativeDimModule_basic", - "IndexSelectRank0IdxModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinModule3dInput_basic", @@ -4386,10 +4396,8 @@ "IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", - "IndexTensorNegativeIndexModule_basic", "IndexTensorSelectDimModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticModule_basic", "IndexTensorStaticNonContiguousWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -4688,7 +4696,6 @@ "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", "SelectScattertModule_basic", "SelectScattertStaticModule_basic", "SignAndLogarithmOfDeterminantModule_F32", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 90d48489092e..6690868af510 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1885,3 +1885,35 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> %0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32> return %0 : !torch.vtensor<[5,6,2],si32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_select( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32> +// CHECK: } +func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32> + return %0 : !torch.vtensor<[4,5,2],f32> +} From 53f7532e76b29a660ab989b9292a93521d135881 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 4 Oct 2024 14:48:02 -0700 Subject: [PATCH 02/12] Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767) Reverted due to downstream model changes. Will reland with fixes post integration. This reverts commit 6e8c7bed4b12117764274e79bc60a93443d5bdd5. --- .../TorchToLinalg/Uncategorized.cpp | 19 ------------------- .../Conversion/TorchToLinalg/elementwise.mlir | 12 +++++++----- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0532b4b19d94..0f6f92bd7c2c 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1627,25 +1627,6 @@ class ConvertElementwiseOp : public ConversionPattern { operands, [](Value v) { return isa(v.getType()); })); auto resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); - bool isScalarOp = resultType.getShape().size() == 0; - if (isScalarOp) { - // for elementwise ops that are actually rank0 scalar computations, - // perform the payload outside a linalg generic op. - SmallVector payloadArgs; - for (auto t : tensorOperands) { - payloadArgs.push_back(rewriter.create(loc, t)); - } - Value scalarResult = createLinalgPayloadCalculationForElementwiseOp( - rewriter, loc, getTypeConverter(), payloadArgs, op, operands); - if (!scalarResult) - return rewriter.notifyMatchFailure( - op, "Failed to create payload for scalar elementwise op"); - Value rank0Result = - createInitTensor(rewriter, loc, ValueRange{}, - resultType.getElementType(), scalarResult); - rewriter.replaceOpWithNewOp(op, resultType, rank0Result); - return success(); - } bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index ecf4caa58389..aa2be74f5d7e 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -4,11 +4,13 @@ // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor -// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor) -> tensor -// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor to tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { +// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): +// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32 +// CHECK: linalg.yield %[[TANH]] : f32 +// CHECK: } -> tensor +// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: } From f4840ed886f39db5bcb3bf20d37e79f8c4657746 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Sat, 5 Oct 2024 22:22:41 -0700 Subject: [PATCH 03/12] [ONNX] Fix onnx.ScatterElements with AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext dialect (#3754) - To fix issue onnx.ScatterElements: https://github.com/nod-ai/SHARK-ModelDev/issues/823 - E2E test: https://github.com/nod-ai/SHARK-TestSuite/pull/363 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 11 ++++-- projects/pt1/e2e_testing/xfail_sets.py | 1 - .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 38 ++++++++++--------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 95413b080343..a7f357349ecf 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // TODO: Implement max and min cases if (reduction == "mul") { - reduction = "multiply"; + reduction = "prod"; } else if (reduction == "max" || reduction == "min") { return rewriter.notifyMatchFailure( binder.op, "max/min reduction unsupported for scatter elements"); + } else if (reduction == "add") { + reduction = "sum"; } Value cstStrReduction = rewriter.create(binder.getLoc(), reduction); - - rewriter.replaceOpWithNewOp( + Value cstTrue = + rewriter.create(binder.getLoc(), true); + rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, updates, - cstStrReduction); + cstStrReduction, cstTrue); return success(); }); patterns.onOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 237a2ac96651..bd8d1994d9b4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3084,7 +3084,6 @@ "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", - "ScatterAddStaticModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index bd2a92874843..30fd60dbde3a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar // CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[AXIS:.*]] = torch.constant.int 1 - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] - // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] - // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 - // CHECK: %[[STR:.*]] = torch.constant.str "add" - // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[FIVE:.*]] = torch.constant.int 1 +// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1> +// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64> +// CHECK: %[[STR:.*]] = torch.constant.str "sum" +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } @@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, // CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[AXIS:.*]] = torch.constant.int 1 - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] - // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] - // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 - // CHECK: %[[STR:.*]] = torch.constant.str "multiply" - // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[FIVE:.*]] = torch.constant.int 1 +// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1> +// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64> +// CHECK: %[[STR:.*]] = torch.constant.str "prod" +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } From b08d08682f2b3a32ba0b9c0130396cb9d684b135 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Mon, 7 Oct 2024 10:28:26 -0700 Subject: [PATCH 04/12] [TOSA] Add legalization for fill, flip, and round (#3768) - Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and aten.round - Fix torchScalarToTosaTensor function to correctly convert Torch scalar input to TOSA tensor - Update xfail_sets.py with new e2e results - Update basic.mlir with LIT tests for new ops Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 211 ++++++++++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 62 +++--- test/Conversion/TorchToTosa/basic.mlir | 81 ++++++++ 3 files changed, 298 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5664ebc7152d..77672181416f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unable to extract the scalar constant"); + int64_t numElem = 1; + for (int64_t dim : dshape) + numElem *= dim; + if (isa(dtype)) { - tosaTensor = tosa::getConstTensor(rewriter, op, - (isFloat ? doubleValue : intValue), - dshape, dtype) - .value(); + tosaTensor = + tosa::getConstTensor( + rewriter, op, + SmallVector(numElem, (isFloat ? doubleValue : intValue)), + dshape, dtype) + .value(); } else if (auto intType = dyn_cast(dtype)) { auto w = intType.getWidth(); if (w != 1 && w != 32 && w != 64) @@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } bool d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } else if (w == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } else if (w == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "of destination type"); } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } } else { return rewriter.notifyMatchFailure(op, "Usupported element type"); @@ -5320,7 +5329,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { }; template -class ConvertAtenFillScalarOp : public OpConversionPattern { +class ConvertAtenFillOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -5336,18 +5345,48 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { op, "Only Tensor types with static shapes are currently supported"); Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { + if (!outElemTy.isIntOrFloat()) return rewriter.notifyMatchFailure( op, "Only floating-point or integer datatype legalization supported"); + + Value fillValueTargetTensor; + if constexpr (std::is_same()) { + // Reshape value tensor to have same rank and shape as input + auto inputRank = + cast(adaptor.getSelf().getType()).getRank(); + + auto fillValue = adaptor.getValue(); + auto fillValueType = dyn_cast(fillValue.getType()); + if (!fillValueType) + return rewriter.notifyMatchFailure(op, "Fill value is not a tensor"); + auto fillValueElemTy = fillValueType.getElementType(); + + SmallVector fillValueMatchedInputRankShape(inputRank, 1); + + auto fillValueMatchedInputRankType = RankedTensorType::get( + makeShapeTorchCompatible(fillValueMatchedInputRankShape), + fillValueElemTy); + + auto fillValueMatchedInputRankTensor = rewriter.create( + op->getLoc(), fillValueMatchedInputRankType, fillValue, + rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + + fillValueTargetTensor = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), + fillValueElemTy), + fillValueMatchedInputRankTensor.getResult(), + makeShapeTorchCompatible(outType.getShape())); + } else { + if (failed(torchScalarToTosaTensor( + rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy, + makeShapeTorchCompatible(outType.getShape())))) + return rewriter.notifyMatchFailure( + op, "Fill value must be a scalar constant"); } - Value constOp; - if (failed(torchScalarToTosaTensor( - rewriter, op, op.getValue(), constOp, outElemTy, - makeShapeTorchCompatible(outType.getShape())))) - return rewriter.notifyMatchFailure( - op, "Supplied value must be a Scalar constant"); - rewriter.replaceOpWithNewOp(op, outType, constOp); + rewriter.replaceOpWithNewOp(op, outType, + fillValueTargetTensor); return success(); } @@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.flip +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenFlipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are currently supported"); + + SmallVector dims; + if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure( + op, "Only constant dims are currently supported"); + + auto selfRank = selfTy.getRank(); + + auto resultTy = getTypeConverter()->convertType(op.getType()); + Value result = self; + + for (auto &dim : dims) { + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Not all dims are valid"); + + result = rewriter.create(op->getLoc(), resultTy, result, + static_cast(dim)); + } + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.round: +// Rounds elements of input to the nearest integer. +// Implements "round half to even" to break ties when a number is equidistant +// from two integers. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRoundOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To round to the nearest integer, we will consider the fractional part of + // the input element (= input element - integer part of element). If the + // fractional part is smaller than 0.5, round the number down. If the + // fractional part is 0.5, apply "round half to even" rule. If the fractional + // part is greater than 0.5, round up. + // + // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): + // res = floor(input) + // else: + // res = ceil(input) + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only tensor types supported"); + + auto resultTy = + cast(getTypeConverter()->convertType(op.getType())); + + auto boolTy = + RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1)); + + auto resultElemTy = resultTy.getElementType(); + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, resultElemTy).value(); + + auto two = + tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); + + auto floorInput = + rewriter.create(op->getLoc(), resultTy, self); + + // input - floor(input) + auto fractionalPart = rewriter.create( + op->getLoc(), resultTy, self, floorInput.getResult()); + + auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); + + auto floorInputDivByTwo = rewriter.create( + op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + + auto floorDivResult = rewriter.create( + op->getLoc(), resultTy, floorInputDivByTwo.getResult()); + + // (floor(input) // 2) * 2 + auto evenComparison = rewriter.create( + op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + + // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 + auto floorInputEven = rewriter.create( + op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult()); + + auto fracEqualOneHalf = rewriter.create( + op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); + + auto fracLtOneHalf = rewriter.create( + op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); + + // (frac == 0.5) && (floor(input) % 2 == 0) + auto fracEqualOneHalfCond = rewriter.create( + op->getLoc(), boolTy, fracEqualOneHalf.getResult(), + floorInputEven.getResult()); + + // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) + auto floorResultCond = rewriter.create( + op->getLoc(), boolTy, fracLtOneHalf.getResult(), + fracEqualOneHalfCond.getResult()); + + rewriter.replaceOpWithNewOp( + op, resultTy, floorResultCond.getResult(), floorInput.getResult(), + ceilInput.getResult()); + + return success(); +} + // Template to create supporting diagonal mask tensor for aten.diagonal template Value createDiagonalMask(PatternRewriter &rewriter, Operation *op, @@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + } // namespace // ----------------------------------------------------------------------------- @@ -6283,11 +6444,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN -#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \ +#define INSERT_FILL_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); -#undef INSERT_FILL_SCALAR_PATTERN + patterns.add>(typeConverter, context); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); +#undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ target.addIllegalOp(); \ @@ -6359,6 +6522,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenTrilOp); INSERT_ATENOP_PATTERN(AtenDiagonalOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bd8d1994d9b4..09db1098e4b1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1663,6 +1663,22 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", "AtenLinalgCrossBroadcast_basic", "AtenLinalgCrossCustomDim_basic", "AtenLinalgCrossFloat_basic", @@ -1819,7 +1835,6 @@ "ArangeStartOutModule_basic", "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", - "ArangeZeroElementOutputModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", "ArangeFloatModule_basic", @@ -2120,7 +2135,6 @@ "NormScalarOptDimModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", - "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", "NumpyTRankNDynamicModule_basic", @@ -2132,7 +2146,6 @@ "OnesModuleInt_basic", "PadModule_basic", "PadWithNoneValModule_basic", - "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "PrimListUnpackNumMismatchModule_basic", @@ -2171,7 +2184,6 @@ "ScalarTensorInt64Module_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", "SiluModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", @@ -3222,6 +3234,12 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "ArangeZeroElementOutputModule_basic", + "NumpyTRank0Module_basic", + "Permute0RankModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceStartEqEndModule_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", @@ -3240,11 +3258,6 @@ "HstackBasicFloatModule_basic", "HstackBasicIntFloatModule_basic", "HstackBasicIntModule_basic", - "Rot90BasicModule_basic", - "Rot90DynamicDimsModule_basic", - "Rot90MultipleRotationsModule_basic", - "Rot90NegativeEvenRotationsModule_basic", - "Rot90NegativeOddRotationsModule_basic", "AtenIntMM_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", @@ -3263,7 +3276,6 @@ "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluTrainModule_basic", "ElementwiseRreluTrainStaticModule_basic", - "FakeQuantizePerTensorAffineCachemaskModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3342,8 +3354,6 @@ "AtenMmQuint8_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", - "AtenRoundFloatHalfToEvenModule_basic", - "AtenRoundFloatModule_basic", "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", @@ -3504,20 +3514,6 @@ "EqIntModule_basic", "ExpandModule_basic", "ExponentialModule_basic", - "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - "FakeQuantizePerTensorAffineModule_basic", - "FakeQuantizePerTensorAffineRoundToEvenModule_basic", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", - "Fill_TensorFloat64WithFloat32Static_basic", - "Fill_TensorFloat64WithFloat32_basic", - "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64Static_basic", - "Fill_TensorFloat64WithInt64_basic", - "FlipModuleStaticShape_basic", - "FlipModule_basic", - "FlipNegativeIndexModule_basic", "FloatImplicitModule_basic", "FullLikeModuleInt2D_basic", "FullLikeModuleInt3D_basic", @@ -3847,9 +3843,7 @@ "VarMeanUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", + "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", } @@ -3862,6 +3856,12 @@ } ONNX_TOSA_XFAIL_SET = { + "ArangeZeroElementOutputModule_basic", + "LinspaceEmptyModule_basic", + "RepeatInterleaveSelfIntNoDimModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "TrilIndicesAllZerosModule_basic", + "TriuIndicesAllZerosModule_basic", "ElementwiseCreateComplexModule_basic", "ReduceAllDimFloatModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", @@ -4026,8 +4026,6 @@ "AtenPolarDoubleModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", - "AtenRoundFloatHalfToEvenModule_basic", - "AtenRoundFloatModule_basic", "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", @@ -4071,8 +4069,6 @@ "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "CeilFloatModule_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 6690868af510..e569fed7fa93 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1917,3 +1917,84 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t %0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32> return %0 : !torch.vtensor<[4,5,2],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { + %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.flip( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + return %1 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.round( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} From f6721e599961a36d67236fce9f58cdd719c9cef4 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 8 Oct 2024 10:34:27 +0530 Subject: [PATCH 05/12] [MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (#3763) This commit adds the support for negative step values in aten.slice.Tensor op. Although, PyTorch does not allow negative step value for slice op but the Onnx.Slice op supports negative step value which eventually lowers to torch.aten.slice.Tensor op. Hence, the support is added for handling those kind of values during the Torch->Linalg lowering of aten.slice.Tensor op. Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchToLinalg/Utils.h | 4 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 49 +++++++++++++++---- lib/Conversion/TorchToLinalg/Linear.cpp | 39 +-------------- lib/Conversion/TorchToLinalg/Utils.cpp | 41 ++++++++++++++++ 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 14e9202222c6..b59d183b4084 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter, Location loc, SmallVector dimensions, Value input, Value &result); +// Flips an input tensor based on the values of axis list. +Value flipTensor(PatternRewriter &rewriter, Location loc, Value input, + SmallVector axis); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5542e0fc642f..ac1707ec23a6 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef a) { template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + int64_t &dim, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value one = rewriter.create(loc, 1); Value negone = rewriter.create(loc, -1); - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); @@ -1857,14 +1857,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } + + // If stride is negative, then flip the input tensor corresponding to that + // dim, update the stride for flipped tensor by multiplying it by -1, and + // update the offset as follows: + // flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride) + // + // For example: + // Input = [0, 1, 2, 3, 4, 5] + // stride = [-2], result_shape = [2], offset = [3] + // Result = [3, 1] + // After flipping: + // Input = [5, 4, 3, 2, 1, 0] + // stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2] + // Result = [3, 1] + + Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input, + SmallVector{dim}); + Value cstDim = rewriter.create(loc, dim); + Value zero = rewriter.create(loc, 0); + Value isNegativeStride = rewriter.create( + loc, arith::CmpIPredicate::slt, strides[dim], zero); + strides[dim] = rewriter.create(loc, strides[dim]); + Value resShapeMulStride = + rewriter.create(loc, resultShape[dim], strides[dim]); + Value inputDim = rewriter.create(loc, input, cstDim); + Value flippedOffset = + rewriter.create(loc, inputDim, resShapeMulStride); + offsets[dim] = rewriter.create( + loc, isNegativeStride, flippedOffset, offsets[dim]); + + input = rewriter.create(loc, isNegativeStride, + flippedInput, input); + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); auto sliceType = RankedTensorType::get( dynShape, resultType.getElementType(), resultType.getEncoding()); @@ -2095,12 +2127,11 @@ class ConvertAtenSliceScatterOp RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 52765411bd73..fc910fa9d3f2 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = cast(adaptor.getSelf().getType()).getRank(); - Type elementType = - cast(adaptor.getSelf().getType()).getElementType(); - Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector axis; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern { } } - // Only used to calculate flipped values, i.e. those on the flip axes. Other - // dims won't be used. - SmallVector dims = getTensorSizes(rewriter, loc, self); - for (auto flipDim : axis) - dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); - - Value initTensor = createZeroInitTensor( - rewriter, loc, getTensorSizes(rewriter, loc, self), elementType); - - SmallVector iteratorTypes( - selfRank, utils::IteratorType::parallel); - SmallVector indexingMaps( - 2, AffineMap::getMultiDimIdentityMap(selfRank, context)); - Value flipped = - rewriter - .create( - loc, self.getType(), self, initTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (auto i = 0; i < selfRank; i++) - indices.push_back(b.create(loc, i)); - for (auto flipDim : axis) { - indices[flipDim] = b.create( - loc, dims[flipDim], indices[flipDim]); - } - Value res = b.create(loc, self, indices) - .getResult(); - b.create(loc, res); - }) - .getResult(0); - + Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis); rewriter.replaceOpWithNewOp(op, self.getType(), flipped); - return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 6ef947d890cd..18e8fb449ef5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, .getResult(0); return success(); } + +// Flips an input tensor based on the values of axis list. +Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, + Value input, SmallVector axis) { + Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); + Type elementType = cast(input.getType()).getElementType(); + auto selfRank = cast(input.getType()).getRank(); + + // Only used to calculate flipped values, i.e. those on the flip axes. Other + // dims won't be used. + SmallVector dims = getTensorSizes(rewriter, loc, input); + for (auto flipDim : axis) + dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); + + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, input), elementType); + + SmallVector iteratorTypes(selfRank, + utils::IteratorType::parallel); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext())); + Value flipped = + rewriter + .create( + loc, input.getType(), input, initTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (auto i = 0; i < selfRank; i++) + indices.push_back(b.create(loc, i)); + for (auto flipDim : axis) { + indices[flipDim] = b.create(loc, dims[flipDim], + indices[flipDim]); + } + Value res = b.create(loc, input, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + return flipped; +} From 614fcdd153bdb716bf17ea0e1227d10f31896da0 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 8 Oct 2024 10:48:47 +0530 Subject: [PATCH 06/12] [MLIR][TORCH] Add support for 1-d group convolution (#3770) This commit adds the support for the 1-d depthwise convolution as a special case of 1-d group convolution. Signed-Off By: Vivek Khandelwal --- lib/Conversion/TorchToLinalg/Linear.cpp | 50 ++++++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 3 ++ .../torch_mlir_e2e_test/test_suite/conv.py | 27 ++++++++++ 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index fc910fa9d3f2..a4962d12abdc 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1184,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (numSpatialDims != 2) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); - // Special depthwise case: Cin = Cout = groups. // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple // of groups) to be depthwise in their documentation, but the linalg ops @@ -1199,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (inShape[1] == numGroups && weightShape[0] == numGroups && weightShape[1] == 1) { // Collapse weight shape (C/G == 1) - SmallVector collapsedDims = {{0, 1}, {2}, {3}}; - SmallVector collapsedShape{weightShape[0] * weightShape[1], - weightShape[2], weightShape[3]}; + SmallVector collapsedDims = {{0, 1}}; + SmallVector collapsedShape{weightShape[0] * weightShape[1]}; + for (unsigned i = 0; i < numSpatialDims; i++) { + collapsedDims.push_back({i + 2}); + collapsedShape.push_back(weightShape[i + 2]); + } Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), weightDTy); Value collapsedWeight = rewriter.create( loc, collapsedType, weight, collapsedDims); if (!inputZp) { - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + switch (numSpatialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D and 2D depthwise convolution " + "supported for special case of group convolution"); + }; } else { + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D depthwise quantized convolution " + "supported for special case of group convolution"); + // currently, the only named depthwise qconv op is nhwc_hwc // input: nchw -> nhwc; weight (collapsed): chw -> hwc // linalg conv result nhwc -> nchw @@ -1260,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D grouped convolution supported"); + // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { auto inType = cast(tensor.getType()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 09db1098e4b1..83c9ef855e75 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1048,6 +1048,7 @@ "ContainsIntList_False", "ContainsIntList_True", "ContiguousModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", @@ -3395,6 +3396,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -4087,6 +4089,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 4fe50243db60..3bc176048946 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.float32, True), + ([4, 1, 3], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv1d( + inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4 + ) + + +@register_test_case( + module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule() +) +def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(4, 1, 3) + module.forward(inputVec, weight) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() From 58489faf7fdd3e3f20fb849fd89e7bfffe6540fe Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:37:31 -0700 Subject: [PATCH 07/12] torch.aten.squeeze.dim lowering with dynamic dims (#3749) Address https://github.com/nod-ai/SHARK-ModelDev/issues/846 Assume the dynamic squeezed dim is 1. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 15 +++++++++++---- test/Conversion/TorchToLinalg/squeeze.mlir | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/squeeze.mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index ac1707ec23a6..902daa1cb5ad 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1658,10 +1658,17 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - // TODO: Handle the case where the dim(th) dimension is dynamic. + // assert dynamic squeeze dim size == 1 if (inputType.isDynamicDim(dim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: dim(th) dimension is not expected to be dynamic"); + Value cstDim = rewriter.create(op.getLoc(), dim); + Value dimVal = rewriter.create(op.getLoc(), input, cstDim); + Value cstOne = rewriter.create(op.getLoc(), 1); + Value cmp = rewriter.create( + op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); + rewriter.create( + op.getLoc(), cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); } const TypeConverter *typeConverter = getTypeConverter(); @@ -1671,7 +1678,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { // If the dim(th) dimension of operand tensor type is not statically unit, // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1) { + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } diff --git a/test/Conversion/TorchToLinalg/squeeze.mlir b/test/Conversion/TorchToLinalg/squeeze.mlir new file mode 100644 index 000000000000..a8922eed5a9d --- /dev/null +++ b/test/Conversion/TorchToLinalg/squeeze.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic +func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} { + // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index + // CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1" + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor into tensor + // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} From 604aaec294b51324554b1e46ff75c012ec512294 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 2 Jan 2025 12:53:03 +0100 Subject: [PATCH 08/12] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0be7f5b524f1..f2a6f29d1158 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2562,13 +2562,6 @@ } ) - { ### Test failing in make_fx_tosa but not in tosa - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' From 76a95f275a88be1deae98d1f43df2cae63106bfd Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 3 Jan 2025 10:58:45 +0100 Subject: [PATCH 09/12] Fix xfail --- projects/pt1/e2e_testing/xfail_sets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f2a6f29d1158..e9d345773284 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2206,6 +2206,7 @@ "IndexTensorModule3dInputStatic_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", + "IndexSelectStaticModule_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "LayerNormNormalizeOverAllDimsModule_basic", @@ -2541,7 +2542,6 @@ "IndexSelectWholeTensorModule_basic", "IndexSelectNegativeDimModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexSelectStaticModule_basic", "IndexSelectSingleIdxModule_basic", "IndexSelectTwoIdxModule_basic", "LinalgVectorNormModule_basic", From c0eb38e9379685cb78c86d40fdd0139c6925c1b4 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 3 Jan 2025 17:32:10 +0100 Subject: [PATCH 10/12] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9d0a7392919d..865db5077481 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1721,6 +1721,7 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ArangeZeroElementOutputModule_basic", "AtenRoundFloatHalfToEvenModule_basic", "AtenRoundFloatModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic", @@ -2288,6 +2289,7 @@ "NormScalarOptDimModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", + "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", "NumpyTRankNDynamicModule_basic", @@ -2301,6 +2303,7 @@ "OnesModuleInt_basic", "PadModule_basic", "PadWithNoneValModule_basic", + "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "PowFloatFloatModule_basic", @@ -2332,8 +2335,6 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - "RepeatInterleaveFillModule_basic", - "RepeatInterleaveStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ReshapeAliasCollapseModule_basic", @@ -2357,6 +2358,8 @@ "SelectIntNegativeDimAndIndexStaticModule_basic", "SiluModule_basic", "SliceOutOfLowerBoundStartIndexStaticModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceStaticModule_basic", "SliceSizeTwoStepDivisibleStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", From 40a686a750a3a4f3ae48fe99de031d311ad30643 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 6 Jan 2025 09:20:45 +0100 Subject: [PATCH 11/12] bump --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 69364a9a16fc..37878445e55c 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 69364a9a16fc7e2465e107a2ff4255beeba6e821 +Subproject commit 37878445e55cbeb1ba6fc60b6b1dff701dfd9691 From ef59423240438f42e372916452911ec7fd07bd7a Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 6 Jan 2025 11:40:42 +0100 Subject: [PATCH 12/12] bump --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 37878445e55c..b3562f34da70 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 37878445e55cbeb1ba6fc60b6b1dff701dfd9691 +Subproject commit b3562f34da706226e2c2aeda75ebf60b7bf73abd