diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3f0a895c4738..ee80e982c885 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5019,6 +5019,54 @@ def Torch_AtenLogSigmoidOp : Torch_Op<"aten.log_sigmoid", [ }]; } +def Torch_AtenHardshrinkOp : Torch_Op<"aten.hardshrink", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::hardshrink : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$lambd + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHardshrinkOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenHardshrinkOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenSoftshrinkOp : Torch_Op<"aten.softshrink", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::softshrink : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$lambd + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSoftshrinkOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenSoftshrinkOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b60320e4f938..501d9119718b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6514,6 +6514,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hardshrink\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.softshrink\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mish\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10064,6 +10072,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardshrink\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softshrink\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2956d9c3cd7b..83ad93c5e879 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1913,6 +1913,120 @@ class DecomposeAtenLogSigmoidOp : public OpRewritePattern { }; } // namespace +// SoftShrink(x, lambda) function: +// Applies a shrinkage function where: +// - If x > lambda, returns x - lambda +// - If x < -lambda, returns x + lambda +// - Otherwise, returns 0 +namespace { +class DecomposeAtenSoftshrinkOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSoftshrinkOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value lambdValue = op.getLambd(); + + auto resTy = dyn_cast(op.getType()); + if (!resTy || !resTy.hasDtype() || !resTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "result should have dtype and size"); + } + + double lambd; + if (!matchPattern(lambdValue, m_TorchConstantFloat(&lambd))) { + return rewriter.notifyMatchFailure( + op, "expected lambd to be a constant float"); + } + + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value neglambd = rewriter.create( + loc, rewriter.getF64FloatAttr(-lambd)); + Value poslambd = rewriter.create( + loc, rewriter.getF64FloatAttr(lambd)); + + Value constOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + + auto boolResType = + resTy.getWithSizesAndDtype(resTy.getSizes(), rewriter.getI1Type()); + + Value posMask = + rewriter.create(loc, boolResType, self, poslambd); + Value negMask = + rewriter.create(loc, boolResType, self, neglambd); + + Value posValue = rewriter.create(loc, resTy, self, + poslambd, constOneFloat); + Value negValue = rewriter.create(loc, resTy, self, + neglambd, constOneFloat); + + Value result = rewriter.create(loc, resTy, posMask, + posValue, zero); + result = + rewriter.create(loc, resTy, negMask, negValue, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +// HardShrink(x, lambda) function: +// Applies a shrinkage function where: +// - If x > lambda, returns x +// - If x < -lambda, returns x +// - Otherwise, returns 0 +namespace { +class DecomposeAtenHardshrinkOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHardshrinkOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value lambdValue = op.getLambd(); + + auto resTy = dyn_cast(op.getType()); + if (!resTy || !resTy.hasDtype() || !resTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "result should have dtype and size"); + } + + double lambd; + if (!matchPattern(lambdValue, m_TorchConstantFloat(&lambd))) { + return rewriter.notifyMatchFailure( + op, "expected lambd to be a constant float"); + } + + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value neglambd = rewriter.create( + loc, rewriter.getF64FloatAttr(-lambd)); + Value poslambd = rewriter.create( + loc, rewriter.getF64FloatAttr(lambd)); + + auto boolResType = + resTy.getWithSizesAndDtype(resTy.getSizes(), rewriter.getI1Type()); + + Value posMask = + rewriter.create(loc, boolResType, self, poslambd); + Value negMask = + rewriter.create(loc, boolResType, self, neglambd); + + Value result = rewriter.create(loc, resTy, posMask, + self, zero); + result = + rewriter.create(loc, resTy, negMask, self, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + // Decompose aten.matmul into: aten.mm and aten.bmm according to ranks. namespace { class DecomposeAtenMatmulOp : public OpRewritePattern { @@ -7803,6 +7917,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 7bb84abc5a23..cceee7e82dd1 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -371,6 +371,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d3897078ee6c..70de9c28c200 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1493,6 +1493,8 @@ "ElementwiseTruncIntModule_basic", "ElementwiseTruncModule_basic", "ElementwiseLogSigmoidModule_basic", + "ElementwiseHardshrinkStaticModule_basic", + "ElementwiseSoftshrinkStaticModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1773,6 +1775,10 @@ "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSignModule_basic", + "ElementwiseHardshrinkModule_basic", + "ElementwiseHardshrinkStaticModule_basic", + "ElementwiseSoftshrinkModule_basic", + "ElementwiseSoftshrinkStaticModule_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSqrtModule_basic", "ElementwiseSubScalarFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2f6691025278..d648ac3178a6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -254,6 +254,12 @@ def aten〇log〡shape(self: List[int]) -> List[int]: def aten〇log_sigmoid〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇hardshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇softshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇mish〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2202,6 +2208,18 @@ def aten〇log_sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: assert not self_dtype == torch.bool return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5)) +def aten〇hardshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int, float, complex] = 0.5) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5)) +def aten〇softshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int, float, complex] = 0.5) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index d9c3625a0c2d..3773a68670b3 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -480,6 +480,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)") emit("aten::glu : (Tensor, int) -> (Tensor)") emit("aten::log_sigmoid : (Tensor) -> (Tensor)") + emit("aten::hardshrink : (Tensor, Scalar) -> (Tensor)") + emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 4ef16df72f2d..b4497e6bc4b7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2352,6 +2352,98 @@ def ElementwiseLogSigmoidModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSoftshrinkModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.softshrink(a) + + +@register_test_case(module_factory=lambda: ElementwiseSoftshrinkModule()) +def ElementwiseSoftshrinkModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSoftshrinkStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.softshrink(a, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseSoftshrinkStaticModule()) +def ElementwiseSoftshrinkStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6)) + + +# ============================================================================== + + +class ElementwiseHardshrinkModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.hardshrink(a, 1.0) + + +@register_test_case(module_factory=lambda: ElementwiseHardshrinkModule()) +def ElementwiseHardshrinkModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class ElementwiseHardshrinkStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.hardshrink(a, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseHardshrinkStaticModule()) +def ElementwiseHardshrinkStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6)) + + +# ============================================================================== + + class ElementwiseErfModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 7e94a28a7fb7..4cb372ceab57 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -94,11 +94,11 @@ def sparse_export( is addressed, this wrapper provides support for the sparse tensor types by first converting all operands to dense tensors, - building the traced graph as for the dense case, and then - annotation sparse parameters with their actual sparse layout - attributes. This temporary solution accelerates testing - torch-mlir with PyTorch sparse tensors until the issue is - resolved. + building the traced graph as for the dense case, then annotating + sparse parameters with their actual sparse layout attributes, + followed by some simple propagation rules. This temporary solution + accelerates testing torch-mlir with PyTorch sparse tensors until + the issue is resolved upstream. """ # Convert all arguments to dense. dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) @@ -106,21 +106,23 @@ def sparse_export( # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). prog = torch.export.export(f, dargs, kwargs) - # Annotate sparse arguments in the graph. Note that we currently - # only account for sparsity defined by the user inputs to the model. - # TODO: support sparsity in model parameters (weights, biases) - # TODO: propagate sparsity into the layers + # Annotate sparse arguments in the graph and apply some very + # basic propagation rules for sparsity. specs = prog.graph_signature.input_specs alen = len(specs) k = 0 for i, node in enumerate(prog.graph.nodes): - if i >= alen: - break - spec = specs[i] - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - if mask[k]: - node.meta["sparsity"] = sparse_metadata(args[k]) - k = k + 1 + if node.op == "placeholder": + # Argument. + spec = specs[i] + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + if mask[k]: + node.meta["sparsity"] = sparse_metadata(args[k]) + k = k + 1 + elif node.op == "call_function": + # Zero preserving elt-wise unary op. + if node.name in {"abs", "neg", "relu", "sin"}: + node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) return prog @@ -170,8 +172,8 @@ def sparse_jit(f, *args, **kwargs): # Construct the additional position array required by MLIR with data # array([0, nnz]). The COO format always uses int64 indices. xargs.append(np.array([0, a._nnz()], dtype=np.int64)) - # Transform a tensor into [tensor x ndim] to conform - # MLIR SoA COO representation. + # Transform a tensor into ndim x tensor to conform + # to the MLIR SoA COO representation. for idx in a._indices(): xargs.append(idx.numpy()) xargs.append(a._values().numpy()) @@ -204,13 +206,16 @@ def run(f): # CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]> # CHECK: } # -# CHECK: torch.sparse -# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9], -# CHECK: [ 0, 1, 10, 19]{{\]}}), -# CHECK: values=tensor([-1000., -1., 1., 1000.]), -# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo) -# CHECK: torch.mlir -# CHECK: (array([0, 4]), array([0, 1, 2, 9]), array([ 0, 1, 10, 19]), array([-1000., -1., 1., 1000.])) +# CHECK: torch.sparse +# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9], +# CHECK: [ 0, 1, 10, 19]{{\]}}), +# CHECK: values=tensor([-1000., -1., 1., 1000.]), +# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo) +# CHECK: torch.mlir +# CHECK: [0 4] +# CHECK: [0 1 2 9] +# CHECK: [ 0 1 10 19] +# CHECK: [-1000. -1. 1. 1000.] # def test_sparse_id(): class IdNet(torch.nn.Module): @@ -233,7 +238,10 @@ def forward(self, x): print("torch.sparse") print(res1) print("torch.mlir") - print(res2) + print(res2[0]) + print(res2[1]) + print(res2[2]) + print(res2[3]) @run @@ -315,14 +323,14 @@ def forward(self, x, v): # CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> # CHECK: } # -# CHECK: torch.sparse -# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], -# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], -# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) -# CHECK: torch.mlir -# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] -# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] -# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], +# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], +# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) +# CHECK: torch.mlir +# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] +# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] +# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} # def test_sparse_SpMM(): class MatMulNet(torch.nn.Module): @@ -349,41 +357,30 @@ def forward(self, x, y): @run # CHECK-LABEL: test_sparse_eltwise -# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> +# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> { -# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32> -# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> { +# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> +# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> # CHECK: } -# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> +# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32> { -# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32> -# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> { +# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> +# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> # CHECK: } # -# CHECK: torch.sparse -# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), -# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, -# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), -# CHECK: values=tensor({{\[}}[ -1., -2.], -# CHECK: [ -3., -4.], -# ... -# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, -# CHECK: layout=torch.sparse_csr) -# CHECK: torch.mlir -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: torch.mlir.batch # -# CHECK: torch.mlir.batch -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): @@ -397,40 +394,43 @@ def forward(self, x): torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) ) - # This yields a **batched** CSR. - batch_input = dense_input.to_sparse_csr(dense_dim=0) - m = export_and_import(net, batch_input) - print(m) - # This yields a plain CSR with dense **sub**tensor sparse_input = dense_input.to_sparse_csr(dense_dim=1) m = export_and_import(net, sparse_input) print(m) + # This yields a **batched** CSR. + batch_input = dense_input.to_sparse_csr(dense_dim=0) + m = export_and_import(net, batch_input) + print(m) + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. - # - # TODO: propagate sparsity into elt-wise (instead of dense result) res1 = net(sparse_input) - res2 = sparse_jit(net, sparse_input) - res3 = sparse_jit(net, batch_input) + # TODO: make these work + # res2 = sparse_jit(net, sparse_input) + # res3 = sparse_jit(net, batch_input) print("torch.sparse") print(res1) print("torch.mlir") - print(res2) print("torch.mlir.batch") - print(res3) @run # CHECK-LABEL: test_sparse_coo3 # CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> { -# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64> -# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64> +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#[[$COO3]]>) -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> { +# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> -> !torch.vtensor<[10,20,30],f64,#[[$COO3]]> +# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#[[$COO3]]> # CHECK: } # -# TODO: make sure sparsity propagates through relu into the output and test actual JIT output +# CHECK: torch.sparse +# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 1, 4, 9, 9], +# CHECK: [ 0, 1, 1, 5, 19, 19], +# CHECK: [ 0, 1, 3, 6, 28, 29]{{\]}}), +# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), +# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) +# CHECK: torch.mlir # def test_sparse_coo3(): class COO3Net(torch.nn.Module): @@ -450,3 +450,11 @@ def forward(self, x): m = export_and_import(net, sparse_input) print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input) + # TODO: make coo3 work + # res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir")