Skip to content

Commit

Permalink
ElementwiseOpFusion: option for disable empty (#430)
Browse files Browse the repository at this point in the history
* ElementwiseOpFusion: option for disable empty

---------

Co-authored-by: Matthias Gehre <matthias.gehre@amd.com>
  • Loading branch information
josel-amd and mgehre-amd authored Dec 18, 2024
1 parent 992dad3 commit 5518042
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
let options = [
Option<"removeOutsDependency", "remove-outs-dependency", "bool",
/*default=*/"true",
"Replace out by tensor.empty">,
];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
const ControlFusionFn &controlElementwiseOpFusion,
bool replaceOutsDependency = true);

/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(

void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpsFusion) {
const ControlFusionFn &controlElementwiseOpsFusion,
bool removeOutsDependency) {
auto *context = patterns.getContext();
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant>(context);
if (removeOutsDependency)
patterns.add<RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
Expand Down Expand Up @@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass
};

// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
removeOutsDependency);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s

#identity = affine_map<(d0) -> (d0)>

func.func @keep_outs_dependency(%arg: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: tensor.empty
%1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%exp = arith.negf %in: f32
linalg.yield %exp : f32
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}

0 comments on commit 5518042

Please sign in to comment.