diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index d96ad919b65f0a..0c315102f3fce5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> { let dependentDialects = [ "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect" ]; + let options = [ + Option<"removeOutsDependency", "remove-outs-dependency", "bool", + /*default=*/"true", + "Replace out by tensor.empty">, + ]; } def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 0208f854f799ec..d8920b34f7bcc8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function; /// when both operations are fusable elementwise operations. void populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, - const ControlFusionFn &controlElementwiseOpFusion); + const ControlFusionFn &controlElementwiseOpFusion, + bool replaceOutsDependency = true); /// Function type which is used to control propagation of tensor.pack/unpack /// ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 6c806fb9828dc2..cd0de6dfaf9411 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, - const ControlFusionFn &controlElementwiseOpsFusion) { + const ControlFusionFn &controlElementwiseOpsFusion, + bool removeOutsDependency) { auto *context = patterns.getContext(); patterns.add(context, controlElementwiseOpsFusion); - patterns.add(context); + patterns.add(context); + if (removeOutsDependency) + patterns.add(context); // Add the patterns that clean up dead operands and results. populateEraseUnusedOperandsAndResultsPatterns(patterns); } @@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass }; // Add elementwise op fusion patterns. - populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); + populateElementwiseOpsFusionPatterns(patterns, defaultControlFn, + removeOutsDependency); populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); tensor::populateBubbleUpExpandShapePatterns(patterns); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir new file mode 100644 index 00000000000000..cf40f3ff519014 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops-no-remove-outs-deps.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s + +#identity = affine_map<(d0) -> (d0)> + +func.func @keep_outs_dependency(%arg: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NOT: tensor.empty + %1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) { + ^bb0(%in: f32, %out: f32): + %exp = arith.negf %in: f32 + linalg.yield %exp : f32 + } -> tensor<4xf32> + return %1 : tensor<4xf32> +}