From bce800a3f4200f8a28e844a91c2be2d2ce71c174 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Wed, 8 May 2024 14:43:06 -0400 Subject: [PATCH 1/3] Integrate llvm-project at dabdec1001dc368373dd581cf72f37a440873ce3 (#3300) Co-authored-by: Jacques Pienaar --- CMakeLists.txt | 7 +++++++ externals/llvm-project | 2 +- .../TorchToTosa/TosaLegalizeCommon.h | 6 +++--- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 9 ++++----- .../TorchToTosa/TosaLegalizeCommon.cpp | 14 +++++++------- test/Conversion/TorchToLinalg/flatten.mlir | 2 +- test/Conversion/TorchToLinalg/unsqueeze.mlir | 10 +++++----- test/Conversion/TorchToLinalg/view.mlir | 18 +++++++++++------- ...torch-backend-to-tosa-backend-pipeline.mlir | 2 +- .../convert-custom-quant-op.mlir | 4 ++-- 10 files changed, 42 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c562fbe31c0..4740f2312394 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,13 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) +# TODO(#3299): migrate to from member x.cast() to mlir::cast(x). +if(MSVC) + add_compile_options(/wd4996) +else() + add_compile_options(-Wno-deprecated-declarations) +endif() + macro(torch_mlir_enable_werror) if(TORCH_MLIR_ENABLE_WERROR_FLAG) if(NOT MSVC) diff --git a/externals/llvm-project b/externals/llvm-project index 593f6fdcb4bb..dabdec1001dc 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf +Subproject commit dabdec1001dc368373dd581cf72f37a440873ce3 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 44b9bbdde3b2..398926e8168a 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -40,9 +40,9 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, // This specialization is for Div op. Unlike other binary ops, it doesn't // support floating type. template <> -tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, - Operation *op, TensorType outType, - Value lhs, Value rhs); +tosa::IntDivOp +createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs); std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, Operation *op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 168515051269..db8a435ab57a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -511,8 +511,8 @@ class ConvertAtenDivOp : public OpConversionPattern { } else { // The output type can be different than the input types (e.g. dividing an // int tensor results in a floating point tensor). - result = tosa::createBinaryOpAndCast(rewriter, op, outType, - lhs, rhsTensor) + result = tosa::createBinaryOpAndCast( + rewriter, op, outType, lhs, rhsTensor) .getResult(); } @@ -4380,7 +4380,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( self = rewriter.create(op.getLoc(), outType, self); auto divTensor = self; - // tosa::DivOp only supports int if (isa(outElemTy)) { auto otherTensorReciprocal = rewriter.create( op.getLoc(), otherTensor.getType(), otherTensor); @@ -4388,8 +4387,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); divTensor = rewriter.create(op.getLoc(), outType, divTensor); } else { - divTensor = - rewriter.create(op.getLoc(), outType, self, otherTensor); + divTensor = rewriter.create(op.getLoc(), outType, self, + otherTensor); } auto mulTensor = diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index f9d3071fd10c..ae8d347e0cfd 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -114,20 +114,20 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, } template <> -tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, - Operation *op, TensorType outType, - Value lhs, Value rhs) { +tosa::IntDivOp +createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs) { auto lhsElemTy = cast(lhs.getType()).getElementType(); auto rhsElemTy = cast(rhs.getType()).getElementType(); if (isa(lhsElemTy) || isa(rhsElemTy)) { - (void)rewriter.notifyMatchFailure(op, - "tosa.div only supports integer type"); + (void)rewriter.notifyMatchFailure( + op, "tosa.int_div only supports integer type"); } lhs = promoteType(rewriter, lhs, outType); rhs = promoteType(rewriter, rhs, outType); - return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, - lhs, rhs); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, + lhs, rhs); } std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, diff --git a/test/Conversion/TorchToLinalg/flatten.mlir b/test/Conversion/TorchToLinalg/flatten.mlir index a2648e2b10f9..9c21db08d2cf 100644 --- a/test/Conversion/TorchToLinalg/flatten.mlir +++ b/test/Conversion/TorchToLinalg/flatten.mlir @@ -73,7 +73,7 @@ func.func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3 // CHECK-LABEL: func.func @torch.aten.flatten.using_ints$rank0( // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor into tensor<1xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> diff --git a/test/Conversion/TorchToLinalg/unsqueeze.mlir b/test/Conversion/TorchToLinalg/unsqueeze.mlir index 44e028e9dd89..85eb5eee9c04 100644 --- a/test/Conversion/TorchToLinalg/unsqueeze.mlir +++ b/test/Conversion/TorchToLinalg/unsqueeze.mlir @@ -6,7 +6,7 @@ // CHECK-LABEL: func.func @torch.aten.unsqueeze$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor into tensor<1xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { @@ -18,7 +18,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.v // CHECK-LABEL: func.func @torch.aten.unsqueeze$basic_negative( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor into tensor<1xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { @@ -30,7 +30,7 @@ func.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> // CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_front( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 3, 4] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32> func.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { @@ -42,7 +42,7 @@ func.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4], // CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_back( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] output_shape [2, 3, 4, 1] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32> func.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { @@ -54,7 +54,7 @@ func.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f // CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_middle( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] output_shape [2, 3, 1, 4] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32> func.func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index c606328d339b..20f0301cc491 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -6,7 +6,7 @@ // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32> @@ -64,7 +64,7 @@ func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> ! // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] : tensor<16x128xf32> into tensor<16x1x128xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] output_shape [16, 1, 128] : tensor<16x128xf32> into tensor<16x1x128xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32> @@ -83,7 +83,7 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> ! // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2]] : tensor<4x5x6xf32> into tensor<120xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] : tensor<120xf32> into tensor<8x1x15x1xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [8, 1, 15, 1] : tensor<120xf32> into tensor<8x1x15x1xf32> // CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<8x1x15x1xf32> to tensor<8x1x?x1xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<8x1x?x1xf32> -> !torch.vtensor<[8,1,?,1],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[8,1,?,1],f32> @@ -103,7 +103,7 @@ func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !t // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2]] : tensor<4x5x6xf32> into tensor<4x30xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [2, 1, 2, 3, 10] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32> // CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x1x2x3x10xf32> to tensor<2x1x2x3x?xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<2x1x2x3x?xf32> -> !torch.vtensor<[2,1,2,3,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,1,2,3,?],f32> @@ -125,7 +125,7 @@ func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> ! // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] : tensor<12xf32> into tensor<3x2x2xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [3, 2, 2] : tensor<12xf32> into tensor<3x2x2xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32> @@ -144,7 +144,9 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,3,?,2,3],f32> -> tensor<10x3x?x2x3xf32> // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3, 4]] : tensor<10x3x?x2x3xf32> into tensor<30x?x6xf32> -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4]] : tensor<30x?x6xf32> into tensor<2x3x5x?x6xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSE]], %[[C1]] : tensor<30x?x6xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4]] output_shape [2, 3, 5, %[[DIM]], 6] : tensor<30x?x6xf32> into tensor<2x3x5x?x6xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x3x5x?x6xf32> -> !torch.vtensor<[2,3,5,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3,5,?,6],f32> @@ -241,7 +243,9 @@ func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,?,2,3],f32> -> tensor<10x?x2x3xf32> // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<10x?x2x3xf32> into tensor<10x?x6xf32> -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1], [2], [3]] : tensor<10x?x6xf32> into tensor<2x5x?x6xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSE]], %[[C1]] : tensor<10x?x6xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 5, %[[DIM]], 6] : tensor<10x?x6xf32> into tensor<2x5x?x6xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x5x?x6xf32> -> !torch.vtensor<[2,5,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,5,?,6],f32> diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 5813cd4351ac..08ec48448bab 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -109,7 +109,7 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32> // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.div %[[VAL_2]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.int_div %[[VAL_2]], %[[VAL_1]] : (tensor, tensor) -> tensor func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir index 7aca3551cfc2..b7a7087d7112 100644 --- a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -20,8 +20,8 @@ func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2] // CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16> // CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> // CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16> - // CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16> - // CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8> + // CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1, 1, 2] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16> + // CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] output_shape [2, 1, 2] : tensor<2x2xi8> into tensor<2x1x2xi8> // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16 // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16> // CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16> From a9abc4ace746844da82f22a23207dfd6e1e5678b Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 22 Aug 2024 15:58:29 +0200 Subject: [PATCH 2/3] Bump LLVM --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 03c95c3d38df..643434e85ec4 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 03c95c3d38df74d490e78ad449c06a319a6baff0 +Subproject commit 643434e85ec471807b044e0eafda30faddebac4c From 1232fad68cf9a3b2e0b9ba338afe6b93699e101f Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 23 Aug 2024 11:21:02 +0200 Subject: [PATCH 3/3] Update to May 11 --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 643434e85ec4..64ba2b4bb742 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 643434e85ec471807b044e0eafda30faddebac4c +Subproject commit 64ba2b4bb7427e4e62fa3718dc296ea6b73fa20b