Skip to content

Commit

Permalink
Use upstream matmul pack (#911)
Browse files Browse the repository at this point in the history
Retires downstream matmul packing logic and uses upstream block pack
matmul instead.
  • Loading branch information
adam-smnk authored May 23, 2024
1 parent 3deda45 commit ec38d1d
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 136 deletions.
8 changes: 0 additions & 8 deletions include/TPP/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ FailureOr<linalg::GenericOp>
packConv2DNhwcHwcfOp(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp linalgOp,
ArrayRef<OpFoldResult> tiles);

// Attempt to block a MatmulOp or a BatchMatmulOp.
FailureOr<linalg::LinalgOp> packMatmulOp(RewriterBase &rewriter,
linalg::MatmulOp linalgOp,
ArrayRef<OpFoldResult> tiles);
FailureOr<linalg::LinalgOp> packMatmulOp(RewriterBase &rewriter,
linalg::BatchMatmulOp linalgOp,
ArrayRef<OpFoldResult> tiles);

// Attempt to block a MatmulOp to VNNI format.
FailureOr<linalg::GenericOp> packVNNIMatmulOp(RewriterBase &rewriter,
linalg::GenericOp linalgOp);
Expand Down
170 changes: 50 additions & 120 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,101 +304,6 @@ mlir::linalgx::packConv2DNchwFchwOp(RewriterBase &rewriter,
return packConvolutions(rewriter, convOp, tiles);
}

template <typename OpTy>
static FailureOr<linalg::LinalgOp>
packMatmulOpImpl(RewriterBase &rewriter, OpTy matmulOp,
ArrayRef<OpFoldResult> tiles) {
static_assert(
llvm::is_one_of<OpTy, linalg::MatmulOp, linalg::BatchMatmulOp>::value,
"applies to only matmul or batch matmul operations");

OpBuilder::InsertionGuard guard(rewriter);
// The op is replaced, we need to set the insertion
// point after it.
rewriter.setInsertionPointAfter(matmulOp);

if (matmulOp.hasDynamicShape())
return rewriter.notifyMatchFailure(matmulOp, "require static shape");

if (matmulOp.hasPureBufferSemantics())
return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");

OpFoldResult tileOnI = tiles[0];
OpFoldResult tileOnJ = tiles[1];
OpFoldResult tileOnK = tiles[2];
bool isBatchMatmulOp = std::is_same_v<OpTy, linalg::BatchMatmulOp>;
size_t inc = isBatchMatmulOp ? 1 : 0;
size_t posI = 0 + inc;
size_t posJ = 1 + inc;
size_t posK = 2 + inc;
if (!linalgx::utils::validateFullTilesOnDims(
cast<TilingInterface>(matmulOp.getOperation()),
{tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) {
return rewriter.notifyMatchFailure(matmulOp, "expect full tiles only");
}

// [..][IB][JB][ib][jb] += [..][IB][KB][ib][kb] * [..][KB][JB][jb][kb]
auto packedCanonicalMatmul = linalg::packMatmulGreedily(
rewriter, matmulOp, tiles, /*mnkPaddedSizesNextMultipleOf=*/{},
/*mnkPackedSizes=*/{0, 1, 2});
if (failed(packedCanonicalMatmul))
return failure();

assert(packedCanonicalMatmul->packOps.size() == 3);
assert(packedCanonicalMatmul->unPackOps.size() == 1);

SmallVector<int64_t> innerPerm = {1, 0};
SmallVector<int64_t> outerPerm = {1, 0};
if (isBatchMatmulOp)
outerPerm = {0, 2, 1};
auto packedMatmul =
linalg::packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
packedCanonicalMatmul->packedLinalgOp,
/*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
if (failed(packedMatmul))
return failure();
return packedMatmul->transposedLinalgOp;
}

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
// i j i k k j
// [128 x 256] += [128 x 256] * [256 x 256]
//
// tile factor on i = 32
// tile factor on j = 16
// tile factor on k = 8
//
// [IB][JB][ib][jb] += [IB][KB][ib][kb] * [JB][KB][kb][jb]
// [4 ][16][32][16] += [4 ][32][32][8 ] * [16][32][8 ][16]
// KB is the batch reduce dimension.
FailureOr<linalg::LinalgOp>
mlir::linalgx::packMatmulOp(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
ArrayRef<OpFoldResult> tiles) {
if (tiles.size() != 3)
return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");

return packMatmulOpImpl<linalg::MatmulOp>(rewriter, matmulOp, tiles);
}

//===----------------------------------------------------------------------===//
// BatchMatmulOp
//===----------------------------------------------------------------------===//
// Original layout:
// [B][I][J] += [B][I][K] * [B][K][J]
// New layout:
// [B][IB][JB][ib][jb] += [B][IB][KB][ib][kb] * [B][JB][KB][kb][jb]
FailureOr<linalg::LinalgOp>
mlir::linalgx::packMatmulOp(RewriterBase &rewriter,
linalg::BatchMatmulOp matmulOp,
ArrayRef<OpFoldResult> tiles) {
if (tiles.size() != 3)
return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");

return packMatmulOpImpl<linalg::BatchMatmulOp>(rewriter, matmulOp, tiles);
}

//===----------------------------------------------------------------------===//
// MatmulOp (VNNI packing)
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -555,29 +460,6 @@ getDefaultBlockingFactors(linalg::LinalgOp linalgOp) {
// Passes
//===----------------------------------------------------------------------===//

// Pack MatmulOp and BatchMatmulOp.
template <typename OpTy> struct PackMatmulImpl : public OpRewritePattern<OpTy> {
PackMatmulImpl(MLIRContext *context, ArrayRef<int64_t> blockingFactors,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit),
blockingFactors(blockingFactors) {}

LogicalResult matchAndRewrite(OpTy matmulOp,
PatternRewriter &rewriter) const override {
if (blockingFactors.empty())
blockingFactors = getDefaultBlockingFactors(matmulOp);
FailureOr<linalg::GenericOp> packedMatmul = mlir::linalgx::packMatmulOp(
rewriter, matmulOp,
getAsOpFoldResult(rewriter.getI64ArrayAttr(blockingFactors)));
if (failed(packedMatmul))
return failure();
return success();
}

private:
mutable SmallVector<int64_t> blockingFactors;
};

// Entry point for packing a matmul operation.
// Pack MatmulOp as following:
// [NB][KB][nb][kb] += [NB][CB][nb][cb] * [KB][CB][cb][kb]
Expand All @@ -591,9 +473,57 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
void runOnOperation() override {
MLIRContext *ctx = getOperation().getContext();
RewritePatternSet patterns(ctx);
patterns.add<PackMatmulImpl<linalg::MatmulOp>,
PackMatmulImpl<linalg::BatchMatmulOp>>(ctx, blockingFactors);

auto packControlFn = [&](linalg::LinalgOp linalgOp)
-> std::optional<linalg::BlockPackMatmulOptions> {
linalg::BlockPackMatmulOptions options;

// Pack only these two named matmul variants.
if (!(isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchMatmulOp>(linalgOp))) {
return std::nullopt;
}

// Enforce user defined blocking factors or use defaults.
if (!blockingFactors.empty()) {
SmallVector<int64_t, 3> blockFactors{*blockingFactors};
options.blockFactors = blockFactors;
} else {
options.blockFactors = getDefaultBlockingFactors(linalgOp);
}

// Allow padding to avoid double checks.
options.allowPadding = true;

// Apply more restrictive packing validation.
OpBuilder builder(linalgOp);
SmallVector<OpFoldResult> tiles =
getAsOpFoldResult(builder.getI64ArrayAttr(options.blockFactors));
OpFoldResult tileOnI = tiles[0];
OpFoldResult tileOnJ = tiles[1];
OpFoldResult tileOnK = tiles[2];
bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(linalgOp);
size_t inc = isBatchMatmulOp ? 1 : 0;
size_t posI = 0 + inc;
size_t posJ = 1 + inc;
size_t posK = 2 + inc;
if (!linalgx::utils::validateFullTilesOnDims(
cast<TilingInterface>(linalgOp.getOperation()),
{tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) {
return std::nullopt;
}

// Apply XSMM packing with block transpose only.
options.lhsTransposeOuterBlocks = false;
options.lhsTransposeInnerBlocks = false;
options.rhsTransposeOuterBlocks = true;
options.rhsTransposeInnerBlocks = false;

return options;
};
linalg::populateBlockPackMatmulPatterns(patterns, packControlFn);
linalg::populateLinalgDeGeneralizationPatterns(patterns);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down
2 changes: 1 addition & 1 deletion test/BF16/matmul-vnni.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func.func @matmul_static(
// CHECK-LABEL: matmul_static
// CHECK-SAME: %[[ARG0:.+]]: tensor<256x512xbf16>, %[[ARG1:.+]]: tensor<512x1024xbf16>, %[[ARG2:.+]]: tensor<256x1024xbf16>
// CHECK: %[[EMPTY_0:.+]] = tensor.empty() : tensor<8x16x32x32xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into %[[EMPTY_0]] : tensor<256x512xbf16> -> tensor<8x16x32x32xbf16>
// CHECK: %[[EMPTY_1:.+]] = tensor.empty() : tensor<32x16x32x32xbf16>
// CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
Expand Down
1 change: 0 additions & 1 deletion test/Integration/xsmm-fusion-mlirgen.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 | tpp-run -e entry --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s

// CHECK: func.func @_entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> {
// CHECK: call @xsmm_fused_brgemm_dispatch
// CHECK: scf.parallel
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/DefaultPipeline/default-tpp-passes.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tpp-opt %s -default-tpp-passes -split-input-file | FileCheck %s
// RUN: tpp-opt %s -default-tpp-passes -split-input-file 2>/dev/null | FileCheck %s

// CHECK: func.func @matmul(
// CHECK-SAME: %[[ARG0:.+]]: memref<4x8xf32>,
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/pass-matmul-blocking-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
Expand Down
8 changes: 4 additions & 4 deletions test/Passes/pass-matmul-blocking.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
Expand Down Expand Up @@ -69,7 +69,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: outs(%[[ARG2]] : tensor<128x128xf32>) -> tensor<128x128xf32>
// CHECK: %[[EMPTY_ARG0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into %[[EMPTY_ARG0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[EMPTY_ARG1:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
Expand Down Expand Up @@ -111,7 +111,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
Expand All @@ -137,7 +137,7 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x64x128xf32>, %arg1: tensor<512
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<512x64x64xf32>
// CHECK: %[[ARG0_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32>
// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 32]
// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32]
// CHECK-SAME: into %[[ARG0_PACK_OUT]] : tensor<512x64x128xf32> -> tensor<512x2x4x32x32xf32>
// CHECK: %[[ARG1_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32>
// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
Expand Down

0 comments on commit ec38d1d

Please sign in to comment.