diff --git a/include/TPP/Transforms/Transforms.h b/include/TPP/Transforms/Transforms.h index 8cdcd2b9f..93ed2e176 100644 --- a/include/TPP/Transforms/Transforms.h +++ b/include/TPP/Transforms/Transforms.h @@ -51,14 +51,6 @@ FailureOr packConv2DNhwcHwcfOp(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp linalgOp, ArrayRef tiles); -// Attempt to block a MatmulOp or a BatchMatmulOp. -FailureOr packMatmulOp(RewriterBase &rewriter, - linalg::MatmulOp linalgOp, - ArrayRef tiles); -FailureOr packMatmulOp(RewriterBase &rewriter, - linalg::BatchMatmulOp linalgOp, - ArrayRef tiles); - // Attempt to block a MatmulOp to VNNI format. FailureOr packVNNIMatmulOp(RewriterBase &rewriter, linalg::GenericOp linalgOp); diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index 44eb10e45..1b4600d74 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -304,101 +304,6 @@ mlir::linalgx::packConv2DNchwFchwOp(RewriterBase &rewriter, return packConvolutions(rewriter, convOp, tiles); } -template -static FailureOr -packMatmulOpImpl(RewriterBase &rewriter, OpTy matmulOp, - ArrayRef tiles) { - static_assert( - llvm::is_one_of::value, - "applies to only matmul or batch matmul operations"); - - OpBuilder::InsertionGuard guard(rewriter); - // The op is replaced, we need to set the insertion - // point after it. - rewriter.setInsertionPointAfter(matmulOp); - - if (matmulOp.hasDynamicShape()) - return rewriter.notifyMatchFailure(matmulOp, "require static shape"); - - if (matmulOp.hasPureBufferSemantics()) - return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics"); - - OpFoldResult tileOnI = tiles[0]; - OpFoldResult tileOnJ = tiles[1]; - OpFoldResult tileOnK = tiles[2]; - bool isBatchMatmulOp = std::is_same_v; - size_t inc = isBatchMatmulOp ? 1 : 0; - size_t posI = 0 + inc; - size_t posJ = 1 + inc; - size_t posK = 2 + inc; - if (!linalgx::utils::validateFullTilesOnDims( - cast(matmulOp.getOperation()), - {tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) { - return rewriter.notifyMatchFailure(matmulOp, "expect full tiles only"); - } - - // [..][IB][JB][ib][jb] += [..][IB][KB][ib][kb] * [..][KB][JB][jb][kb] - auto packedCanonicalMatmul = linalg::packMatmulGreedily( - rewriter, matmulOp, tiles, /*mnkPaddedSizesNextMultipleOf=*/{}, - /*mnkPackedSizes=*/{0, 1, 2}); - if (failed(packedCanonicalMatmul)) - return failure(); - - assert(packedCanonicalMatmul->packOps.size() == 3); - assert(packedCanonicalMatmul->unPackOps.size() == 1); - - SmallVector innerPerm = {1, 0}; - SmallVector outerPerm = {1, 0}; - if (isBatchMatmulOp) - outerPerm = {0, 2, 1}; - auto packedMatmul = - linalg::packTranspose(rewriter, packedCanonicalMatmul->packOps[1], - packedCanonicalMatmul->packedLinalgOp, - /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm); - if (failed(packedMatmul)) - return failure(); - return packedMatmul->transposedLinalgOp; -} - -//===----------------------------------------------------------------------===// -// MatmulOp -//===----------------------------------------------------------------------===// -// i j i k k j -// [128 x 256] += [128 x 256] * [256 x 256] -// -// tile factor on i = 32 -// tile factor on j = 16 -// tile factor on k = 8 -// -// [IB][JB][ib][jb] += [IB][KB][ib][kb] * [JB][KB][kb][jb] -// [4 ][16][32][16] += [4 ][32][32][8 ] * [16][32][8 ][16] -// KB is the batch reduce dimension. -FailureOr -mlir::linalgx::packMatmulOp(RewriterBase &rewriter, linalg::MatmulOp matmulOp, - ArrayRef tiles) { - if (tiles.size() != 3) - return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors"); - - return packMatmulOpImpl(rewriter, matmulOp, tiles); -} - -//===----------------------------------------------------------------------===// -// BatchMatmulOp -//===----------------------------------------------------------------------===// -// Original layout: -// [B][I][J] += [B][I][K] * [B][K][J] -// New layout: -// [B][IB][JB][ib][jb] += [B][IB][KB][ib][kb] * [B][JB][KB][kb][jb] -FailureOr -mlir::linalgx::packMatmulOp(RewriterBase &rewriter, - linalg::BatchMatmulOp matmulOp, - ArrayRef tiles) { - if (tiles.size() != 3) - return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors"); - - return packMatmulOpImpl(rewriter, matmulOp, tiles); -} - //===----------------------------------------------------------------------===// // MatmulOp (VNNI packing) //===----------------------------------------------------------------------===// @@ -555,29 +460,6 @@ getDefaultBlockingFactors(linalg::LinalgOp linalgOp) { // Passes //===----------------------------------------------------------------------===// -// Pack MatmulOp and BatchMatmulOp. -template struct PackMatmulImpl : public OpRewritePattern { - PackMatmulImpl(MLIRContext *context, ArrayRef blockingFactors, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - blockingFactors(blockingFactors) {} - - LogicalResult matchAndRewrite(OpTy matmulOp, - PatternRewriter &rewriter) const override { - if (blockingFactors.empty()) - blockingFactors = getDefaultBlockingFactors(matmulOp); - FailureOr packedMatmul = mlir::linalgx::packMatmulOp( - rewriter, matmulOp, - getAsOpFoldResult(rewriter.getI64ArrayAttr(blockingFactors))); - if (failed(packedMatmul)) - return failure(); - return success(); - } - -private: - mutable SmallVector blockingFactors; -}; - // Entry point for packing a matmul operation. // Pack MatmulOp as following: // [NB][KB][nb][kb] += [NB][CB][nb][cb] * [KB][CB][cb][kb] @@ -591,9 +473,57 @@ struct PackMatmul : public tpp::impl::PackMatmulBase { void runOnOperation() override { MLIRContext *ctx = getOperation().getContext(); RewritePatternSet patterns(ctx); - patterns.add, - PackMatmulImpl>(ctx, blockingFactors); + + auto packControlFn = [&](linalg::LinalgOp linalgOp) + -> std::optional { + linalg::BlockPackMatmulOptions options; + + // Pack only these two named matmul variants. + if (!(isa(linalgOp) || + isa(linalgOp))) { + return std::nullopt; + } + + // Enforce user defined blocking factors or use defaults. + if (!blockingFactors.empty()) { + SmallVector blockFactors{*blockingFactors}; + options.blockFactors = blockFactors; + } else { + options.blockFactors = getDefaultBlockingFactors(linalgOp); + } + + // Allow padding to avoid double checks. + options.allowPadding = true; + + // Apply more restrictive packing validation. + OpBuilder builder(linalgOp); + SmallVector tiles = + getAsOpFoldResult(builder.getI64ArrayAttr(options.blockFactors)); + OpFoldResult tileOnI = tiles[0]; + OpFoldResult tileOnJ = tiles[1]; + OpFoldResult tileOnK = tiles[2]; + bool isBatchMatmulOp = isa(linalgOp); + size_t inc = isBatchMatmulOp ? 1 : 0; + size_t posI = 0 + inc; + size_t posJ = 1 + inc; + size_t posK = 2 + inc; + if (!linalgx::utils::validateFullTilesOnDims( + cast(linalgOp.getOperation()), + {tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) { + return std::nullopt; + } + + // Apply XSMM packing with block transpose only. + options.lhsTransposeOuterBlocks = false; + options.lhsTransposeInnerBlocks = false; + options.rhsTransposeOuterBlocks = true; + options.rhsTransposeInnerBlocks = false; + + return options; + }; + linalg::populateBlockPackMatmulPatterns(patterns, packControlFn); linalg::populateLinalgDeGeneralizationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; diff --git a/test/BF16/matmul-vnni.mlir b/test/BF16/matmul-vnni.mlir index 701562bff..547d6d1d8 100644 --- a/test/BF16/matmul-vnni.mlir +++ b/test/BF16/matmul-vnni.mlir @@ -19,7 +19,7 @@ func.func @matmul_static( // CHECK-LABEL: matmul_static // CHECK-SAME: %[[ARG0:.+]]: tensor<256x512xbf16>, %[[ARG1:.+]]: tensor<512x1024xbf16>, %[[ARG2:.+]]: tensor<256x1024xbf16> // CHECK: %[[EMPTY_0:.+]] = tensor.empty() : tensor<8x16x32x32xbf16> -// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] // CHECK-SAME: into %[[EMPTY_0]] : tensor<256x512xbf16> -> tensor<8x16x32x32xbf16> // CHECK: %[[EMPTY_1:.+]] = tensor.empty() : tensor<32x16x32x32xbf16> // CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] diff --git a/test/Integration/xsmm-fusion-mlirgen.mlir b/test/Integration/xsmm-fusion-mlirgen.mlir index d6465b2f9..349ee9e9d 100644 --- a/test/Integration/xsmm-fusion-mlirgen.mlir +++ b/test/Integration/xsmm-fusion-mlirgen.mlir @@ -1,5 +1,4 @@ // RUN: mlir-gen --kernel=const --bias --relu --seed=123 | tpp-run -e entry --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s - // CHECK: func.func @_entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> { // CHECK: call @xsmm_fused_brgemm_dispatch // CHECK: scf.parallel diff --git a/test/Passes/DefaultPipeline/default-tpp-passes.mlir b/test/Passes/DefaultPipeline/default-tpp-passes.mlir index baa863c5f..36bd8d61a 100644 --- a/test/Passes/DefaultPipeline/default-tpp-passes.mlir +++ b/test/Passes/DefaultPipeline/default-tpp-passes.mlir @@ -1,4 +1,4 @@ -// RUN: tpp-opt %s -default-tpp-passes -split-input-file | FileCheck %s +// RUN: tpp-opt %s -default-tpp-passes -split-input-file 2>/dev/null | FileCheck %s // CHECK: func.func @matmul( // CHECK-SAME: %[[ARG0:.+]]: memref<4x8xf32>, diff --git a/test/Passes/pass-matmul-blocking-default.mlir b/test/Passes/pass-matmul-blocking-default.mlir index c8b8311bd..26db13fc0 100644 --- a/test/Passes/pass-matmul-blocking-default.mlir +++ b/test/Passes/pass-matmul-blocking-default.mlir @@ -18,7 +18,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> { // CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> -// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32> diff --git a/test/Passes/pass-matmul-blocking.mlir b/test/Passes/pass-matmul-blocking.mlir index d7032eba4..c73fb8911 100644 --- a/test/Passes/pass-matmul-blocking.mlir +++ b/test/Passes/pass-matmul-blocking.mlir @@ -18,7 +18,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> { // CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> -// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32> @@ -69,7 +69,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: outs(%[[ARG2]] : tensor<128x128xf32>) -> tensor<128x128xf32> // CHECK: %[[EMPTY_ARG0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] // CHECK-SAME: into %[[EMPTY_ARG0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[EMPTY_ARG1:.+]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] @@ -111,7 +111,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> { // CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> -// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32> @@ -137,7 +137,7 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x64x128xf32>, %arg1: tensor<512 // CHECK: %[[OUT:.+]] = tensor.empty() : tensor<512x64x64xf32> // CHECK: %[[ARG0_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32> // CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 32] +// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] // CHECK-SAME: into %[[ARG0_PACK_OUT]] : tensor<512x64x128xf32> -> tensor<512x2x4x32x32xf32> // CHECK: %[[ARG1_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32> // CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]