diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index e94ecc7ae..110ec9d24 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -153,7 +153,9 @@ def TileConsumerAndFuseProducers : Pass<"tile-consumer-and-fuse-producers", "Get producers till maxDepth">, Option<"numIters", "num-iters", "int64_t", "3", "Run fusion for the given number of iterations">, - Option<"useForAll", "use-for-all", "bool", "true", "Use parallel forAll"> + Option<"useForAll", "use-for-all", "bool", "true", "Use parallel forAll">, + Option<"minTileFactor", "min-tile-factor", "int64_t", "2", + "Minimum factor between dimension size and a tile size"> ]; let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"]; diff --git a/include/TPP/Transforms/Utils/TransformUtils.h b/include/TPP/Transforms/Utils/TransformUtils.h index 9390db42f..919a1bfe5 100644 --- a/include/TPP/Transforms/Utils/TransformUtils.h +++ b/include/TPP/Transforms/Utils/TransformUtils.h @@ -78,9 +78,13 @@ isContraction(linalg::LinalgOp linalgOp); // Specific dims can be passed using 'dims'. If dims is empty the validation // will start from the outermost dimension, moving to innermost ones up to the // number of tiles. +// Tiling application can restricted based on the workload dimension size. +// The tiling is applied only to when all dimensions fulfill the predicate: +// '(dimSize[i] / tiles[i]) >= minTileFactor'. bool validateFullTilesOnDims(TilingInterface tileOp, ArrayRef tiles, - ArrayRef dims = {}); + ArrayRef dims = {}, + int64_t minTileFactor = 2); // Rewrite scf.for to scf.forall. Assumes the loop to be parallel and // marked with `kLoopId`. diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index 1664343a2..967a98ab6 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -147,7 +147,9 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, // Tile to split the kernel into threads and blocks. // Use default tiling to handle both packed and unpacked ops. pm.addPass(createCleanup()); - pm.addPass(createTileConsumerAndFuseProducers()); + TileConsumerAndFuseProducersOptions tilingOptions; + tilingOptions.minTileFactor = 1; + pm.addPass(createTileConsumerAndFuseProducers(tilingOptions)); pm.addPass(createCleanup()); // Preprocess and bufferize as further conversion requires memref diff --git a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp index 889fbab53..8cf894dd7 100644 --- a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp +++ b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp @@ -85,7 +85,8 @@ static bool isConvolutionLike(Operation *op) { // Return true if `op` can be tiled using `tileSizes`. Require to statically // know the range and the tile factor. The tile must be full. static bool canBeTiledWithCurrentSpec(Operation *op, - ArrayRef tileSizes) { + ArrayRef tileSizes, + int64_t minTileFactor) { assert(isa(op) && "expect an op implementing the tiling interface"); assert(!tileSizes.empty() && "expect tile sizes to be non-empty"); @@ -105,8 +106,8 @@ static bool canBeTiledWithCurrentSpec(Operation *op, } LLVM_DEBUG(llvm::dbgs() << "Running tile validations ----\n"); - if (!linalgx::utils::validateFullTilesOnDims(cast(op), - tileSizes)) { + if (!linalgx::utils::validateFullTilesOnDims( + cast(op), tileSizes, /*dim=*/{}, minTileFactor)) { LLVM_DEBUG(llvm::dbgs() << "FAILED\n"); return false; } @@ -382,7 +383,8 @@ static llvm::SmallDenseSet collectFusableProducers( static FailureOr fuseWithEltwise( RewriterBase &rewriter, TilingInterface consumer, llvm::DenseMap> &tileSizes, - llvm::SmallDenseSet &alreadyFusedOps, int64_t maxDepth) { + llvm::SmallDenseSet &alreadyFusedOps, int64_t maxDepth, + int64_t minTileFactor) { // Step 0. Early exit if tileSizes are empty. if (tileSizes.empty() || !tileSizes.count(consumer)) { LLVM_DEBUG(llvm::dbgs() << "EMPTY TILE SIZES\n"); @@ -397,7 +399,8 @@ static FailureOr fuseWithEltwise( } // Step 2. Check if the tile configuration fits the consumer. - if (!canBeTiledWithCurrentSpec(consumer, tileSizes.at(consumer))) { + if (!canBeTiledWithCurrentSpec(consumer, tileSizes.at(consumer), + minTileFactor)) { LLVM_DEBUG(llvm::dbgs() << "CONSUMER: " << consumer << "\nCANNOT BE TILED WITH CURRENT CONFIG\n"); return failure(); @@ -616,7 +619,8 @@ static Operation *getLastFusableEltWiseConsumer( // Run `fuseWithEltwise` on contraction-like operations. static void doFusion(RewriterBase &rewriter, func::FuncOp func, - ArrayRef tileSizes, int64_t maxDepth) { + ArrayRef tileSizes, int64_t maxDepth, + int64_t minTileFactor) { // Set to keep track of fused ops. llvm::SmallDenseSet fusedOps; @@ -673,7 +677,7 @@ static void doFusion(RewriterBase &rewriter, func::FuncOp func, LLVM_DEBUG(llvm::dbgs() << "\n\n"); FailureOr fuseAndTileResult = fuseWithEltwise(rewriter, cast(linalgOp), - defaultTiles, fusedOps, maxDepth); + defaultTiles, fusedOps, maxDepth, minTileFactor); LLVM_DEBUG(llvm::dbgs() << "\n\n"); if (succeeded(fuseAndTileResult)) { rewriter.replaceOp( @@ -703,7 +707,8 @@ struct TileConsumerAndFuseProducers do { func::FuncOp func = getOperation(); IRRewriter rewriter(&getContext()); - doFusion(rewriter, func, this->tileSizes, this->maxDepth); + doFusion(rewriter, func, this->tileSizes, this->maxDepth, + this->minTileFactor); { RewritePatternSet patterns(&ctx); diff --git a/lib/TPP/Transforms/TransformUtils.cpp b/lib/TPP/Transforms/TransformUtils.cpp index ca2ca75f8..b1499b96f 100644 --- a/lib/TPP/Transforms/TransformUtils.cpp +++ b/lib/TPP/Transforms/TransformUtils.cpp @@ -291,7 +291,8 @@ static std::optional getConstantRange(const Range &range) { } static bool validateFullTilesOnDim(TilingInterface tileOp, - const OpFoldResult &tile, size_t dim) { + const OpFoldResult &tile, size_t dim, + int64_t minTileFactor) { OpBuilder builder(tileOp); OpBuilder::InsertionGuard guard(builder); SmallVector iterationDomain = @@ -299,28 +300,29 @@ static bool validateFullTilesOnDim(TilingInterface tileOp, if (dim >= iterationDomain.size()) return false; - auto tileFactor = getConstantIntValue(tile); + auto tileSize = getConstantIntValue(tile); auto rangeOnDim = getConstantRange(iterationDomain[dim]); // If the tile factor or the range are non-constant, the tile size is // considered to be valid. - if (!tileFactor || !rangeOnDim) + if (!tileSize || !rangeOnDim) return true; // Corner case: Tiling with '0' along 'dim' is valid - no tiling. - if (*tileFactor == 0) + if (*tileSize == 0) return true; // Corner case: Tiling '1' with '1' is valid. - if (*tileFactor == 1 && *rangeOnDim == 1) + if (*tileSize == 1 && *rangeOnDim == 1) return true; - return (*rangeOnDim % *tileFactor == 0); + return (*rangeOnDim % *tileSize == 0) && + (*rangeOnDim / *tileSize >= minTileFactor); } bool validateFullTilesOnDims(TilingInterface tileOp, ArrayRef tiles, - ArrayRef dims) { + ArrayRef dims, int64_t minTileFactor) { if (!dims.empty() && dims.size() != tiles.size()) return false; @@ -333,7 +335,8 @@ bool validateFullTilesOnDims(TilingInterface tileOp, assert(dimsToCheck.size() == tiles.size()); for (auto dim : llvm::enumerate(dimsToCheck)) { - if (!validateFullTilesOnDim(tileOp, tiles[dim.index()], dim.value())) + if (!validateFullTilesOnDim(tileOp, tiles[dim.index()], dim.value(), + minTileFactor)) return false; } return true; diff --git a/test/Passes/DefaultPipeline/default-tpp-passes.mlir b/test/Passes/DefaultPipeline/default-tpp-passes.mlir index fa632c664..baa863c5f 100644 --- a/test/Passes/DefaultPipeline/default-tpp-passes.mlir +++ b/test/Passes/DefaultPipeline/default-tpp-passes.mlir @@ -223,8 +223,11 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x32x64xf32>, %arg1: tensor<512x // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i64 // CHECK-DAG: %[[C64:.+]] = arith.constant 64 : i64 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK-DAG: %[[C0_i:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1_i:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C512_i:.+]] = arith.constant 512 : index // CHECK: %{{.+}} = call @xsmm_gemm_dispatch(%[[C1]], %[[C32]], %[[C32]], %[[C64]], %[[C64]], %[[C32]], %[[C32]], %[[C0]]) - // CHECK: scf.parallel + // CHECK: scf.parallel{{.*}}(%[[C0_i]]) to (%[[C512_i]]) step (%[[C1_i]]) // CHECK: xsmm_gemm_invoke %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x32x64xf32>, tensor<512x64x32xf32>) outs(%0 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>