diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 926ebb1a2cea87..ba44e5674fb21f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1536,6 +1536,12 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( tiledAndFusedOps.insert(tiledAndFusedOp); } + // Drop the extract_slice if it has been replaced by the tiled producer, and + // is no longer used. + if (worklistItem.candidateSlice->use_empty()) { + rewriter.eraseOp(worklistItem.candidateSlice); + } + if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index 4115f2857a20c6..2cea815ac2b047 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -651,6 +651,46 @@ module { } } +// ----- + +// This test checks that upon tiling and fusion, Linalg ops that have been tiled +// through fusion and are not used elsewhere are indeed dead code and get +// dropped. + +// CHECK-LABEL: func @tile_fuse_drop_dead_producer( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +func.func @tile_fuse_drop_dead_producer(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %c2f = arith.constant 2.0 : f32 + + // CHECK-NOT: linalg.generic {{{[^\}]*}}} ins(%[[TA]] : tensor<10x10xf32>) outs(%{{.*}} : tensor<10x10xf32>) { + %empty = tensor.empty() : tensor<10x10xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} + ins(%arg0: tensor<10x10xf32>) outs(%empty: tensor<10x10xf32>) { + ^bb0(%a: f32, %b: f32): + %res = arith.addf %a, %c2f : f32 + linalg.yield %res : f32 + } -> tensor<10x10xf32> + + %empty2 = tensor.empty() : tensor<10x10xf32> + // CHECK: scf.for {{.*}} { + // CHECK: scf.for {{.*}} { + // CHECK: linalg.generic + // CHECK: linalg.negf + // CHECK: } + // CHECK: } + %1 = linalg.negf ins(%0 : tensor<10x10xf32>) outs(%empty2 : tensor<10x10xf32>) -> tensor<10x10xf32> + + return %1 : tensor<10x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.negf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %tiled_low, %loop1, %loop2 = transform.structured.fuse %0 [5, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + //////////////////////////////////////////////////////////////////////////////// // Tests below are expected to fail.