Skip to content

Commit

Permalink
[FXML-5417] TileUsingInterface: drop unused extract_slice (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD authored Jan 6, 2025
1 parent 752540d commit bada367
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,12 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
tiledAndFusedOps.insert(tiledAndFusedOp);
}

// Drop the extract_slice if it has been replaced by the tiled producer, and
// is no longer used.
if (worklistItem.candidateSlice->use_empty()) {
rewriter.eraseOp(worklistItem.candidateSlice);
}

if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
}
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,46 @@ module {
}
}

// -----

// This test checks that upon tiling and fusion, Linalg ops that have been tiled
// through fusion and are not used elsewhere are indeed dead code and get
// dropped.

// CHECK-LABEL: func @tile_fuse_drop_dead_producer(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
func.func @tile_fuse_drop_dead_producer(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%c2f = arith.constant 2.0 : f32

// CHECK-NOT: linalg.generic {{{[^\}]*}}} ins(%[[TA]] : tensor<10x10xf32>) outs(%{{.*}} : tensor<10x10xf32>) {
%empty = tensor.empty() : tensor<10x10xf32>
%0 = linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]}
ins(%arg0: tensor<10x10xf32>) outs(%empty: tensor<10x10xf32>) {
^bb0(%a: f32, %b: f32):
%res = arith.addf %a, %c2f : f32
linalg.yield %res : f32
} -> tensor<10x10xf32>

%empty2 = tensor.empty() : tensor<10x10xf32>
// CHECK: scf.for {{.*}} {
// CHECK: scf.for {{.*}} {
// CHECK: linalg.generic
// CHECK: linalg.negf
// CHECK: }
// CHECK: }
%1 = linalg.negf ins(%0 : tensor<10x10xf32>) outs(%empty2 : tensor<10x10xf32>) -> tensor<10x10xf32>

return %1 : tensor<10x10xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.negf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tiled_low, %loop1, %loop2 = transform.structured.fuse %0 [5, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}


////////////////////////////////////////////////////////////////////////////////
// Tests below are expected to fail.
Expand Down

0 comments on commit bada367

Please sign in to comment.