From 9938abf25e1e7526ca7f43a8c49e9078c14fc55c Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 26 Sep 2024 18:17:22 -0400 Subject: [PATCH 1/4] AtenCumprodOp (#3737) --- include/torch-mlir/Conversion/Utils/Utils.h | 2 + .../TorchToTMTensor/TorchToTMTensor.cpp | 75 +++++++++++++++++ lib/Conversion/Utils/Utils.cpp | 10 +++ .../Transforms/AbstractInterpLibrary.cpp | 22 +++++ projects/pt1/e2e_testing/xfail_sets.py | 21 +++++ .../build_tools/abstract_interp_lib_gen.py | 15 ++++ .../torch_mlir_e2e_test/test_suite/basic.py | 84 +++++++++++++++++++ 7 files changed, 229 insertions(+) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index b76efe869a0f..d21dd5504dcd 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy); +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy); Value castIntToIndex(OpBuilder &b, Location loc, Value v); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index b0b0b0df2ef0..94d7154115be 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1497,6 +1497,79 @@ class ConvertAtenSortOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenCumprodOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + Type elementType = resultType.getElementType(); + Type inputElementType = + cast(input.getType()).getElementType(); + + // Converting the input element type to the result's element type. + // The only possible mismatch would be when the input element type is an + // integer but not `si64`. Therefore, we directly convert the input to + // `si64`. Rest all cases are handled in the dtype definition for this op. + if (elementType != inputElementType) { + Value torchInput = convertTensorToDtype( + rewriter, loc, op.getSelf(), + rewriter.getIntegerType(64, IntegerType::Signed)); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(torchInput.getType()), + torchInput); + } + + int64_t inputRank = resultType.getRank(); + Value dtype = op.getDtype(); + if (!isa(dtype.getType())) + return rewriter.notifyMatchFailure( + op, "unsupported: dtype argument not supported"); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "invalid dim"); + + SmallVector sizes = getTensorSizes(rewriter, loc, input); + Value output = createOneInitTensor(rewriter, loc, sizes, elementType); + output = rewriter.create(loc, resultType, output); + + SmallVector accSizes(sizes); + accSizes.erase(accSizes.begin() + dim); + SmallVector accStatic( + makeShapeTorchCompatible(resultType.getShape())); + accStatic.erase(accStatic.begin() + dim); + Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType); + Type accType = + RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); + acc = rewriter.create(loc, accType, acc); + + Value result = createTMTensorScanOp( + rewriter, loc, input, output, acc, dim, /*inclusive=*/true, + [](OpBuilder &b, Location loc, Value input, Value acc) { + Value prod = + (isa(input.getType()) + ? b.create(loc, input, acc)->getResult(0) + : b.create(loc, input, acc)->getResult(0)); + b.create(loc, prod); + }); + + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: @@ -2240,6 +2313,8 @@ class ConvertTorchToTMTensor patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 5ef0ab16963a..1a208f4ab127 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, return b.create(loc, c0, initTensor).getResult(0); } +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy) { + Value initTensor = + b.create(loc, getAsOpFoldResult(sizes), elemTy); + RankedTensorType type = cast(initTensor.getType()); + Value c1 = + b.create(loc, b.getOneAttr(type.getElementType())); + return b.create(loc, c1, initTensor).getResult(0); +} + Value castIntToIndex(OpBuilder &b, Location loc, Value v) { assert(isa(v.getType()) && "must be called with integer type"); return b.createOrFold(loc, b.getIndexType(), v); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 59cf69393ded..995a7df283fd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9134,6 +9134,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11844,6 +11847,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3b3e4611ea6b..0e741d0de36b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -79,6 +79,7 @@ #### General TorchDynamo/PyTorch errors # torch._dynamo.exc.Unsupported: Tensor.item "CumsumModule_basic", + "CumprodModule_basic", # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 @@ -432,6 +433,7 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -667,6 +669,10 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", @@ -1077,6 +1083,9 @@ "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DetachModule_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -3105,6 +3114,10 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", @@ -3378,6 +3391,10 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", @@ -4110,6 +4127,10 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantModule_F32", "DeterminantBatchedModule_F32", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc49757ee9d3..22fe8e299f07 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1434,6 +1434,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: return self +def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return self + def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self @@ -2926,6 +2929,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt return torch.int64 return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index cb6aa7fc15d7..ef20079b6f75 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class CumprodModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, val): + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumprod(val, ones.item()) + + +@register_test_case(module_factory=lambda: CumprodModule()) +def CumprodModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodStaticModule()) +def CumprodStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticNegativeDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, dim=-1) + + +@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule()) +def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodInputDtypeInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.int32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module()) +def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 7, 4).to(torch.int32)) + + +# ============================================================================== + + class AtenToDeviceModule(torch.nn.Module): def __init__(self): super().__init__() From a33d1232c5c67e82147126619d787d56521f8617 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Fri, 27 Sep 2024 13:30:02 -0700 Subject: [PATCH 2/4] [onnx] Fix onnx.Shape lowering with scalar input (#3716) Address https://github.com/nod-ai/SHARK-Turbine/issues/826 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 16 ++++++++-------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 9 +++++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 36c26f26c2ef..ea5156a0c878 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1662,10 +1662,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto shapeType = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{inputRank}, resultType.getOptionalDtype()); - Value shape = rewriter.create( binder.getLoc(), shapeType, operand); + if (inputRank == 0) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, shape); + return success(); + } + if (start == 0 && end == -1) { rewriter.replaceOp(binder.op, shape); return success(); @@ -1673,18 +1678,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value sv = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(start)); - Value ev = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(end)); - Value step = rewriter.create(binder.getLoc(), 1); - Value dim = rewriter.create(binder.getLoc(), 0); - shape = rewriter.create( - binder.getLoc(), resultType, shape, dim, sv, ev, step); - - rewriter.replaceOp(binder.op, shape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, shape, dim, sv, ev, step); return success(); }); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index af2a1e00299b..bd2a92874843 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2833,6 +2833,15 @@ func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) return %0 : !torch.vtensor<[1],si64> } +// ----- + +// CHECK-LABEL: func.func @test_shape_scalar +func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} { + // CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64> + // CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64> + %0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64> + return %0: !torch.vtensor<[?],si64> +} // ----- From eb4e59e1899d4f3ed61e7ed3956e4fd9e1cc9aae Mon Sep 17 00:00:00 2001 From: yyp0 Date: Sun, 29 Sep 2024 17:41:20 +0800 Subject: [PATCH 3/4] [Torch] support binary_cross_entropy_with_logits decomposition (#3741) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 +++++++ .../Transforms/AbstractInterpLibrary.cpp | 16 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 73 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 12 +++ .../build_tools/torch_ods_gen.py | 3 + .../test_suite/reduction.py | 23 ++++++ 6 files changed, 154 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c9329ccb895d..6f02a94768d0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9224,6 +9224,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy }]; } +def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$pos_weight, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 995a7df283fd..445d4e459013 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10289,6 +10289,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.optional>, !torch.int, !torch.int, !torch.float) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.int) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list, list>\n" " return %0 : !torch.tuple, list, list>\n" @@ -14634,6 +14646,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1ee57b60f248..29c176f96afd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8799,6 +8799,77 @@ class DecomposeAtenCrossEntropyLossOp }; } // namespace +namespace { +class DecomposeAtenBinaryCrossEntropyWithLogitsOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto self = op.getSelf(); + auto target = op.getTarget(); + auto posWeight = op.getPosWeight(); + auto weight = op.getWeight(); + auto reduction = op.getReduction(); + + Value loss; + auto one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto _one = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + + auto _target = + rewriter.create(loc, target.getType(), target, _one); + auto _target_1 = rewriter.create(loc, _target.getType(), + _target, one, one); + Value mm = + rewriter.create(loc, self.getType(), _target_1, self); + Value logSigm = + rewriter.create(loc, self.getType(), self); + + if (!isa(posWeight.getType())) { + auto logWeight = rewriter.create( + loc, posWeight.getType(), + rewriter.create(loc, posWeight.getType(), posWeight, + one, one), + one, one); + loss = rewriter.create( + loc, mm.getType(), mm, + rewriter.create(loc, logWeight.getType(), logWeight, + logSigm), + one); + } else { + loss = + rewriter.create(loc, mm.getType(), mm, logSigm, one); + } + + if (!isa(weight.getType())) { + loss = + rewriter.create(loc, loss.getType(), loss, weight); + } + + // apply loss reduction. + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, "no reduction type is appointed!"); + } + + auto none = rewriter.create(loc); + Value res; + if (reductionInt == 1) { + res = rewriter.create(loc, op.getType(), loss, none); + } else if (reductionInt == 2) { + res = rewriter.create(loc, op.getType(), loss, none); + } else { + res = loss; + } + + rewriter.replaceOp(op, res); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenOneHotOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -9936,6 +10007,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 22fe8e299f07..d3ec25bcea70 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1993,6 +1993,14 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) +def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]: + scalar_shape: List[int] = [] + if reduction == 0: + result_shape = upstream_shape_functions._copy(self) + else: + result_shape = scalar_shape + return result_shape + @check_shape_function([ Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ]) @@ -4958,6 +4966,10 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U return dtype return aten〇std〡dtype(self_rank_dtype) +def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( tensor_shapes=[(3,3)], diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f3227f29b5ce..ea5c504284eb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -743,6 +743,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)" ) + emit( + "aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)" + ) emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 9a683e3c6219..e9b84ea0652c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -2294,6 +2294,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) +class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 2], torch.float32, True), + ([8, 2], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.binary_cross_entropy_with_logits( + input, target, reduction=0 + ) + + +@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule()) +def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.rand(8, 2)) + + # ============================================================================== From a76a787b5d91b7513f8d55ea1719880d4f80113b Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 16 Dec 2024 23:57:24 +0100 Subject: [PATCH 4/4] Update xfail --- projects/pt1/e2e_testing/xfail_sets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d7a3903bf519..c0dded74796c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1741,6 +1741,7 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "BinaryCrossEntropyWithLogitsStaticModule_basic", "ElementwiseAtenFloorDivideBroadcastModule_basic", "ElementwiseAtenFloorDivideScalarModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic",