Skip to content

Commit

Permalink
Pack matmuls with small dimensions (#963)
Browse files Browse the repository at this point in the history
Extends pack matmul control function to automatically adapt blocking
factors for small dimensions.

The new behavior is that the defined or default packing sizes define the
upper bound on the tile size. Dimensions smaller than their block
factors are still packed using their whole size.
This change allows processing matmuls on tall-and-skinny matrices and
partially decouples the decision of what and how to pack from the
mechanism itself. For simplicity, any small matmul is be packed now.

Long-term, packing driver should be augmented with a separate cost
functions that governs operation selection, need for padding, smartly
adjusts blocking factors etc.
  • Loading branch information
adam-smnk authored Sep 6, 2024
1 parent 67f9830 commit 99cb584
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 79 deletions.
3 changes: 3 additions & 0 deletions include/TPP/Transforms/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ bool isBlockedMatmul(Operation *op);
FailureOr<linalg::ContractionDimensions>
isContraction(linalg::LinalgOp linalgOp);

// Return constant range span or nullopt, otherwise.
std::optional<int64_t> getConstantRange(const Range &range);

// Validate a tile configuration for a linalgOp when we can statically do that.
// Specific dims can be passed using 'dims'. If dims is empty the validation
// will start from the outermost dimension, moving to innermost ones up to the
Expand Down
23 changes: 22 additions & 1 deletion lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
MLIRContext *ctx = getOperation().getContext();
RewritePatternSet patterns(ctx);

// TODO: Add a cost function that decides whether to pack at all.
auto packControlFn = [&](linalg::LinalgOp linalgOp)
-> std::optional<linalg::BlockPackMatmulOptions> {
linalg::BlockPackMatmulOptions options;
Expand All @@ -501,8 +502,28 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
// Allow padding to avoid double checks.
options.allowPadding = true;

// Apply more restrictive packing validation.
// Adjust block factors to smaller dimensions.
// If a dimension is smaller than the blocking factor, then
// try to block by the dimension size.
auto dims = linalg::inferContractionDims(linalgOp);
if (failed(dims))
return std::nullopt;

OpBuilder builder(linalgOp);
auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);

if (std::optional<int64_t> dimM =
linalgx::utils::getConstantRange(iterationDomain[dims->m.back()]))
options.blockFactors[0] = std::min(*dimM, options.blockFactors[0]);
if (std::optional<int64_t> dimN =
linalgx::utils::getConstantRange(iterationDomain[dims->n.back()]))
options.blockFactors[1] = std::min(*dimN, options.blockFactors[1]);
if (std::optional<int64_t> dimK =
linalgx::utils::getConstantRange(iterationDomain[dims->k.back()]))
options.blockFactors[2] = std::min(*dimK, options.blockFactors[2]);

// Apply more restrictive packing validation.
SmallVector<OpFoldResult> tiles =
getAsOpFoldResult(builder.getI64ArrayAttr(options.blockFactors));
OpFoldResult tileOnI = tiles[0];
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ isContraction(linalg::LinalgOp linalgOp) {
return dims;
}

static std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> stride = getConstantIntValue(range.stride);
if (!stride || *stride != 1)
return std::nullopt;
Expand Down
53 changes: 3 additions & 50 deletions test/Passes/DefaultPipeline/default-tpp-passes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ func.func @matmul(%A: tensor<4x8xf32>,
%B: tensor<8x4xf32>, %C: tensor<4x4xf32>) -> tensor<4x4xf32> {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: call @xsmm_gemm_dispatch
// CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]]
// CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index
// CHECK-NEXT: %[[cast_ptr0:.*]] = arith.index_cast %[[ptr0]] : index to i64
// CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[cast_ptr0]] : i64 to !llvm.ptr

// CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]]
// CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index
// CHECK-NEXT: %[[cast_ptr1:.*]] = arith.index_cast %[[ptr1]] : index to i64
// CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[cast_ptr1]] : i64 to !llvm.ptr

// CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]]
// CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index
// CHECK-NEXT: %[[cast_ptr2:.*]] = arith.index_cast %[[ptr2]] : index to i64
// CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[cast_ptr2]] : i64 to !llvm.ptr

Expand Down Expand Up @@ -90,53 +90,6 @@ func.func @conv2d_1x1(

// -----

#map = affine_map<(d0, d1) -> (d0 + d1)>

// CHECK-LABEL: @conv2d_1x1_decomposed(
// CHECK-SAME: %[[arg:.*]]: memref<1x7x7x2048xf32>) -> memref<1x7x7x512xf32> {
func.func @conv2d_1x1_decomposed(
%arg0 : tensor<1x7x7x2048xf32>) -> tensor<1x7x7x512xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c7 = arith.constant 7 : index

// Conv2D weights
%cst = arith.constant dense<0.00332225906> : tensor<2048x512xf32>

// 1x1 Conv2D
// CHECK: call @xsmm_gemm_dispatch
// CHECK: scf.for
// CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr
// CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr
// CHECK: %[[ptr2:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr
// CHECK: call @xsmm_gemm_invoke({{.*}}%[[ptr0]], %{{.+}}, %[[ptr1]], %{{.+}}, %[[ptr2]], %{{.+}}
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x7x7x512xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32>
%2 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %1) -> (tensor<1x7x7x512xf32>) {
%3 = scf.for %arg3 = %c0 to %c7 step %c1 iter_args(%arg4 = %arg2) -> (tensor<1x7x7x512xf32>) {
%4 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %arg4) -> (tensor<1x7x7x512xf32>) {
%5 = scf.for %arg7 = %c0 to %c1 step %c1 iter_args(%arg8 = %arg6) -> (tensor<1x7x7x512xf32>) {
%6 = affine.apply #map(%arg3, %arg5)
%extracted_slice = tensor.extract_slice %arg0[%arg1, %6, %arg7, 0] [1, 1, 7, 2048] [1, 1, 1, 1] : tensor<1x7x7x2048xf32> to tensor<7x2048xf32>
%extracted_slice_1 = tensor.extract_slice %arg8[%arg1, %arg3, 0, 0] [1, 1, 7, 512] [1, 1, 1, 1] : tensor<1x7x7x512xf32> to tensor<7x512xf32>
%7 = linalg.matmul ins(%extracted_slice, %cst : tensor<7x2048xf32>, tensor<2048x512xf32>) outs(%extracted_slice_1 : tensor<7x512xf32>) -> tensor<7x512xf32>
%inserted_slice = tensor.insert_slice %7 into %arg8[%arg1, %arg3, 0, 0] [1, 1, 7, 512] [1, 1, 1, 1] : tensor<7x512xf32> into tensor<1x7x7x512xf32>
scf.yield %inserted_slice : tensor<1x7x7x512xf32>
}
scf.yield %5 : tensor<1x7x7x512xf32>
}
scf.yield %4 : tensor<1x7x7x512xf32>
}
scf.yield %3 : tensor<1x7x7x512xf32>
}

// CHECK: return {{.*}} : memref<1x7x7x512xf32>
return %2 : tensor<1x7x7x512xf32>
}

