From b058fca93d79a6a5a7814f16f36fc5c582611603 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 4 Jun 2024 17:26:04 +0200 Subject: [PATCH] Switch to scf::tileUsingSCF Deprecates TPP usage of linalg::tileToForallOpUsingTileSizes in preparation for upstream Linalg API deprecation. The corresponding test is updated as SCF tiling API folds affine maps as part of tiling resulting in simpler IR. Fixes #676 --- .../Transforms/RewriteBatchMatmulToMatmul.cpp | 11 ++++--- .../pass-rewrite-batch-matmul-to-matmul.mlir | 31 ++++++++----------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp b/lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp index 3d0a1519f..0dbfed330 100644 --- a/lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp +++ b/lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp @@ -103,12 +103,15 @@ struct RewriteBatchMatmulToMatmul tiles[0] = getAsIndexOpFoldResult(rewriter.getContext(), 1); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(batchMatmulOp); - auto tilingResult = linalg::tileToForallOpUsingTileSizes( - rewriter, cast(batchMatmulOp.getOperation()), tiles, - /*mapping=*/std::nullopt); + scf::SCFTilingOptions tilingOpts; + tilingOpts.setTileSizes(tiles); + tilingOpts.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + auto tilingResult = scf::tileUsingSCF( + rewriter, cast(batchMatmulOp.getOperation()), + tilingOpts); if (failed(tilingResult)) return signalPassFailure(); - rewriter.replaceOp(batchMatmulOp, tilingResult->tileOp->getResults()); + rewriter.replaceOp(batchMatmulOp, tilingResult->replacements); }); // Step2: diff --git a/test/Passes/pass-rewrite-batch-matmul-to-matmul.mlir b/test/Passes/pass-rewrite-batch-matmul-to-matmul.mlir index 91448ee6b..7dc85d31e 100644 --- a/test/Passes/pass-rewrite-batch-matmul-to-matmul.mlir +++ b/test/Passes/pass-rewrite-batch-matmul-to-matmul.mlir @@ -53,17 +53,14 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x?x?xf32>, // ----- -// TODO: tiling using scf.forall introduces the affine.min that prevents -// rank reducing the tensor and map to brgemm. See: #676 func.func @batch_matmul_rewrite(%arg0: tensor, - %arg1: tensor, %dim0: index, %dim1: index, %bacth: index) -> tensor { - %0 = tensor.empty(%bacth, %dim0, %dim1) : tensor + %arg1: tensor, %dim0: index, %dim1: index, %batch: index) -> tensor { + %0 = tensor.empty(%batch, %dim0, %dim1) : tensor %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) -> tensor return %1 : tensor } -// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 1)> // CHECK-LABEL: batch_matmul_rewrite // CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor, // CHECK-SAME: %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index @@ -72,18 +69,16 @@ func.func @batch_matmul_rewrite(%arg0: tensor, // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[ARG4]], %[[ARG2]], %[[ARG3]]) : tensor // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor // CHECK: %{{.+}} = scf.forall (%[[ARG5:.+]]) in (%[[DIM]]) // CHECK-SAME: shared_outs(%[[ARG6:.+]] = %[[EMPTY]]) -> (tensor) { -// CHECK: %[[MIN:.+]] = affine.min #[[MAP]](%[[ARG5]])[%[[DIM0]]] -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG5]], 0, 0] [%[[MIN]], %[[DIM1]], %[[DIM2]]] [1, 1, 1] -// CHECK-SAME: : tensor to tensor -// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG5]], 0, 0] [%[[MIN]], %[[DIM2]], %[[DIM3]]] [1, 1, 1] -// CHECK-SAME: : tensor to tensor -// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG5]], 0, 0] [%[[MIN]], %[[DIM1]], %[[DIM3]]] [1, 1, 1] -// CHECK-SAME: : tensor to tensor -// CHECK: %{{.+}} = linalg.batch_matmul ins(%[[SLICE]], %[[SLICE1]] : tensor, tensor) -// CHECK-SAME: outs(%[[SLICE2]] : tensor) -> tensor +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG5]], 0, 0] [1, %[[DIM0]], %[[DIM1]]] [1, 1, 1] +// CHECK-SAME: : tensor to tensor +// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG5]], 0, 0] [1, %[[DIM1]], %[[DIM2]]] [1, 1, 1] +// CHECK-SAME: : tensor to tensor +// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG5]], 0, 0] [1, %[[DIM0]], %[[DIM2]]] [1, 1, 1] +// CHECK-SAME: : tensor to tensor +// CHECK: %{{.+}} = linalg.matmul ins(%[[SLICE]], %[[SLICE1]] : tensor, tensor) +// CHECK-SAME: outs(%[[SLICE2]] : tensor) -> tensor