diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index b76efe869a0f..d21dd5504dcd 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy); +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy); Value castIntToIndex(OpBuilder &b, Location loc, Value v); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index be4a858ce130..4f548a8768eb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1662,10 +1662,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto shapeType = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{inputRank}, resultType.getOptionalDtype()); - Value shape = rewriter.create( binder.getLoc(), shapeType, operand); + if (inputRank == 0) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, shape); + return success(); + } + if (start == 0 && end == -1) { rewriter.replaceOp(binder.op, shape); return success(); @@ -1673,18 +1678,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value sv = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(start)); - Value ev = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(end)); - Value step = rewriter.create(binder.getLoc(), 1); - Value dim = rewriter.create(binder.getLoc(), 0); - shape = rewriter.create( - binder.getLoc(), resultType, shape, dim, sv, ev, step); - - rewriter.replaceOp(binder.op, shape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, shape, dim, sv, ev, step); return success(); }); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 980cfdd53033..24884281e036 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1500,6 +1500,79 @@ class ConvertAtenSortOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenCumprodOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + Type elementType = resultType.getElementType(); + Type inputElementType = + cast(input.getType()).getElementType(); + + // Converting the input element type to the result's element type. + // The only possible mismatch would be when the input element type is an + // integer but not `si64`. Therefore, we directly convert the input to + // `si64`. Rest all cases are handled in the dtype definition for this op. + if (elementType != inputElementType) { + Value torchInput = convertTensorToDtype( + rewriter, loc, op.getSelf(), + rewriter.getIntegerType(64, IntegerType::Signed)); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(torchInput.getType()), + torchInput); + } + + int64_t inputRank = resultType.getRank(); + Value dtype = op.getDtype(); + if (!isa(dtype.getType())) + return rewriter.notifyMatchFailure( + op, "unsupported: dtype argument not supported"); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "invalid dim"); + + SmallVector sizes = getTensorSizes(rewriter, loc, input); + Value output = createOneInitTensor(rewriter, loc, sizes, elementType); + output = rewriter.create(loc, resultType, output); + + SmallVector accSizes(sizes); + accSizes.erase(accSizes.begin() + dim); + SmallVector accStatic( + makeShapeTorchCompatible(resultType.getShape())); + accStatic.erase(accStatic.begin() + dim); + Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType); + Type accType = + RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); + acc = rewriter.create(loc, accType, acc); + + Value result = createTMTensorScanOp( + rewriter, loc, input, output, acc, dim, /*inclusive=*/true, + [](OpBuilder &b, Location loc, Value input, Value acc) { + Value prod = + (isa(input.getType()) + ? b.create(loc, input, acc)->getResult(0) + : b.create(loc, input, acc)->getResult(0)); + b.create(loc, prod); + }); + + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: @@ -2243,6 +2316,8 @@ class ConvertTorchToTMTensor patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 5ef0ab16963a..1a208f4ab127 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, return b.create(loc, c0, initTensor).getResult(0); } +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy) { + Value initTensor = + b.create(loc, getAsOpFoldResult(sizes), elemTy); + RankedTensorType type = cast(initTensor.getType()); + Value c1 = + b.create(loc, b.getOneAttr(type.getElementType())); + return b.create(loc, c1, initTensor).getResult(0); +} + Value castIntToIndex(OpBuilder &b, Location loc, Value v) { assert(isa(v.getType()) && "must be called with integer type"); return b.createOrFold(loc, b.getIndexType(), v); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 8620ee183e56..7f85b891ff59 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9155,6 +9155,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11878,6 +11881,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0bbdea4eb9e8..d7a3903bf519 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -102,6 +102,7 @@ #### General TorchDynamo/PyTorch errors # torch._dynamo.exc.Unsupported: Tensor.item "CumsumModule_basic", + "CumprodModule_basic", # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 @@ -471,6 +472,7 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -713,6 +715,10 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", @@ -1129,6 +1135,9 @@ "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DetachModule_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -3340,6 +3349,10 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", @@ -3639,6 +3652,10 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", @@ -4370,6 +4387,10 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantModule_F32", "DeterminantBatchedModule_F32", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index fb33b2c45d35..a1859f7cd9f6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1441,6 +1441,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: return self +def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return self + def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self @@ -2947,6 +2950,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt return torch.int64 return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index d5f7e3922400..a2972bd66573 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5129,6 +5129,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class CumprodModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, val): + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumprod(val, ones.item()) + + +@register_test_case(module_factory=lambda: CumprodModule()) +def CumprodModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodStaticModule()) +def CumprodStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticNegativeDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, dim=-1) + + +@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule()) +def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodInputDtypeInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.int32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module()) +def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 7, 4).to(torch.int32)) + + +# ============================================================================== + + class AtenToDeviceModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5ae4e1938816..7ca44e25dc3f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2867,6 +2867,15 @@ func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) return %0 : !torch.vtensor<[1],si64> } +// ----- + +// CHECK-LABEL: func.func @test_shape_scalar +func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} { + // CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64> + // CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64> + %0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64> + return %0: !torch.vtensor<[?],si64> +} // -----