From 7423413ffa936cf293a754b740bc67090672c890 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 18 Dec 2024 15:08:49 +0100 Subject: [PATCH 1/3] ElementwiseOpFusion: option for disable empty --- mlir/include/mlir/Dialect/Linalg/Passes.td | 5 +++++ .../mlir/Dialect/Linalg/Transforms/Transforms.h | 3 ++- .../Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 11 +++++++---- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index d96ad919b65f0a..e32b28bb2af8b7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> { let dependentDialects = [ "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect" ]; + let options = [ + Option<"introduceTensorEmpty", "introduce-empty", "bool", + /*default=*/"true", + "Replace out by tensor.empty">, + ]; } def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 0208f854f799ec..1ca09d97cdee63 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function; /// when both operations are fusable elementwise operations. void populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, - const ControlFusionFn &controlElementwiseOpFusion); + const ControlFusionFn &controlElementwiseOpFusion, + bool introduceTensorEmpty = true); /// Function type which is used to control propagation of tensor.pack/unpack /// ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 6c806fb9828dc2..5f4938c787820c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, - const ControlFusionFn &controlElementwiseOpsFusion) { + const ControlFusionFn &controlElementwiseOpsFusion, + bool introduceTensorEmpty) { auto *context = patterns.getContext(); patterns.add(context, controlElementwiseOpsFusion); - patterns.add(context); + patterns.add(context); + if (introduceTensorEmpty) + patterns.add(context); // Add the patterns that clean up dead operands and results. populateEraseUnusedOperandsAndResultsPatterns(patterns); } @@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass }; // Add elementwise op fusion patterns. - populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); + populateElementwiseOpsFusionPatterns(patterns, defaultControlFn, + introduceTensorEmpty); populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); tensor::populateBubbleUpExpandShapePatterns(patterns); From 08dcd58a26d52f1aaee3885679cfe61f8f8acfd4 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 18 Dec 2024 14:40:16 +0000 Subject: [PATCH 2/3] Address comments --- mlir/include/mlir/Dialect/Linalg/Passes.td | 2 +- .../Dialect/Linalg/Transforms/Transforms.h | 2 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 6 +++--- ...on-elementwise-ops-no-remove-outs-deps.mlir | 18 ++++++++++++++++++ 4 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index e32b28bb2af8b7..0c315102f3fce5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -76,7 +76,7 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> { "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect" ]; let options = [ - Option<"introduceTensorEmpty", "introduce-empty", "bool", + Option<"removeOutsDependency", "remove-outs-dependency", "bool", /*default=*/"true", "Replace out by tensor.empty">, ]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1ca09d97cdee63..d8920b34f7bcc8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1702,7 +1702,7 @@ using ControlFusionFn = std::function; void populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion, - bool introduceTensorEmpty = true); + bool replaceOutsDependency = true); /// Function type which is used to control propagation of tensor.pack/unpack /// ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 5f4938c787820c..cd0de6dfaf9411 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2135,11 +2135,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpsFusion, - bool introduceTensorEmpty) { + bool removeOutsDependency) { auto *context = patterns.getContext(); patterns.add(context, controlElementwiseOpsFusion); patterns.add(context); - if (introduceTensorEmpty) + if (removeOutsDependency) patterns.add(context); // Add the patterns that clean up dead operands and results. populateEraseUnusedOperandsAndResultsPatterns(patterns); @@ -2183,7 +2183,7 @@ struct LinalgElementwiseOpFusionPass // Add elementwise op fusion patterns. populateElementwiseOpsFusionPatterns(patterns, defaultControlFn, - introduceTensorEmpty); + removeOutsDependency); populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); tensor::populateBubbleUpExpandShapePatterns(patterns); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir new file mode 100644 index 00000000000000..cc2e9abe958d3b --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s + +#identity = affine_map<(d0) -> (d0)> + +func.func @redudant_copy_with_target_burst_size_two(%arg: tensor<4xf32>) -> tensor<4xf32> attributes {plhw.toplevel} { + // CHECK-NOT: tensor.empty + %1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) { + ^bb0(%in: f32, %out: f32): + %exp = arith.negf %in: f32 + linalg.yield %exp : f32 + } -> tensor<4xf32> + %2 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%1: tensor<4xf32>) outs(%arg: tensor<4xf32>) { + ^bb0(%in: f32, %out: f32): + %exp = arith.mulf %in,%in: f32 + linalg.yield %exp : f32 + } -> tensor<4xf32> + return %2 : tensor<4xf32> +} \ No newline at end of file From cb794e64947c700522ecdf654aea2e7112133f5d Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 18 Dec 2024 14:52:15 +0000 Subject: [PATCH 3/3] Use single linalg.generic --- .../fusion-elementwise-ops-no-remove-outs-deps.mlir | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir index cc2e9abe958d3b..cf40f3ff519014 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir @@ -2,17 +2,12 @@ #identity = affine_map<(d0) -> (d0)> -func.func @redudant_copy_with_target_burst_size_two(%arg: tensor<4xf32>) -> tensor<4xf32> attributes {plhw.toplevel} { +func.func @keep_outs_dependency(%arg: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NOT: tensor.empty %1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) { ^bb0(%in: f32, %out: f32): %exp = arith.negf %in: f32 linalg.yield %exp : f32 } -> tensor<4xf32> - %2 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%1: tensor<4xf32>) outs(%arg: tensor<4xf32>) { - ^bb0(%in: f32, %out: f32): - %exp = arith.mulf %in,%in: f32 - linalg.yield %exp : f32 - } -> tensor<4xf32> - return %2 : tensor<4xf32> -} \ No newline at end of file + return %1 : tensor<4xf32> +}