diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index f174da4f43b7..a65c446b5fe9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1173,8 +1173,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Matmul: input datatypes mismatched"); - auto outputElemType = getMatMulOutputType(lhsElemTy, rewriter); - if (!outputElemType) { + auto outputElemTy = getMatMulOutputType(lhsElemTy, rewriter); + if (!outputElemTy) { return rewriter.notifyMatchFailure( op, "Only i8 and i16 integer and bf16, f16 and " "f32 float types are valid"); @@ -1553,12 +1553,6 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { SmallVector matmulOutputShape( {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); - Type outputElemTy; - if (lhsElemTy.isa()) { - outputElemTy = lhsElemTy; - } else { // qint8 emits i32 matmul output - outputElemTy = rewriter.getIntegerType(32); - } auto mmOutputTy = RankedTensorType::get( makeShapeLLVMCompatible(matmulOutputShape), outputElemTy); @@ -1571,6 +1565,14 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { matmulLhs, matmulRhs) .getResult(); + auto castOutputTy = RankedTensorType::get( + makeShapeLLVMCompatible(matmulOutputShape), lhsElemTy); + auto castResult = rewriter.createOrFold( + op->getLoc(), + OpConversionPattern::getTypeConverter() + ->convertType(castOutputTy), + mmOpResult); + // Perform the reshape to output shape. This is always required unless max // input rank=3 and there was no broadcasting, in which case the tosa.matmul // output itself is correctly shaped. @@ -1671,12 +1673,12 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Perform reshape auto reshapedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(reshapedOpShape), outputElemTy); + makeShapeLLVMCompatible(reshapedOpShape), lhsElemTy); auto reshapedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), - mmOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); + castResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); if (opNeedsTranspose) { @@ -1700,7 +1702,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { output = reshapedOp.getResult(); } } else { - output = mmOpResult; + output = castResult; } return success(); @@ -1722,13 +1724,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Failed to perform matmul operation"); - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(), - output); - + rewriter.replaceOp(op, output); return success(); } }; @@ -1898,14 +1894,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { matmulOutput, bias) .getResult(); } - - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(), - matmulPlusBias); - + rewriter.replaceOp(op, matmulPlusBias); return success(); } }; diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index e57467ba2416..74025cfc6342 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -30,9 +30,9 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<1x4x8xf32>, tensor<1x8x16xf32>) -> tensor<1x4x16xf32> // CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x4x16xf32>) -> tensor<4x16xf32> -func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> + return %0 : !torch.vtensor<[4,16],f32> } // ----- @@ -55,9 +55,9 @@ func.func @torch.aten.matmul_1d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch. // CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> // CHECK-NEXT: %[[VAL_5:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> // CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> -func.func @torch.aten.matmul_12d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[?],f32> { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[6],f32>, !torch.vtensor<[6,1],f32> -> !torch.vtensor<[?],f32> - return %0 : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul_12d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[6],f32>, !torch.vtensor<[6,1],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> } // ----- @@ -67,9 +67,9 @@ func.func @torch.aten.matmul_12d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch // CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> // CHECK-NEXT: %[[VAL_5:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_4]] : (tensor<1x2x6xf32>, tensor<1x6x1xf32>) -> tensor<1x2x1xf32> // CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<1x2x1xf32>) -> tensor<2xf32> -func.func @torch.aten.matmul_21d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6],f32>) -> !torch.vtensor<[?],f32> { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6],f32> -> !torch.vtensor<[?],f32> - return %0 : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul_21d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6],f32>) -> !torch.vtensor<[2],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6],f32> -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> } // ----- @@ -78,9 +78,9 @@ func.func @torch.aten.matmul_21d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !tor // CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x6x8xf32> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<1x2x6xf32>, tensor<1x6x8xf32>) -> tensor<1x2x8xf32> // CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x8xf32>) -> tensor<2x8xf32> -func.func @torch.aten.mm_2d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6,8],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6,8],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mm_2d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6,8],f32>) -> !torch.vtensor<[2,8],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6,8],f32> -> !torch.vtensor<[2,8],f32> + return %0 : !torch.vtensor<[2,8],f32> } // ----- @@ -89,9 +89,27 @@ func.func @torch.aten.mm_2d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vt // CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x10x2x6xf32>) -> tensor<100x2x6xf32> // CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> // CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<100x6x6xf32>) -> tensor<10x10x6x6xf32> -func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[?,?,?,?],f32> { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[10,10,6,6],f32> + return %0 : !torch.vtensor<[10,10,6,6],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<10x6x2xf32>) -> tensor<1x10x6x2xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %[[VAL_2]], %[[VAL_3]] : (tensor<1x10x6x2xf32>, tensor<4xi32>) -> tensor<10x1x6x2xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<10x1x6x2xf32>) -> tensor<10x6x2xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.transpose %1, %[[VAL_6]] : (tensor<10x10x2x6xf32>, tensor<4xi32>) -> tensor<10x2x10x6xf32> +// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<10x2x10x6xf32>) -> tensor<10x2x60xf32> +// CHECK-NEXT: %[[VAL_9:.+]] = tosa.matmul %[[VAL_5]], %[[VAL_8]] : (tensor<10x6x2xf32>, tensor<10x2x60xf32>) -> tensor<10x6x60xf32> +// CHECK-NEXT: %[[VAL_10:.+]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<10x6x60xf32>) -> tensor<10x6x10x6xf32> +// CHECK-NEXT: %[[VAL_11:.+]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_12:.+]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<10x6x10x6xf32>, tensor<4xi32>) -> tensor<10x10x6x6xf32> +func.func @torch.aten.matmul_4d_broadcast(%arg0 : !torch.vtensor<[10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[10,10,6,6],f32> + return %0 : !torch.vtensor<[10,10,6,6],f32> } // ----- @@ -104,9 +122,9 @@ func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : // CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x20x21xf32>) -> tensor<4x5x3x7xf32> // CHECK-NEXT: %[[VAL_8:.+]] = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-NEXT: %[[VAL_9:.+]] = tosa.transpose %[[VAL_7]], %[[VAL_8]] : (tensor<4x5x3x7xf32>, tensor<4xi32>) -> tensor<4x3x5x7xf32> -func.func @torch.aten.matmul_4d_broadcast_2(%arg0 : !torch.vtensor<[4,1,5,6],f32>, %arg1 : !torch.vtensor<[1,3,6,7],f32>) -> !torch.vtensor<[?,?,?,?],f32> { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,1,5,6],f32>, !torch.vtensor<[1,3,6,7],f32> -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.matmul_4d_broadcast_2(%arg0 : !torch.vtensor<[4,1,5,6],f32>, %arg1 : !torch.vtensor<[1,3,6,7],f32>) -> !torch.vtensor<[4,3,5,7],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,1,5,6],f32>, !torch.vtensor<[1,3,6,7],f32> -> !torch.vtensor<[4,3,5,7],f32> + return %0 : !torch.vtensor<[4,3,5,7],f32> } // ----- @@ -118,38 +136,37 @@ func.func @torch.aten.matmul_4d_broadcast_2(%arg0 : !torch.vtensor<[4,1,5,6],f32 // CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<8x1x16xf32>) -> tensor<1x8x16xf32> // CHECK-NEXT: %[[VAL_7:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_6]] : (tensor<1x400x8xf32>, tensor<1x8x16xf32>) -> tensor<1x400x16xf32> // CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x400x16xf32>) -> tensor<100x4x16xf32> -func.func @torch.aten.matmul_3d_broadcast(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[?,?,?],f32> { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.matmul_3d_broadcast(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[100,4,16],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[100,4,16],f32> + return %0 : !torch.vtensor<[100,4,16],f32> } // ----- -// CHECK-LABEL: torch.aten.bmm_3d_fp16 -// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf16> -func.func @torch.aten.bmm_3d_fp16(%arg0 : !torch.vtensor<[100,4,8],f16>, %arg1 : !torch.vtensor<[100,8,16],f16>) -> !torch.vtensor<[?,?,?],f16> { - %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f16>, !torch.vtensor<[100,8,16],f16> -> !torch.vtensor<[?,?,?],f16> - return %0 : !torch.vtensor<[?,?,?],f16> +// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %[[VAL_2]] : (tensor<100x4x16xf32>) -> tensor<100x4x16xf16> +func.func @torch.aten.bmm_3d_fp16(%arg0 : !torch.vtensor<[100,4,8],f16>, %arg1 : !torch.vtensor<[100,8,16],f16>) -> !torch.vtensor<[100,4,16],f16> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f16>, !torch.vtensor<[100,8,16],f16> -> !torch.vtensor<[100,4,16],f16> + return %0 : !torch.vtensor<[100,4,16],f16> } // ----- -// CHECK-LABEL: torch.aten.bmm_3d_bf16 -// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<100x4x16xbf16> -func.func @torch.aten.bmm_3d_bf16(%arg0 : !torch.vtensor<[100,4,8],bf16>, %arg1 : !torch.vtensor<[100,8,16],bf16>) -> !torch.vtensor<[?,?,?],bf16> { - %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],bf16>, !torch.vtensor<[100,8,16],bf16> -> !torch.vtensor<[?,?,?],bf16> - return %0 : !torch.vtensor<[?,?,?],bf16> + +// CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<100x4x16xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %[[VAL_2]] : (tensor<100x4x16xf32>) -> tensor<100x4x16xbf16> +func.func @torch.aten.bmm_3d_bf16(%arg0 : !torch.vtensor<[100,4,8],bf16>, %arg1 : !torch.vtensor<[100,8,16],bf16>) -> !torch.vtensor<[100,4,16],bf16> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],bf16>, !torch.vtensor<[100,8,16],bf16> -> !torch.vtensor<[100,4,16],bf16> + return %0 : !torch.vtensor<[100,4,16],bf16> } // ----- // CHECK: %[[VAL_2:.+]] = tosa.matmul %0, %1 : (tensor<100x4x8xf32>, tensor<100x8x16xf32>) -> tensor<100x4x16xf32> -func.func @torch.aten.bmm_3d_fp32(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[100,8,16],f32>) -> !torch.vtensor<[?,?,?],f32> { - %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[100,8,16],f32> -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.bmm_3d_fp32(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[100,8,16],f32>) -> !torch.vtensor<[100,4,16],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[100,8,16],f32> -> !torch.vtensor<[100,4,16],f32> + return %0 : !torch.vtensor<[100,4,16],f32> } - - // ----- // CHECK-LABEL: func.func @torch.aten.relu$basic(