Skip to content

Commit

Permalink
Retire downstream pack-fill propagation (#912)
Browse files Browse the repository at this point in the history
The logic is covered by upstream fill op canonicalization pattern.
  • Loading branch information
adam-smnk authored May 15, 2024
1 parent 7ee8e7e commit 08caaf6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 51 deletions.
52 changes: 2 additions & 50 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,49 +539,6 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter,

namespace {

//===----------------------------------------------------------------------===//
// BubbleUpThroughFillOp
//===----------------------------------------------------------------------===//

// Attempt to avoid packing a fill op. Instead create a 'packed' fill.
// %0 = tensor.empty
// %packed = tensor.empty
// %1 = linalg.fill ins(%cst) outs(%0)
// %2 = tensor.pack %1 into %packed
// %3 = some_packed_op %2
//
// --->
//
// %0 = tensor.empty
// %1 = linalg.fill ins(%cst) outs (%packed)
// %2 = some_packed_op %1
// %3 = tensor.unpack %2 into %0
//
struct BubbleUpThroughFillOp : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
Value source = packOp.getSource();
auto fillOp = source.getDefiningOp<linalg::FillOp>();
if (!fillOp)
return failure();

Value fillRes = fillOp.getResult(0);
if (!fillRes.hasOneUse())
return failure();

// Replace result with output.
rewriter.replaceAllUsesWith(fillRes, fillOp.getOutputs()[0]);
auto empty = tensor::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), source, packOp.getMixedTiles(),
packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
rewriter.replaceOpWithNewOp<linalg::FillOp>(packOp, fillOp.getInputs(),
empty);
return success();
}
};

static SmallVector<int64_t>
getDefaultBlockingFactors(linalg::LinalgOp linalgOp) {
assert(linalgOp && "expect a valid linalgOp");
Expand Down Expand Up @@ -758,7 +715,8 @@ struct PropagatePackUnPack
void runOnOperation() override {
MLIRContext *ctx = getOperation().getContext();
RewritePatternSet patterns(ctx);
tpp::populateSinkPackPatterns(patterns);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](Operation *op) { return true; });
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down Expand Up @@ -1158,9 +1116,3 @@ void mlir::tpp::populateSimplifyPacking(RewritePatternSet &patterns) {
ForAllIterArgsFolder>(ctx);
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
}

void mlir::tpp::populateSinkPackPatterns(RewritePatternSet &patterns) {
linalg::populateDataLayoutPropagationPatterns(
patterns, [](Operation *op) { return true; });
patterns.add<BubbleUpThroughFillOp>(patterns.getContext());
}
2 changes: 1 addition & 1 deletion test/Passes/pack-unpack-propagation.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tpp-opt %s -propagate-pack-and-unpack -simplify-pack -split-input-file | FileCheck %s
// RUN: tpp-opt %s -propagate-pack-and-unpack -simplify-pack -canonicalize -split-input-file | FileCheck %s

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
Expand Down

0 comments on commit 08caaf6

Please sign in to comment.