// -----

#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
Expand Down
7 changes: 3 additions & 4 deletions test/Passes/DefaultPipeline/linalg-to-xsmm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ func.func @gemm_with_zero(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> ten
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i64
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64
// CHECK-NOT: xsmm_unary_dispatch
// CHECK: %[[ALLOC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<3x3xf32>
// CHECK: %[[DIS:.+]] = call @xsmm_gemm_dispatch(%[[C1]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C4]])
// CHECK: %[[INT_PTR_ARG0:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<3x3xf32> -> index
// CHECK: %[[INT_PTR_ARG0:.+]] = memref.extract_aligned_pointer_as_index
// CHECK: %[[CAST_ARG0:.+]] = arith.index_cast %[[INT_PTR_ARG0]] : index to i64
// CHECK: %[[LLVM_PTR_ARG0:.+]] = llvm.inttoptr %[[CAST_ARG0]] : i64 to !llvm.ptr
// CHECK: %[[INT_PTR_ARG1:.+]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<3x3xf32> -> index
// CHECK: %[[INT_PTR_ARG1:.+]] = memref.extract_aligned_pointer_as_index
// CHECK: %[[CAST_ARG1:.+]] = arith.index_cast %[[INT_PTR_ARG1]] : index to i64
// CHECK: %[[LLVM_PTR_ARG1:.+]] = llvm.inttoptr %[[CAST_ARG1]] : i64 to !llvm.ptr
// CHECK: %[[INT_PTR_ALLOC:.+]] = memref.extract_aligned_pointer_as_index %[[ALLOC]] : memref<3x3xf32> -> index
// CHECK: %[[INT_PTR_ALLOC:.+]] = memref.extract_aligned_pointer_as_index
// CHECK: %[[CAST_ALLOC:.+]] = arith.index_cast %[[INT_PTR_ALLOC]] : index to i64
// CHECK: %[[LLVM_PTR_ALLOC:.+]] = llvm.inttoptr %[[CAST_ALLOC]] : i64 to !llvm.ptr
// CHECK: call @xsmm_gemm_invoke(%[[C1]], %[[DIS]], %[[LLVM_PTR_ARG0]], %[[C0]], %[[LLVM_PTR_ARG1]], %[[C0]], %[[LLVM_PTR_ALLOC]], %[[C0]])
36 changes: 36 additions & 0 deletions test/Passes/pass-matmul-blocking-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,39 @@ func.func @block_linalg_matmul_transpose_b(
// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<4x4x32x32xf32>, tensor<4x4x32x32xf32>) outs(%[[PACK2]] : tensor<4x4x32x32xf32>)
// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[ARG2]] : tensor<4x4x32x32xf32> -> tensor<128x128xf32>
// CHECK: return %[[OUT]] : tensor<128x128xf32>

// -----

func.func @block_linalg_matmul_dynamic(
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
-> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2: tensor<?x?xf32>)
-> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

// CHECK-LABEL: func @block_linalg_matmul_dynamic(
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PAD]] : f32)
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
// CHECK-SAME: inner_tiles = [32, 32] into {{.*}} : tensor<?x?xf32> -> tensor<?x?x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] padding_value(%[[PAD]] : f32)
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
// CHECK-SAME: inner_tiles = [32, 32] into {{.*}} : tensor<?x?xf32> -> tensor<?x?x32x32xf32>
// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]] padding_value(%[[PAD]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into {{.*}} : tensor<?x?xf32> -> tensor<?x?x32x32xf32>
// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<?x?x32x32xf32>, tensor<?x?x32x32xf32>) outs(%[[PACK2]] : tensor<?x?x32x32xf32>)
// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into %[[ARG2]] : tensor<?x?x32x32xf32> -> tensor<?x?xf32>
// CHECK: return %[[OUT]] : tensor<?x?xf32>
43 changes: 22 additions & 21 deletions test/Passes/pass-matmul-blocking.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ func.func @block_dims_equal_to_factors(

// -----

// We don't expect to block as the blocking factor do not create full tiles.
func.func @block_linalg_matmul(
// Adapt tile sizes to small dimensions.
// Assume that there is separate cost function that controls
// if packing should take place at all.
func.func @block_small_dims_matmul(
%arg0: tensor<5x6xf32>, %arg1: tensor<6x5xf32>, %arg2: tensor<5x5xf32>)
-> tensor<5x5xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<5x6xf32>, tensor<6x5xf32>)
Expand All @@ -70,13 +72,24 @@ func.func @block_linalg_matmul(
return %0 : tensor<5x5xf32>
}

// CHECK-LABEL: func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<5x6xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<6x5xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<5x5xf32>) -> tensor<5x5xf32> {
// CHECK: %{{.+}} = linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[ARG2]]
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

// CHECK-LABEL: func @block_small_dims_matmul(
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<5x6xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<6x5xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<5x5xf32>) -> tensor<5x5xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<1x1x5x6xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [5, 6] into %[[BUF0]] : tensor<5x6xf32> -> tensor<1x1x5x6xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<1x1x6x5xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [6, 5] into %[[BUF1]] : tensor<6x5xf32> -> tensor<1x1x6x5xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<1x1x5x5xf32>
// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]] inner_dims_pos = [0, 1] inner_tiles = [5, 5] into %[[BUF2]] : tensor<5x5xf32> -> tensor<1x1x5x5xf32>
// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<1x1x5x6xf32>, tensor<1x1x6x5xf32>) outs(%[[PACK2]] : tensor<1x1x5x5xf32>)
// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [5, 5] into %[[ARG2]] : tensor<1x1x5x5xf32> -> tensor<5x5xf32>
// CHECK: return %[[OUT]] : tensor<5x5xf32>
// CHECK: }

// -----

Expand Down Expand Up @@ -183,15 +196,3 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x64x128xf32>, %arg1: tensor<512
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[GEN]]
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 32]
// CHECK-SAME: into %[[OUT]] : tensor<512x2x2x32x32xf32> -> tensor<512x64x64xf32>

// -----

// CHECK-LABEL: batch_matmul_invalid_tiles
func.func @batch_matmul_invalid_tiles(%arg0: tensor<5x5x5xf32>, %arg1: tensor<5x5x5xf32>) -> tensor<5x5x5xf32> {
%0 = tensor.empty() : tensor<5x5x5xf32>
// CHECK: linalg.batch_matmul
// CHECK-NOT: linalg.generic
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>)
outs(%0 : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
return %1 : tensor<5x5x5xf32>
}
5 changes: 3 additions & 2 deletions test/Passes/tpp-mapping.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ func.func @conv_to_matmul(%img: tensor<1x5x5x3xf32>, %filter: tensor<3x3x3x8xf32
// CHECK: scf.for
// CHECK: tensor.extract_slice{{[^:]+}}: tensor<1x5x5x3xf32> to tensor<3x3xf32>
// CHECK: tensor.extract_slice{{[^:]+}}: tensor<3x3x3x8xf32> to tensor<3x8xf32>
// CHECK: tensor.extract_slice{{[^:]+}}: tensor<1x3x3x8xf32> to tensor<3x8xf32>
// CHECK: tensor.extract_slice{{[^:]+}}: tensor<1x1x3x8xf32> to tensor<3x8xf32>
// CHECK: linalg.matmul{{.*}} -> tensor<3x8xf32>
// CHECK: tensor.insert_slice{{[^:]+}}: tensor<3x8xf32> into tensor<1x3x3x8xf32>
// CHECK: tensor.insert_slice{{[^:]+}}: tensor<3x8xf32> into tensor<1x1x3x8xf32>
// CHECK: tensor.insert_slice{{[^:]+}}: tensor<1x1x3x8xf32> into tensor<1x3x3x8xf32>
// CHECK: }

// -----
Expand Down

0 comments on commit 99cb584

Please sign in to comment.