From 9aab33a4e2e632b962ce00092822387bb1778914 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 24 Jun 2024 12:36:52 +0200 Subject: [PATCH] Linalg to XeGPU lowering (#915) Adds direct lowering from Linalg to XeGPU and extends GPU runner support with 'intel' target. The Intel GPU pipeline is designed with IMEX and IGC compatibility in mind. The lowering targets tiled operations and assumes that the input shapes are nicely divisible by hardware supported sizes e.g., tiles 32x32, 16x16 etc. This is the first step toward bridging XeGPU with higher abstraction dialect. Common patterns used in this conversion can be later split into more progressive lowering through other dialects like vector and memref. Supported conversion: - targets Vector Compute mode of XeGPU (subgroup-level kernel) - eltwise operations of any type split into SIMD sized computations - DPAS implementation for F16 matmul with output precision conversion --- include/TPP/PassBundles.td | 20 +- include/TPP/Passes.h | 4 + include/TPP/Passes.td | 25 + lib/TPP/DefaultPipeline.cpp | 7 + lib/TPP/GPU/CMakeLists.txt | 2 + lib/TPP/GPU/GpuConversion.cpp | 12 +- lib/TPP/GPU/GpuPipeline.cpp | 68 +- lib/TPP/GPU/LinalgToXeGPU.cpp | 1404 ++++++++++++++++++++++++ lib/TPP/Runner/MLIRBench.cpp | 35 +- test/GPU/Intel/gpu-pipeline-intel.mlir | 42 + test/GPU/linalg-to-xegpu-dpas.mlir | 87 ++ test/GPU/linalg-to-xegpu-stages.mlir | 53 + test/GPU/linalg-to-xegpu.mlir | 217 ++++ 13 files changed, 1956 insertions(+), 20 deletions(-) create mode 100644 lib/TPP/GPU/LinalgToXeGPU.cpp create mode 100644 test/GPU/Intel/gpu-pipeline-intel.mlir create mode 100644 test/GPU/linalg-to-xegpu-dpas.mlir create mode 100644 test/GPU/linalg-to-xegpu-stages.mlir create mode 100644 test/GPU/linalg-to-xegpu.mlir diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index 66cc514e4..d44f5d747 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -124,9 +124,27 @@ def GpuConversion : Pass<"gpu-conversion", "ModuleOp"> { ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, ]; let dependentDialects = ["linalg::LinalgDialect", + "gpu::GPUDialect", "scf::SCFDialect", "memref::MemRefDialect", - "gpu::GPUDialect"]; + "xegpu::XeGPUDialect"]; + let options = [ + Option<"useWmma", "wmma", + "bool", /*default=*/"false", + "Use WMMA operations">, + ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, + Option<"isIntel", "intel", + "bool", /*default=*/"false", + "Convert for Intel GPU">, + Option<"kTile", "k-tile", "int64_t", + /*default=*/"32", + "GEMM tile size for reduction dimension.">, + Option<"stages", "stages", "int64_t", + /*default=*/"1", + "Number of cooperative prefetch stages.">, + ListOption<"dpasTile", "dpas-tile", "int64_t", + "DPAS register block sizes MxNxK">, + ]; } def GpuToCuda : Pass<"gpu-to-cuda", "ModuleOp"> { diff --git a/include/TPP/Passes.h b/include/TPP/Passes.h index e1fae0d92..c48f9e310 100644 --- a/include/TPP/Passes.h +++ b/include/TPP/Passes.h @@ -88,6 +88,10 @@ namespace xsmm { class XsmmDialect; } // namespace xsmm +namespace xegpu { +class XeGPUDialect; +} // namespace xegpu + } // namespace mlir namespace mlir { diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 3f4de5935..8ec2bef5e 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -479,4 +479,29 @@ def TppRunnerWrapper : Pass<"tpp-runner-wrapper", "ModuleOp">{ ]; } +def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { + let summary = "Convert linalg dialect to XeGPU dialect."; + let description = [{ + Lower linalg ops to XeGPU dialect. + }]; + let dependentDialects = ["linalg::LinalgDialect", + "gpu::GPUDialect", + "xegpu::XeGPUDialect", + "scf::SCFDialect", + "memref::MemRefDialect", + "arith::ArithDialect", + "math::MathDialect", + "vector::VectorDialect"]; + let options = [ + Option<"kTile", "k-tile", "int64_t", + /*default=*/"32", + "GEMM tile size for reduction dimension.">, + Option<"stages", "stages", "int64_t", + /*default=*/"1", + "Number of cooperative prefetch stages.">, + ListOption<"dpasTile", "dpas-tile", "int64_t", + "DPAS register block sizes MxNxK">, + ]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index e7b29b4c0..d5c89a68f 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -132,6 +132,13 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, if (print == PrintStage::Mid) pm.addPass(createPrintIRPass()); + // Bail out early for Intel GPU. + // The rest of the lowering is performed by IMEX. + if (gpuBackend == "intel") { + pm.addPass(createPrintIRPass()); + return; + } + // Partial Lowering pm.addPass(memref::createExpandStridedMetadataPass()); pm.addPass(createConvertTensorToLinalgPass()); diff --git a/lib/TPP/GPU/CMakeLists.txt b/lib/TPP/GPU/CMakeLists.txt index 44712b6a8..4a428cde8 100644 --- a/lib/TPP/GPU/CMakeLists.txt +++ b/lib/TPP/GPU/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(TPPGPU LinalgToGpu.cpp GpuDataTransfer.cpp GpuInlineConstants.cpp + LinalgToXeGPU.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP @@ -22,6 +23,7 @@ add_mlir_library(TPPGPU LINK_LIBS PUBLIC MLIRGPUDialect + MLIRXeGPUDialect MLIRGPUTransforms MLIRGPUToSPIRV MLIRSCFToGPU diff --git a/lib/TPP/GPU/GpuConversion.cpp b/lib/TPP/GPU/GpuConversion.cpp index b40cc070e..dc5f451ee 100644 --- a/lib/TPP/GPU/GpuConversion.cpp +++ b/lib/TPP/GPU/GpuConversion.cpp @@ -15,7 +15,8 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -58,8 +59,13 @@ struct GpuConversion : public tpp::impl::GpuConversionBase, // First lower linalg using custom patterns then fall back to // the default lowering for any remaining ops. pm.addNestedPass(createLinalgDeGeneralize()); - pm.addNestedPass( - createLinalgToGpu(LinalgToGpuOptions{useWmma, warpTile})); + if (isIntel) { + pm.addNestedPass( + createLinalgToXeGPU(LinalgToXeGPUOptions{kTile, stages, dpasTile})); + } else { + pm.addNestedPass( + createLinalgToGpu(LinalgToGpuOptions{useWmma, warpTile, kTile})); + } pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); // Map loops into GPU kernels. diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index b2fcb0a42..ae4a578bc 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -50,6 +50,29 @@ llvm::cl::list wmmaTileSizes( llvm::cl::list_init(SmallVector{16, 16, 16}), llvm::cl::CommaSeparated); +llvm::cl::list + gpuBlockTile("gpu-block-tile", llvm::cl::desc("GPU block tile size"), + llvm::cl::list_init(SmallVector{128, 128}), + llvm::cl::CommaSeparated); + +llvm::cl::list + gpuThreadTile("gpu-thread-tile", llvm::cl::desc("GPU thread tile size"), + llvm::cl::list_init(SmallVector{32, 32}), + llvm::cl::CommaSeparated); + +llvm::cl::opt kTile("k-tile", llvm::cl::desc("GEMM K dim tiling size"), + llvm::cl::init(32)); + +llvm::cl::opt stages("stages", + llvm::cl::desc("GEMM coop prefetch stages"), + llvm::cl::init(1)); + +// DPAS size defaults to PVC. +llvm::cl::list + gpuDpasTile("dpas-tile", llvm::cl::desc("DPAS register block sizes MxNxK"), + llvm::cl::list_init(SmallVector{8, 16, 16}), + llvm::cl::CommaSeparated); + namespace mlir { namespace tpp { #define GEN_PASS_DEF_GPUPIPELINE @@ -62,12 +85,14 @@ namespace { enum class GpuType { Cuda, Vulkan, + Intel, }; GpuType parseGpuOption(StringRef gpuStr) { auto type = llvm::StringSwitch>(gpuStr) .CaseLower("cuda", GpuType::Cuda) .CaseLower("vulkan", GpuType::Vulkan) + .CaseLower("intel", GpuType::Intel) .Default(std::nullopt); assert(type && "Unsupported GPU backend"); @@ -90,7 +115,8 @@ GpuOptions getGpuOptions(GpuType gpuType) { options.features = "+ptx60"; break; } - case GpuType::Vulkan: { + case GpuType::Vulkan: + case GpuType::Intel: { // No options needed at the moment. break; } @@ -145,22 +171,40 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, // Tile to split the kernel into threads and blocks. // Use default tiling to handle both packed and unpacked ops. pm.addPass(createCleanup()); - TileConsumerAndFuseProducersOptions tilingOptions; - tilingOptions.minTileFactor = 1; - pm.addPass(createTileConsumerAndFuseProducers(tilingOptions)); + if (gpuType == GpuType::Intel) { + // First split computation into grid with blocks of specified size. + TileConsumerAndFuseProducersOptions blockTileOptions; + blockTileOptions.tileSizes = gpuBlockTile; + blockTileOptions.minTileFactor = 1; + pm.addPass(createTileConsumerAndFuseProducers(blockTileOptions)); + + // Then try to further split computation into subtiles. + // This allows to split larger computations across multiple + // threads/workitems. For smaller workloads, it provides another + // chance for outlining. + TileConsumerAndFuseProducersOptions threadTileOptions; + threadTileOptions.tileSizes = gpuThreadTile; + threadTileOptions.minTileFactor = 1; + pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions)); + } else { + TileConsumerAndFuseProducersOptions tilingOptions; + tilingOptions.minTileFactor = 1; + pm.addPass(createTileConsumerAndFuseProducers(tilingOptions)); + } pm.addPass(createCleanup()); // Preprocess and bufferize as further conversion requires memref // abstraction. pm.addPass(createLowerPacksAndUnPacks()); - bool dealloc = gpuType != GpuType::Cuda; + bool dealloc = gpuType == GpuType::Vulkan; pm.addPass(createBufferize(BufferizeOptions{dealloc})); pm.addPass(createConvertForAllToParallelOp()); pm.addPass(createCleanup()); // Convert to generic GPU ops. - pm.addPass( - createGpuConversion(GpuConversionOptions{gpuWmma, wmmaTileSizes})); + pm.addPass(createGpuConversion( + GpuConversionOptions{gpuWmma, wmmaTileSizes, gpuType == GpuType::Intel, + kTile, stages, gpuDpasTile})); // Lower GPU ops to the chosen GPU backend. switch (gpuType) { @@ -177,6 +221,16 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, pm.addPass(createGpuToVulkan()); break; } + case GpuType::Intel: + pm.addPass(xegpu::createXeGPUFoldAliasOps()); + + std::string clientApi = "intel"; + SetSPIRVCapabilitiesOptions capabilitiesOptions{clientApi}; + pm.addPass(tpp::createSetSPIRVCapabilities(capabilitiesOptions)); + SetSPIRVAbiAttributeOptions abiAttrOptions{clientApi}; + pm.addPass(tpp::createSetSPIRVAbiAttribute(abiAttrOptions)); + + break; } // Covert all local dialects like perf. diff --git a/lib/TPP/GPU/LinalgToXeGPU.cpp b/lib/TPP/GPU/LinalgToXeGPU.cpp new file mode 100644 index 000000000..983c04728 --- /dev/null +++ b/lib/TPP/GPU/LinalgToXeGPU.cpp @@ -0,0 +1,1404 @@ +//===- LinalgToXeGPU.cpp -----------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" + +#include "TPP/IR/MatcherUtils.h" +#include "TPP/IR/StructuredOpMatcher.h" +#include "TPP/Transforms/Utils/ValueUtils.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/TransformOps/Utils.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/TypeSwitch.h" + +#include +#include + +using namespace mlir; +using namespace mlir::tpp; +using namespace mlir::xegpu; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_LINALGTOXEGPU +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +// Represents VNNI configuration for an operand. +struct VnniConfig { + int vnniFactor; + int vnniAxis; +}; + +// Helper struct to keep track of tiles' position with respect to whole matrix. +struct TilesArray { + TilesArray() = delete; + TilesArray(int numRows, int numCols) { + assert(((numRows > 0) && (numCols > 0)) && "Expected 2D array shape"); + for (int i = 0; i < numRows; i++) { + tileMatrix.push_back(SmallVector{}); + for (int j = 0; j < numCols; j++) + tileMatrix[i].push_back(Value{}); + } + } + ~TilesArray() = default; + + Value getTile(int row, int col) { return tileMatrix[row][col]; } + + void setTile(int row, int col, Value val) { tileMatrix[row][col] = val; } + + SmallVector toFlatVector() { + SmallVector flatVector; + for (auto row : tileMatrix) + flatVector.append(row); + return flatVector; + } + + SmallVector> tileMatrix; +}; + +// Return DPAS tile sizes if the gemm-like operation fits DPAS hardware. +static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile, + ArrayRef dpasTile) { + if (!(isa(linalgOp) || + isa(linalgOp) || + isa(linalgOp))) { + return false; + } + + // Expect MxNxK DPAS register block sizes. + if (dpasTile.size() != 3) + return false; + + // Only static shapes are supported. + if (linalgOp.hasDynamicShape()) + return false; + + auto aType = cast(linalgOp.getDpsInputs()[0].getType()); + auto bType = cast(linalgOp.getDpsInputs()[1].getType()); + auto cType = cast(linalgOp.getDpsInits()[0].getType()); + + auto elemTypeA = aType.getElementType(); + auto elemTypeB = bType.getElementType(); + auto elemTypeC = cType.getElementType(); + + // TODO: Add more DPAS combinations. + bool isSupportedPrecision = + (elemTypeA.isF16() && elemTypeB.isF16() && elemTypeC.isF16()) || + (elemTypeA.isF16() && elemTypeB.isF16() && elemTypeC.isF32()); + if (!isSupportedPrecision) + return false; + + auto mDim = cType.getShape()[0]; + auto nDim = cType.getShape()[1]; + auto kDim = aType.getShape().back(); + + // Validate workload sizes. + // The computation dimensions must fit into the tiles. + // Reduction dimension tile size has to be compatible + // with the warp tile. + int dpasTileM = dpasTile[0]; + int dpasTileN = dpasTile[1]; + int dpasTileK = dpasTile[2]; + if ((mDim % dpasTileM != 0) || (nDim % dpasTileN != 0) || + (kDim % dpasTileK != 0) || (kTile % dpasTileK != 0)) { + return false; + } + + return true; +} + +// Verify if linalg operands fulfill lowering constraints. +static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp, + Value operand, + PatternRewriter &rewriter, + unsigned maxDims = 2) { + auto type = dyn_cast(operand.getType()); + if (!type) { + return rewriter.notifyMatchFailure( + linalgOp, "Expect memref operand for XeGPU lowering"); + } + + if (type.getShape().size() > maxDims) { + return rewriter.notifyMatchFailure( + linalgOp, "Too high dimensionality for XeGPU operations"); + } + + auto strides = utils::getStaticStrides(operand); + + if (failed(strides)) { + return rewriter.notifyMatchFailure( + linalgOp, "Expect static strides for XeGPU lowering"); + } + if (strides->back() != 1) { + return rewriter.notifyMatchFailure(linalgOp, + "Expect unit stride in the innermost " + "dimension for XeGPU operations"); + } + + return success(); +} + +// Match and, if possible, lower a generic operation to an XeGPU compatible op. +// Returns the result of the lowered op or nullopt, otherwise. +static std::optional lowerGenericOp(linalg::GenericOp genericOp, + ArrayRef operands, + VectorType resType, + PatternRewriter &rewriter) { + Location loc = genericOp.getLoc(); + + // Expect operands to be already loaded vectors. + for (auto operand : operands) { + if (!isa(operand.getType())) + return std::nullopt; + } + + if (structured_match::utils::isTwoDReluOp(genericOp, /*operands=*/nullptr)) { + assert(operands.size() == 1 && + "Invalid number of operands for generic 2D ReLU"); + + auto eltType = resType.getElementType(); + Value zeroConst; + + if (isa(eltType)) { + auto floatType = cast(eltType); + zeroConst = rewriter.create( + loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); + } else if (isa(eltType)) { + zeroConst = rewriter.create(loc, 0, eltType); + } else { + // Unhandled type. Bail out. + return std::nullopt; + } + + auto zeroVec = + rewriter.create(loc, resType, zeroConst); + + return rewriter + .create(loc, resType, operands[0], zeroVec) + .getResult(); + } + + if (structured_match::utils::isTwoDAddOp(genericOp, /*operands=*/nullptr)) { + assert(operands.size() == 2 && + "Invalid number of operands for generic 2D add"); + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + + return std::nullopt; +} + +// Lower an elementwise operation to an XeGPU compatible op. +// Returns the result of the lowered op or nullopt, otherwise. +static std::optional lowerEltwiseOp(linalg::LinalgOp linalgOp, + ArrayRef operands, + PatternRewriter &rewriter) { + Location loc = linalgOp.getLoc(); + + assert(llvm::all_of(operands, + [&](Value tile) { + return tile.getType() == operands[0].getType(); + }) && + "All eltwise operands must have the same type."); + + // Expect operands to be already loaded vectors. + for (auto operand : operands) { + if (!isa(operand.getType())) + return std::nullopt; + } + + auto operandType = cast(operands[0].getType()); + auto resType = + VectorType::get(operandType.getShape(), operandType.getElementType()); + auto eltType = resType.getElementType(); + + return llvm::TypeSwitch>(linalgOp) + .Case([&](linalg::AbsOp absOp) -> std::optional { + assert(operands.size() == 1 && "Invalid number of operands for abs"); + if (isa(eltType)) { + return rewriter.create(loc, resType, operands[0]) + .getResult(); + } + if (isa(eltType)) { + return rewriter.create(loc, resType, operands[0]) + .getResult(); + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::AddOp addOp) -> std::optional { + assert(operands.size() == 2 && "Invalid number of operands for add"); + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::CeilOp ceilOp) -> std::optional { + assert(operands.size() == 1 && "Invalid number of operands for ceil"); + return rewriter.create(loc, resType, operands[0]) + .getResult(); + }) + .Case([&](linalg::DivOp divOp) -> std::optional { + assert(operands.size() == 2 && "Invalid number of operands for div"); + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::DivUnsignedOp divUnsignedOp) -> std::optional { + assert(operands.size() == 2 && + "Invalid number of operands for unsigned div"); + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::ExpOp expOp) -> std::optional { + assert(operands.size() == 1 && "Invalid number of operands for exp"); + return rewriter.create(loc, resType, operands[0]) + .getResult(); + }) + .Case([&](linalg::FloorOp floorOp) -> std::optional { + assert(operands.size() == 1 && "Invalid number of operands for floor"); + return rewriter.create(loc, resType, operands[0]) + .getResult(); + }) + .Case([&](linalg::MaxOp maxOp) -> std::optional { + assert(operands.size() == 2 && "Invalid number of operands for max"); + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + if (isa(eltType)) { + if (eltType.isUnsignedInteger()) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } else { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::MulOp mulOp) -> std::optional { + assert(operands.size() == 2 && "Invalid number of operands for mul"); + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::NegfOp negfOp) -> std::optional { + assert(operands.size() == 1 && "Invalid number of operands for negf"); + return rewriter.create(loc, resType, operands[0]) + .getResult(); + }) + .Case([&](linalg::SubOp subOp) -> std::optional { + assert(operands.size() == 2 && "Invalid number of operands for sub"); + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + if (isa(eltType)) { + return rewriter + .create(loc, resType, operands[0], operands[1]) + .getResult(); + } + // Unhandled type. Bail out. + return std::nullopt; + }) + .Case([&](linalg::GenericOp genericOp) -> std::optional { + return lowerGenericOp(genericOp, operands, resType, rewriter); + }) + .Default( + [&](Operation *op) -> std::optional { return std::nullopt; }); +} + +// Get static GPU block sizes represented by a surrounding operation +// like a kernel launch or parallel loop. +// Returns known block sizes if they are all static or failure, otherwise. +static FailureOr> getStaticBlockSizes(Operation *op) { + if (!op) + return failure(); + + auto getConstVal = [&](Value val) -> std::optional { + if (auto constOp = val.getDefiningOp()) { + return constOp.value(); + } + return std::nullopt; + }; + + if (auto launchOp = dyn_cast(op)) { + auto sizeX = getConstVal(launchOp.getBlockSizeX()); + auto sizeY = getConstVal(launchOp.getBlockSizeY()); + auto sizeZ = getConstVal(launchOp.getBlockSizeZ()); + if (!sizeX || !sizeY || !sizeZ) + return failure(); + + return SmallVector{*sizeX, *sizeY, *sizeZ}; + } + + // TODO: Remove when the lowering only occurs within a gpu.launch op. + // Manually computing this is brittle and duplicated parallel + // loops to gpu conversion. + if (auto blockLoop = dyn_cast(op)) { + auto gridLoop = blockLoop->getParentOfType(); + + // Blocks or number of threads are represented by the first parallel loop + // nested within another parallel loop. + // + // Fail if there is no outer parallel loop or current loop is nested more + // than once. + if (!gridLoop || (gridLoop->getParentOfType())) { + return failure(); + } + + SmallVector blockSizes; + for (auto [lb, ub, step] : + llvm::zip_equal(blockLoop.getLowerBound(), blockLoop.getUpperBound(), + blockLoop.getStep())) { + auto lbVal = getConstVal(lb); + auto ubVal = getConstVal(ub); + auto stepVal = getConstVal(step); + if (!lbVal || !ubVal || !stepVal) + return failure(); + + int64_t blockSize = (*ubVal - *lbVal) / *stepVal; + + // There must be at least one subgroup created for each dimension. + // Otherwise, bail out and let kernel outlining fail later. + if (blockSize <= 0) + return failure(); + blockSizes.push_back(blockSize); + } + + // Too many dimensions, something went wrong. Bail out. + if (blockSizes.size() > 3) + return failure(); + + return blockSizes; + } + + return failure(); +} + +// Get linearized GPU thread ID. +static Value getGpuLinearThreadId(PatternRewriter &rewriter, Location loc) { + SmallVector threadIds; + SmallVector blockDims; + + for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) { + threadIds.push_back(rewriter.create(loc, dim)); + blockDims.push_back(rewriter.create(loc, dim)); + } + + // The default GPU indexing is modeled after CUDA: + // linear index = (z * sizeY + y) * sizeX + x + Value threadId = + rewriter.create(loc, threadIds[2], blockDims[1]); + threadId = rewriter.create(loc, threadId, threadIds[1]); + threadId = rewriter.create(loc, threadId, blockDims[0]); + threadId = rewriter.create(loc, threadId, threadIds[0]); + + return threadId; +} + +// Create a GEMM input tile to be loaded by each subgroup in +// cooperative fashion. +// Optionally accepts batch IV for batched GEMM input loading. +// Returns failure if it is unable to split block/workgroup for +// prefetching. +static FailureOr +createGemmCoopPrefetchTile(PatternRewriter &rewriter, linalg::LinalgOp linalgOp, + unsigned inputPos, int64_t numThreads, + ArrayRef blockTile, ArrayRef threadTile, + int tileStep) { + assert(inputPos <= 1 && "Can handle only GEMM inputs: mat A or mat B"); + Location loc = linalgOp.getLoc(); + + Value src = linalgOp.getDpsInputs()[inputPos]; + + // Get a top level view into the whole matrix not only the thread slice. + if (auto subview = dyn_cast_or_null(src.getDefiningOp())) { + src = subview.getSource(); + } + + const int tileRows = inputPos == 0 ? blockTile[0] : tileStep; + const int tileCols = inputPos == 0 ? tileStep : blockTile[1]; + + const int numElements = tileRows * tileCols; + const int elementsPerThread = numElements / numThreads; + + // Limit the maximum prefetching row length to avoid very wide tiles. + // + // Currently, the max row size is capped by the hardware max load width. + // + // TODO: Expose as a tunable parameter or add some heuristics. + const int maxRowLength = 32; + + // Prioritize first loading contiguous elements (row lenght/number of + // columns) only then gather any remaining elements to be loaded from + // further rows. + // Also, ensure that the prefetch tile stays within the tile bounds. + // + // Ideally, prefetch tile sizes should be derived from total number of + // elements to be loaded, number of threads/workitems, and hardware load + // size limits. Large prefetch tiles might need to be split into sub-tiles. + const int numCols = + std::min(std::min(elementsPerThread, tileCols), maxRowLength); + const int numRows = elementsPerThread / numCols; + + // Bail on invalid prefetching tiles config. + if (numRows == 0 || + ((numRows * numCols * numThreads) > (tileRows * tileCols))) + return failure(); + + auto srcType = cast(src.getType()); + + auto prefetchType = + xegpu::TensorDescType::get({numRows, numCols}, srcType.getElementType()); + + Value threadId = getGpuLinearThreadId(rewriter, loc); + + // TODO: Simplify block offsets. + // Prefetching tile should be derived from the matmul op operands and + // exposed as a subview. + // + // Add offset if there are multiple blocks in the current tile's non-reduction + // dimension. + Value blockOffset = rewriter.create(loc, 0); + if (blockTile[inputPos] / threadTile[inputPos] > 1) { + Value blockSize = + rewriter.create(loc, blockTile[inputPos]); + + // For matrix B, pick correct block dimension. + // Block min X has to be used if there is no thread tiling in the rows + // (dim X) but only in columns (dim Y). + gpu::Dimension gpuDim = gpu::Dimension::x; + if ((inputPos == 1) && (blockTile[0] / threadTile[0] > 1)) { + gpuDim = gpu::Dimension::y; + } + Value blockId = rewriter.create(loc, gpuDim); + + blockOffset = rewriter.create(loc, blockId, blockSize); + } + + Value numColTiles = + rewriter.create(loc, tileStep / numCols); + if (inputPos == 1) { + numColTiles = + rewriter.create(loc, blockTile[1] / numCols); + } + Value tileRowOffset = + rewriter.create(loc, threadId, numColTiles); + Value tileColOffset = + rewriter.create(loc, threadId, numColTiles); + + Value tileRowSize = rewriter.create(loc, numRows); + Value tileColSize = rewriter.create(loc, numCols); + Value eltRowOffset = + rewriter.create(loc, tileRowOffset, tileRowSize); + Value eltColOffset = + rewriter.create(loc, tileColOffset, tileColSize); + + if (inputPos == 0) { + eltRowOffset = + rewriter.create(loc, eltRowOffset, blockOffset); + } else { + eltColOffset = + rewriter.create(loc, eltColOffset, blockOffset); + } + + SmallVector prefetchOffsets{eltRowOffset, eltColOffset}; + + return rewriter.create( + loc, prefetchType, dyn_cast>(src), + prefetchOffsets); +} + +// Insert prefetches for the given tensor descriptors. +static void prefetchTiles(PatternRewriter &rewriter, Location loc, + ValueRange prefetchTiles, + xegpu::CachePolicyAttr readCacheHint) { + // Prefetch the next set of input tiles. + for (auto tile : prefetchTiles) { + rewriter.create(loc, tile, + /*l1_hint=*/readCacheHint, + /*l2_hint=*/readCacheHint, + /*l3_hint=*/readCacheHint); + } +} + +// Update all tensor descriptors offsets with the fixed offsets. +static SmallVector updateTilesOffsets(PatternRewriter &rewriter, + Location loc, ValueRange tiles, + ArrayRef offsets) { + SmallVector updatedTiles; + for (auto tile : tiles) { + auto updatedTile = + rewriter + .create(loc, tile.getType(), tile, + /*offsets=*/ValueRange{}, offsets) + .getResult(); + updatedTiles.push_back(updatedTile); + } + + return updatedTiles; +} + +// Split a source into a series of descriptor tiles. +// +// The descriptors collectively load a 2D shape at the specified offsets from +// the given source. +// The offsets and the load shape must stay within the source boundaries. +// +// The descriptor sub-tiles are ordered in row-major fashion with respect to the +// whole load tile. +static SmallVector createDescriptorTiles(PatternRewriter &rewriter, + Location loc, Value src, + ArrayRef loadShape, + ArrayRef loadOffsets, + ArrayRef descTile, + int arrayLength = 1) { + assert(arrayLength == 1 && "Array descriptors are not supported"); + + auto type = cast(src.getType()); + auto descType = xegpu::TensorDescType::get(descTile, type.getElementType()); + + // Create the root descriptor. + // + // It is more efficient to create remainig descriptors by only updating its + // offsets compared to creating separate descriptors. + // The original tile is split into contiguous sub-tiles so, the first tile + // can be used as an anchor. + Value rootOffsetRow = + rewriter.create(loc, loadOffsets[0]); + Value rootOffsetCol = + rewriter.create(loc, loadOffsets[1]); + + mlir::SmallVector offsets{rootOffsetRow, rootOffsetCol}; + auto rootTile = + rewriter + .create( + loc, descType, dyn_cast>(src), offsets) + .getResult(); + + SmallVector tiles; + for (int i = 0; i < loadShape[0]; i += descTile[0]) { + for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { + auto tile = rewriter + .create( + loc, descType, rootTile, + /*offsets=*/ValueRange{}, SmallVector{i, j}) + .getResult(); + tiles.push_back(tile); + } + } + + return tiles; +} + +// Create coarse sub-tiles to be loaded by the current subgroup. +// +// The shape to be loaded is split into the largest 2D loads supported +// by the hardware. +// +// The load subgroup tiles are ordered in row-major fashion with respect to the +// source shape. +static SmallVector createCoarseDscTiles(PatternRewriter &rewriter, + Location loc, Value src, + ArrayRef sgTile, + bool isVnni) { + assert(sgTile.size() <= 2 && + "Require at most 2D tile size for eltwise lowering"); + + // Ensure that load is 2D. + // TODO: Add support for 1D loads. + SmallVector sgTile2D{sgTile}; + if (sgTile.size() == 1) + sgTile2D.push_back(1); + + auto type = cast(src.getType()); + auto elemByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + + // TODO: Fetch actual list of supported load configs. + int64_t maxHeight = 32; + int64_t maxWidth = 64 / elemByteWidth; + // Assumes VNNI-factor 2. + // TODO: Make the VNNI-factor flexible. + if (isVnni) + maxWidth /= 2; + int64_t maxArrayLength = 4; + int64_t sgLoadRows = std::min(sgTile2D[0], maxHeight); + int64_t sgLoadCols = std::min(sgTile2D[1], maxWidth); + int64_t arrayLength = std::min(maxWidth / sgLoadCols, maxArrayLength); + // In case of partial fit, load only single tile. + if (maxWidth % sgLoadCols != 0 || arrayLength != 4 || arrayLength != 2) + arrayLength = 1; + + // TODO: Add variable array_length support. + arrayLength = 1; + + return createDescriptorTiles(rewriter, loc, src, sgTile2D, {0, 0}, + {sgLoadRows, sgLoadCols}, arrayLength); +} + +// Return vector type with specified VNNI shape. +static VectorType getVnniVector(ArrayRef shape, Type elementType, + VnniConfig vnniConf) { + assert(shape.size() == 2 && "Expected plain 2D shape"); + SmallVector vecShape{shape}; + vecShape[vnniConf.vnniAxis] /= vnniConf.vnniFactor; + vecShape.push_back(vnniConf.vnniFactor); + return VectorType::get(vecShape, elementType); +} + +// Loads n-D tiles from memory to registers. +static SmallVector +loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, + xegpu::CachePolicyAttr hint, + std::optional vnniConf = std::nullopt, + DenseI64ArrayAttr transpose = nullptr) { + // Assume all tiles have the same shape. + auto tileType = cast(loadTiles[0].getType()); + assert(llvm::all_of(loadTiles, + [&](Value tile) { return tile.getType() == tileType; }) && + "All load tiles must have the same type."); + + VectorType vecLoadType = + VectorType::get(tileType.getShape(), tileType.getElementType()); + IntegerAttr vnniAxisAttr = nullptr; + if (vnniConf) { + vnniAxisAttr = IntegerAttr::get(rewriter.getI64Type(), vnniConf->vnniAxis); + vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(), + *vnniConf); + } + + SmallVector loadVec; + for (auto tile : loadTiles) { + auto loadOp = rewriter.create( + loc, vecLoadType, tile, vnniAxisAttr, transpose, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + loadVec.push_back(loadOp); + } + // TODO: Add split over the array_length > 1. + // The split must preserve row-major ordering of the load tiles. + + return loadVec; +} + +// Splits loaded tiles of a larger 2D tile into individual subtiles and places +// them in their corresponding positions with respect to the original large +// tile. +// +// The loaded tiles must be perfectly divisible by the specified subtiles. +// Assumes row-major ordering for both the loaded tiles and the original tile. +// +// If the loaded tiles use VNNI layout, corresponding VNNI configuration must be +// provided. +static TilesArray +extractVecSubTiles(PatternRewriter &rewriter, Location loc, + ValueRange loadVecTiles, ArrayRef sgTotalTile, + ArrayRef loadTile, ArrayRef subTile, + std::optional vnniConf = std::nullopt) { + auto vecLoadType = cast(loadVecTiles[0].getType()); + assert(llvm::all_of(loadVecTiles, + [&](Value tile) { + return cast(tile.getType()) == vecLoadType; + }) && + "All loaded vectors must have the same type."); + assert(vecLoadType.getShape().size() == 2 || + vnniConf && "Requires VNNI config for non 2D loaded tiles"); + + // Accumulate all dimensions as the vector might have extra VNNI + // dimensions. + int loadVecSize = std::accumulate(vecLoadType.getShape().begin(), + vecLoadType.getShape().end(), 1, + std::multiplies()); + auto loadVecFlat = VectorType::get(loadVecSize, vecLoadType.getElementType()); + + VectorType vecSubTileType = + VectorType::get(subTile, vecLoadType.getElementType()); + if (vnniConf) { + vecSubTileType = + getVnniVector(subTile, vecLoadType.getElementType(), *vnniConf); + } + + const int totalTileRows = sgTotalTile[0] / loadTile[0]; + const int totalTileCols = sgTotalTile[1] / loadTile[1]; + + const int subTilesPerLoadRow = loadTile[0] / subTile[0]; + const int subTilePerLoadCol = loadTile[1] / subTile[1]; + + const int subTileRows = sgTotalTile[0] / subTile[0]; + const int subTileCols = sgTotalTile[1] / subTile[1]; + TilesArray subTiles(subTileRows, subTileCols); + + // Iterate over the total tile. + for (int m = 0; m < totalTileRows; m++) { + for (int k = 0; k < totalTileCols; k++) { + // Load tiles are ordered in row-major fashion. + int loadIdx = m * totalTileCols + k; + auto sgTotalTile = loadVecTiles[loadIdx]; + auto castFlat = + rewriter.create(loc, loadVecFlat, sgTotalTile); + + // Iterate over load tiles. + // Each load tile contains one or more sub-tiles. + for (int i = 0; i < subTilesPerLoadRow; i++) { + for (int j = 0; j < subTilePerLoadCol; j++) { + const int subTileSize = subTile[0] * subTile[1]; + int dpasIdx = i * subTilePerLoadCol + j; + int offset = dpasIdx * subTileSize; + + auto slice = rewriter.create( + loc, castFlat, /*offsets=*/ArrayRef{offset}, + /*sizes=*/ArrayRef{subTileSize}, + /*strides=*/ArrayRef{1}); + auto castTile = + rewriter.create(loc, vecSubTileType, slice); + + // Insert the sub-tiles in their position relative to the whole + // subgroup tile. + int rowIdx = m * subTilesPerLoadRow + i; + int colIdx = k * subTilePerLoadCol + j; + subTiles.setTile(rowIdx, colIdx, castTile); + } + } + } + } + + return subTiles; +} + +// Create XeGPU DPAS kernel out of GEMM-like operation. +static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, + ArrayRef dpasTile, int kTile, + int prefetchStages, + PatternRewriter &rewriter) { + assert((isa(linalgOp) || + isa(linalgOp) || + isa(linalgOp)) && + "Requires a GEMM-like op for DPAS lowering"); + + Location loc = linalgOp.getLoc(); + auto ctx = linalgOp.getContext(); + + auto matA = linalgOp.getDpsInputs()[0]; + auto matB = linalgOp.getDpsInputs()[1]; + auto matC = linalgOp.getDpsInits()[0]; + + auto typeA = cast(matA.getType()); + auto typeC = cast(matC.getType()); + + int64_t dpasTileM = dpasTile[0]; + int64_t dpasTileN = dpasTile[1]; + int64_t dpasTileK = dpasTile[2]; + + // Cache hints for loads and stores. + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + auto writeCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK); + + bool isBrgemm = isa(linalgOp); + + Value zero = rewriter.create(loc, 0); + + int dimM = typeC.getShape()[0]; + int dimN = typeC.getShape()[1]; + int dimK = typeA.getShape().back(); + + // Create C sub-tiles. + auto dpasTypeC = xegpu::TensorDescType::get({dpasTileM, dpasTileN}, + typeC.getElementType()); + SmallVector tilesC = createDescriptorTiles( + rewriter, loc, matC, typeC.getShape(), {0, 0}, dpasTypeC.getShape()); + + // Load C sub-tiles. + // Fetch the inital values of the output accumulator. + SmallVector loadVecC = + loadNdDescTiles(rewriter, loc, tilesC, readCacheHint); + + // DPAS only works with F32 accumulators. + auto dpasResType = + VectorType::get(dpasTypeC.getShape(), FloatType::getF32(ctx)); + + // Extend the accumulation values if needed. + auto convOutPrecision = !typeC.getElementType().isF32(); + if (convOutPrecision) { + for (size_t i = 0; i < loadVecC.size(); i++) { + auto extOp = + rewriter.create(loc, dpasResType, loadVecC[i]); + loadVecC[i] = extOp.getOut(); + } + } + + // Create a loop and step into it. + auto startLoop = [&](int lb, int ub, int step, + ValueRange iterArgs) -> scf::ForOp { + Value lbCst = rewriter.create(loc, lb); + Value ubCst = rewriter.create(loc, ub); + Value stepCst = rewriter.create(loc, step); + scf::ForOp loopOp = + rewriter.create(loc, lbCst, ubCst, stepCst, iterArgs); + rewriter.setInsertionPointToStart(loopOp.getBody()); + return loopOp; + }; + auto getLoopIterValues = [&](scf::ForOp loopOp) -> SmallVector { + SmallVector loopIterVals; + for (auto iterArg : loopOp.getRegionIterArgs()) + loopIterVals.push_back(iterArg); + return loopIterVals; + }; + + OpBuilder::InsertionGuard guard(rewriter); + + // Construct and move into batch reduction loop. + // Propagate output values as iter args. + scf::ForOp batchLoop; + Value batchIv; + if (isBrgemm) { + batchLoop = startLoop(0, typeA.getShape()[0], 1, loadVecC); + batchIv = batchLoop.getInductionVar(); + loadVecC = getLoopIterValues(batchLoop); + // TODO: Replace input matrices A and B with subviews on the current + // batchIV as loads can only be performed on 2D memrefs. + } + + // Create A sub-tiles. + SmallVector tilesA = + createCoarseDscTiles(rewriter, loc, matA, {dimM, kTile}, /*isVnni=*/true); + + // Create B sub-tiles. + SmallVector tilesB = + createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, /*isVnni=*/true); + + // Create input prefetch tiles. + int64_t numThreads = 1; + auto blockDims = + getStaticBlockSizes(linalgOp->getParentOfType()); + if (succeeded(blockDims)) { + numThreads = std::accumulate(blockDims->begin(), blockDims->end(), 1, + std::multiplies()); + } + // Disable prefetching when there is no block/workgroup parallelism. + bool isCoopPrefetch = numThreads > 1; + + Value prefetchA; + Value prefetchB; + xegpu::TensorDescType prefetchTypeA; + xegpu::TensorDescType prefetchTypeB; + if (isCoopPrefetch) { + // Return dimension size on which the whole block/workgroup operates. + auto getBlockLevelSize = [&](Value val, int dim) -> int { + if (auto subview = + dyn_cast_or_null(val.getDefiningOp())) { + val = subview.getSource(); + } + + return cast(val.getType()).getShape()[dim]; + }; + + int blockRows = getBlockLevelSize(matC, 0); + int blockCols = getBlockLevelSize(matC, 1); + + auto prefetchDescA = createGemmCoopPrefetchTile( + rewriter, linalgOp, /*inputPos=*/0, numThreads, {blockRows, blockCols}, + {dimM, dimN}, kTile); + auto prefetchDescB = createGemmCoopPrefetchTile( + rewriter, linalgOp, /*inputPos=*/1, numThreads, {blockRows, blockCols}, + {dimM, dimN}, kTile); + + if (succeeded(prefetchDescA) && succeeded(prefetchDescB)) { + prefetchA = prefetchDescA->getResult(); + prefetchTypeA = prefetchDescA->getType(); + prefetchB = prefetchDescB->getResult(); + prefetchTypeB = prefetchDescB->getType(); + + // Start data prefetching by multistage data load. + for (int i = 0; i < prefetchStages; i++) { + prefetchTiles(rewriter, loc, ValueRange{prefetchA}, readCacheHint); + prefetchTiles(rewriter, loc, ValueRange{prefetchB}, readCacheHint); + prefetchA = updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, + {0, kTile})[0]; + prefetchB = updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, + {kTile, 0})[0]; + } + } else { + // Disable coop prefetching on failure. + isCoopPrefetch = false; + } + } + + // Construct and move into GEMM reduction dimension tiling loop. + // Propagate output values as iter args. + SmallVector iterArgs; + iterArgs.append(loadVecC); + iterArgs.append(tilesA); + iterArgs.append(tilesB); + if (isCoopPrefetch) { + iterArgs.push_back(prefetchA); + iterArgs.push_back(prefetchB); + } + scf::ForOp kDimLoop = startLoop(0, dimK, kTile, iterArgs); + auto iterValues = getLoopIterValues(kDimLoop); + + loadVecC = SmallVector{iterValues.begin(), + iterValues.begin() + loadVecC.size()}; + tilesA = + SmallVector{iterValues.begin() + loadVecC.size(), + iterValues.begin() + loadVecC.size() + tilesA.size()}; + tilesB = SmallVector{iterValues.begin() + loadVecC.size() + tilesA.size(), + iterValues.begin() + loadVecC.size() + tilesA.size() + tilesB.size()}; + if (isCoopPrefetch) { + prefetchA = *(iterValues.end() - 2); + prefetchB = *(iterValues.end() - 1); + } + + // Periodically synchronize the block/workgroup to minimize impact on cache + // due to replacement of sub-tiles before all threads/workitems consumed + // inputs for reduction dimension step. + // + // TODO: Synchronization frequency should be derived from tile and cache size. + int syncFreq = 4; + int maxSyncStep = 1024; + int syncStep = std::min(std::max(dimK / syncFreq, maxSyncStep), maxSyncStep); + auto syncStepConst = rewriter.create(loc, syncStep); + auto loopStepMod = rewriter.create( + loc, kDimLoop.getInductionVar(), syncStepConst); + auto syncBlockCond = rewriter.create( + loc, arith::CmpIPredicate::eq, loopStepMod, zero); + rewriter.create( + loc, syncBlockCond, + /*thenBuilder=*/ + [](OpBuilder &b, Location loc) { + b.create(loc); + b.create(loc); + }, + /*elseBuilder=*/nullptr); + + // TODO: Add more possible types. + int vnniFactor = TypeSwitch(typeA.getElementType()) + .Case([](Float16Type type) { return 2; }) + .Default([](Type type) { return -1; }); + if (vnniFactor == -1) + return failure(); + + VnniConfig vnniConfA{.vnniFactor = vnniFactor, .vnniAxis = 1}; + VnniConfig vnniConfB{.vnniFactor = vnniFactor, .vnniAxis = 0}; + + // Load A sub-tiles. + SmallVector loadVecA = + loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA); + auto tileTypeA = cast(tilesA[0].getType()); + + // Load B sub-tiles. + SmallVector loadVecB = + loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB); + auto tileTypeB = cast(tilesB[0].getType()); + + // Update offsets of the input tiles. + // Shift along the reduction dimension. + tilesA = updateTilesOffsets(rewriter, loc, tilesA, {0, kTile}); + tilesB = updateTilesOffsets(rewriter, loc, tilesB, {kTile, 0}); + + // Prefetch the next set of input tiles. + if (isCoopPrefetch) { + // Prefetch all block/workgroup tiles cooperatively. + prefetchTiles(rewriter, loc, ValueRange{prefetchA}, readCacheHint); + prefetchTiles(rewriter, loc, ValueRange{prefetchB}, readCacheHint); + prefetchA = + updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, {0, kTile})[0]; + prefetchB = + updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, {kTile, 0})[0]; + } else { + // Apply naive prefetching for each subgroup separately. + prefetchTiles(rewriter, loc, tilesA, readCacheHint); + prefetchTiles(rewriter, loc, tilesB, readCacheHint); + } + + // Extract DPAS tiles from loaded sub-tiles. + TilesArray dpasVecA = extractVecSubTiles(rewriter, loc, loadVecA, + {dimM, kTile}, tileTypeA.getShape(), + {dpasTileM, dpasTileK}, vnniConfA); + TilesArray dpasVecB = extractVecSubTiles(rewriter, loc, loadVecB, + {kTile, dimN}, tileTypeB.getShape(), + {dpasTileK, dpasTileN}, vnniConfB); + + const int numTilesM = dimM / dpasTileM; + const int numTilesN = dimN / dpasTileN; + const int numTilesK = kTile / dpasTileK; + + // Compute sub-tiles of the C tile. + // + // Iterate over the reduction dimension sub-tiles as the outermost + // loop to minimize read after write conflicts between partial + // computations of the same C sub-tile. + SmallVector dpasResults = loadVecC; + + for (int k = 0; k < numTilesK; k++) { + for (int m = 0; m < numTilesM; m++) { + for (int n = 0; n < numTilesN; n++) { + int cIdx = m * numTilesN + n; + + Value result = rewriter + .create( + loc, dpasResType, dpasVecA.getTile(m, k), + dpasVecB.getTile(k, n), dpasResults[cIdx]) + .getResult(); + + // Update sub-tile partial result. + dpasResults[cIdx] = result; + } + } + } + + // Create loop terminator and exit the loop. + auto terminateLoop = [&](scf::ForOp loopOp, SmallVector resultValues) { + rewriter.setInsertionPointToEnd(loopOp.getBody()); + rewriter.create(loc, resultValues); + rewriter.setInsertionPointAfter(loopOp); + }; + + SmallVector yieldVals; + yieldVals.append(dpasResults); + yieldVals.append(tilesA); + yieldVals.append(tilesB); + if (isCoopPrefetch) { + yieldVals.push_back(prefetchA); + yieldVals.push_back(prefetchB); + } + + // Terminate and exit reduction dim loop. + terminateLoop(kDimLoop, yieldVals); + yieldVals = kDimLoop.getResults(); + + SmallVector results{yieldVals.begin(), + yieldVals.begin() + dpasResults.size()}; + + // Terminate and exit batch reduce loop. + if (isBrgemm) { + terminateLoop(batchLoop, results); + results = batchLoop.getResults(); + } + + // Truncate the result values if needed. + if (convOutPrecision) { + auto truncType = + VectorType::get(dpasTypeC.getShape(), typeC.getElementType()); + for (size_t i = 0; i < results.size(); i++) { + auto truncOp = + rewriter.create(loc, truncType, results[i]); + results[i] = truncOp.getOut(); + } + } + + // Write back the final C sub-tiles results to the output buffer. + SmallVector storeOps; + for (size_t i = 0; i < tilesC.size(); i++) { + auto storeOp = + rewriter.create(loc, results[i], tilesC[i], + /*l1_hint=*/writeCacheHint, + /*l2_hint=*/writeCacheHint, + /*l3_hint=*/writeCacheHint); + storeOps.push_back(storeOp); + } + + rewriter.eraseOp(linalgOp); + + return success(); +} + +// Create XeGPU kernel out of elementwise operation. +LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) { + Location loc = linalgOp.getLoc(); + auto ctx = linalgOp.getContext(); + + auto output = linalgOp.getDpsInits()[0]; + auto outputShape = cast(output.getType()).getShape(); + + // Create descriptors and load values for all inputs. + SmallVector> loadedInputs; + for (auto input : linalgOp.getDpsInputs()) { + SmallVector inputTiles = createCoarseDscTiles( + rewriter, loc, input, outputShape, /*isVnni=*/false); + SmallVector loadedVals = + loadNdDescTiles(rewriter, loc, inputTiles, /*hint=*/nullptr); + loadedInputs.push_back(loadedVals); + } + + // Extract SIMD sized sub-tiles from loaded tiles. + // TODO: Fetch SIMD sizes from target descriptor. + int maxSizeSIMD = 256; + auto loadShape = cast(loadedInputs[0][0].getType()).getShape(); + // For sake of n-D loads and store, the vectorized operations are kept in 2D + // shape. The loaded tiles might be larger than what SIMD units can handle. + // Thus, split the registers into contiguous smaller slices. The current + // hardware load restrictions ensure that the loaded tile width will not + // exceed SIMD size. + // + // Take at least one whole row plus as many extra rows as can fit into + // a single SIMD instruction. + int64_t subTileCols = loadShape[1]; + int64_t subTileRows = std::min(loadShape[0], maxSizeSIMD / subTileCols); + + SmallVector> vecSubTiles; + for (auto inputTiles : loadedInputs) { + TilesArray subTiles = + extractVecSubTiles(rewriter, loc, inputTiles, outputShape, loadShape, + {subTileRows, subTileCols}); + vecSubTiles.push_back(subTiles.toFlatVector()); + } + + // Perform vectorized computations for each output tile. + SmallVector results; + for (size_t i = 0; i < vecSubTiles[0].size(); i++) { + // Operands are sub-tiles at the same location. + SmallVector operands; + for (auto inputs : vecSubTiles) { + operands.push_back(inputs[i]); + } + + // Create SIMD operations on the sub-tiles. + auto res = lowerEltwiseOp(linalgOp, operands, rewriter); + if (!res) + return failure(); + + results.push_back(*res); + } + + // Output descriptors for later stores. + SmallVector outputTiles = createDescriptorTiles( + rewriter, loc, output, outputShape, {0, 0}, {subTileRows, subTileCols}); + + // Store results. + auto writeCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK); + for (size_t i = 0; i < outputTiles.size(); i++) { + rewriter.create(loc, results[i], outputTiles[i], + /*l1_hint=*/writeCacheHint, + /*l2_hint=*/writeCacheHint, + /*l3_hint=*/writeCacheHint); + } + + rewriter.eraseOp(linalgOp); + + return success(); +} + +// Convert a GEMM-like operation to an XeGPU kernel. +template +struct ConvertGemmLikeToXeGPU : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + // Constrain conversion to the supported GEMM-like ops. + static_assert( + llvm::is_one_of::value); + + ConvertGemmLikeToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options) + : OpRewritePattern(ctx), options(options) {} + + LogicalResult matchAndRewrite(LinalgOpTy gemmLikeOp, + PatternRewriter &rewriter) const override { + if (!gemmLikeOp.hasPureBufferSemantics()) { + return rewriter.notifyMatchFailure( + gemmLikeOp, "Linalg GEMM-like to GPU expects memref type"); + } + if (gemmLikeOp.hasDynamicShape()) { + return rewriter.notifyMatchFailure( + gemmLikeOp, "Expect static shape when mapping to GPU"); + } + + using namespace structured_match; + auto matmulMatcher = + StructuredOpMatcher::make() + .operation(NumDpsInits(EqualsTo(1))) + .operation(NumDpsInputs(EqualsTo(2))) + .operation(NumRegions(EqualsTo(1))) + .operation(NumOfLoops(EqualsTo(3))) + .input(MatchAll(), HasStaticShape()) + .output(MatchAll(), HasStaticShape()) + .region(MatchOne(0), WithOpChain()); + if (isa(gemmLikeOp) && + !matmulMatcher.match(gemmLikeOp)) { + return rewriter.notifyMatchFailure( + gemmLikeOp, "Generic does not represent a GEMM-like operation"); + } + + for (auto input : gemmLikeOp.getDpsInputs()) { + // 3D inputs are also acceptable in case of brgemm. + auto isInputValid = + isValidMemrefOperand(gemmLikeOp, input, rewriter, /*maxDims=*/3); + if (failed(isInputValid)) + return isInputValid; + } + auto isOutputValid = + isValidMemrefOperand(gemmLikeOp, gemmLikeOp.getDpsInits()[0], rewriter); + if (failed(isOutputValid)) + return isOutputValid; + + // Ensure that reduction dimension tiling also works for smaller + // workloads. + auto aType = cast(gemmLikeOp.getDpsInputs()[0].getType()); + auto kDim = aType.getShape().back(); + auto kTile = kDim < options.kTile ? kDim : options.kTile; + + // DPAS hardware sizes in MxNxK format. + // TODO: In case more hardware configurations are available, + // add some automatic selection for optimal sizes. + if (options.dpasTile.empty()) { + return rewriter.notifyMatchFailure(gemmLikeOp, "Expect DPAS block sizes"); + } + + if (!isDPASCompatible(gemmLikeOp, kTile, options.dpasTile)) { + return rewriter.notifyMatchFailure( + gemmLikeOp, "GEMM-like compute does not fit in DPAS tiles"); + } + + return createDPASKernel(gemmLikeOp, options.dpasTile, kTile, options.stages, + rewriter); + } + +private: + LinalgToXeGPUOptions options; +}; + +// Convert a named elementwise operation to an XeGPU kernel. +template +struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertNamedEltwiseToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options) + : OpRewritePattern(ctx), options(options) {} + + LogicalResult matchAndRewrite(LinalgOpTy eltwiseOp, + PatternRewriter &rewriter) const override { + if (!eltwiseOp.hasPureBufferSemantics()) { + return rewriter.notifyMatchFailure( + eltwiseOp, "Linalg eltwise to GPU expects memref type"); + } + if (eltwiseOp.hasDynamicShape()) { + return rewriter.notifyMatchFailure( + eltwiseOp, "Expect static shape when mapping to GPU"); + } + + for (auto input : eltwiseOp.getDpsInputs()) { + auto isInputValid = isValidMemrefOperand(eltwiseOp, input, rewriter); + if (failed(isInputValid)) + return isInputValid; + } + auto isOutputValid = + isValidMemrefOperand(eltwiseOp, eltwiseOp.getDpsInits()[0], rewriter); + if (failed(isOutputValid)) + return isOutputValid; + + return createEltwiseKernel(eltwiseOp, rewriter); + } + +private: + LinalgToXeGPUOptions options; +}; + +// TODO: Finalize BRGEMM support and register the pattern. +void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns, + LinalgToXeGPUOptions options) { + patterns.add, + ConvertGemmLikeToXeGPU>(patterns.getContext(), + options); +} + +void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns, + LinalgToXeGPUOptions options) { + patterns.add, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU, + ConvertNamedEltwiseToXeGPU>(patterns.getContext(), + options); +} + +struct LinalgToXeGPU : public tpp::impl::LinalgToXeGPUBase { + using LinalgToXeGPUBase::LinalgToXeGPUBase; + + void runOnOperation() override { + LinalgToXeGPUOptions options{kTile, stages, dpasTile}; + + // Run GEMM pattern first to allow fusion with its consumers. + RewritePatternSet gemmPatterns(&getContext()); + populateLinalgGemmToXeGPUPatterns(gemmPatterns, options); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(gemmPatterns)); + + // Convert other remaining ops. + RewritePatternSet patterns(&getContext()); + populateLinalgEltwiseToXeGPUPatterns(patterns, options); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace diff --git a/lib/TPP/Runner/MLIRBench.cpp b/lib/TPP/Runner/MLIRBench.cpp index e269811f9..c352e9dab 100644 --- a/lib/TPP/Runner/MLIRBench.cpp +++ b/lib/TPP/Runner/MLIRBench.cpp @@ -193,18 +193,27 @@ Value MLIRBench::registerOnGpu(Value buf, MemRefType memRefTy) { } // Allocate an arg buffer on device and copy data from host - auto gpuAlloc = builder.create(unkLoc, memRefTy, ValueRange{}, - ValueRange{}, ValueRange{}); + // Use shared memory on Intel GPU and dedicated GPU allocation, otherwise + bool isHostShared = backend == "intel"; + auto gpuAlloc = + builder.create(unkLoc, memRefTy, ValueRange{}, ValueRange{}, + ValueRange{}, /*hostShared=*/isHostShared); auto gpuBuf = gpuAlloc.getResult(0); - auto gpuMemcpy = builder.create( - unkLoc, /*asyncToken=*/std::nullopt, ValueRange{}, gpuBuf, buf); + + Operation *memcpy; + if (backend == "intel") { + memcpy = builder.create(unkLoc, buf, gpuBuf); + } else { + memcpy = builder.create(unkLoc, /*asyncToken=*/std::nullopt, + ValueRange{}, gpuBuf, buf); + } // Dealloc the arg buffer at the end of program builder.setInsertionPointToEnd(&getMainBlock()); builder.create(unkLoc, /*asyncToken=*/std::nullopt, gpuBuf); // Continue inserting ops after the created kernel arg - builder.setInsertionPointAfter(gpuMemcpy); + builder.setInsertionPointAfter(memcpy); return gpuBuf; } @@ -384,7 +393,9 @@ LogicalResult MLIRBench::printResult(Operation *kernelCall) { // Kernels must return a single result Value result = kernelCall->getResult(0); - if (backend == "cuda" && offloadToDevice) { + + bool isIntel = (backend == "intel"); + if (((backend == "cuda") || isIntel) && offloadToDevice) { auto resType = cast(result.getType()); auto memrefType = MemRefType::get(resType.getShape(), resType.getElementType()); @@ -395,8 +406,14 @@ LogicalResult MLIRBench::printResult(Operation *kernelCall) { } auto outBuf = builder.create(unkLoc, memrefType); - auto gpuMemcpy = builder.create( - unkLoc, /*asyncToken=*/std::nullopt, ValueRange{}, outBuf, result); + + Operation *memcpy; + if (isIntel) { + memcpy = builder.create(unkLoc, result, outBuf); + } else { + memcpy = builder.create( + unkLoc, /*asyncToken=*/std::nullopt, ValueRange{}, outBuf, result); + } // Dealloc the output buffer at the end of program. // For now, automatic deallocation is disabled for GPUs. @@ -404,7 +421,7 @@ LogicalResult MLIRBench::printResult(Operation *kernelCall) { builder.create(unkLoc, outBuf); // Restore insertion point - builder.setInsertionPointAfter(gpuMemcpy); + builder.setInsertionPointAfter(memcpy); result = outBuf; } diff --git a/test/GPU/Intel/gpu-pipeline-intel.mlir b/test/GPU/Intel/gpu-pipeline-intel.mlir new file mode 100644 index 000000000..332daba12 --- /dev/null +++ b/test/GPU/Intel/gpu-pipeline-intel.mlir @@ -0,0 +1,42 @@ +// RUN: tpp-opt %s -gpu-pipeline=gpu=intel \ +// RUN: -gpu-block-tile=128,128 -gpu-thread-tile=32,32 -k-tile=32 -stages=1 \ +// RUN: -split-input-file | \ +// RUN: FileCheck %s + +func.func @linalg_matmul(%arg0: tensor<128x1024xf16>, + %arg1: tensor<1024x1024xf16>, + %arg2: tensor<128x1024xf16>) -> tensor<128x1024xf16> { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x1024xf16>, tensor<1024x1024xf16>) + outs(%arg2 : tensor<128x1024xf16>) -> tensor<128x1024xf16> + return %0 : tensor<128x1024xf16> +} + +// CHECK: module attributes {gpu.container_module} +// CHECK-LABEL: func.func @linalg_matmul( +// CHECK-SAME: %[[arg0:.+]]: memref<128x1024xf16>, %[[arg1:.+]]: memref<1024x1024xf16>, %[[arg2:.+]]: memref<128x1024xf16> +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c8:.+]] = arith.constant 8 : index +// CHECK: gpu.launch_func @linalg_matmul_kernel::@linalg_matmul_kernel blocks in (%[[c8]], %[[c1]], %[[c1]]) threads in (%[[c4]], %[[c4]], %[[c1]]) args(%[[arg2]] : memref<128x1024xf16>, %[[arg0]] : memref<128x1024xf16>, %[[arg1]] : memref<1024x1024xf16> +// +// CHECK-LABEL: gpu.func @linalg_matmul_kernel( +// CHECK-SAME: %[[C:.+]]: memref<128x1024xf16>, %[[A:.+]]: memref<128x1024xf16>, %[[B:.+]]: memref<1024x1024xf16> +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[c1024:.+]] = arith.constant 1024 : index +// CHECK-COUNT-8: xegpu.load_nd +// CHECK-COUNT-8: arith.extf +// CHECK-COUNT-2: xegpu.prefetch_nd +// CHECK: %[[out:.+]]:14 = scf.for %[[iv:.+]] = %[[c0]] to %[[c1024]] step %[[c32]] +// CHECK-SAME: { +// CHECK-COUNT-4: xegpu.load_nd +// CHECK-COUNT-4: xegpu.update_nd_offset +// CHECK-COUNT-2: xegpu.prefetch_nd +// CHECK-COUNT-2: xegpu.update_nd_offset +// CHECK-COUNT-12: vector.extract_strided_slice +// CHECK-COUNT-16: xegpu.dpas +// CHECK: scf.yield +// CHECK: } +// CHECK-COUNT-8: arith.truncf +// CHECK-COUNT-8: xegpu.store_nd +// CHECK: gpu.return diff --git a/test/GPU/linalg-to-xegpu-dpas.mlir b/test/GPU/linalg-to-xegpu-dpas.mlir new file mode 100644 index 000000000..a64a93f7a --- /dev/null +++ b/test/GPU/linalg-to-xegpu-dpas.mlir @@ -0,0 +1,87 @@ +// RUN: tpp-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) { + %c1 = arith.constant 1 : index + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) + outs(%arg2 : memref<32x32xf16>) + gpu.terminator + } + return +} + +// CHECK-LABEL: func.func @matmul +// CHECK-SAME: %[[A:.+]]: memref<32x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<32x32xf16> +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 +// CHECK-DAG: %[[c16:.+]] = arith.constant 16 +// CHECK-DAG: %[[c32:.+]] = arith.constant 32 + +// Create output initial value load tiles. +// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] +// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0] +// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]] + +// Load initial accumulator values. +// CHECK: %[[vC:.+]] = xegpu.load_nd %[[tC]] +// CHECK-COUNT-7: xegpu.load_nd + +// Extend the type to match DPAS output precision. +// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]] +// CHECK-COUNT-7: arith.extf + +// Create input load tiles. +// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] +// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0] +// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] +// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0] +// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]] + +// Create DPAS computation loop over tiled reduction dimension. +// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%[[c0]] to %[[c32]] step %[[c16]] +// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]] +// CHECK-SAME: { + +// Periodically synchronize the workgroup. +// CHECK: scf.if +// CHECK-SAME: { +// CHECK: gpu.barrier +// CHECK: } + +// Load input values and update the load tile position. +// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]] +// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] +// CHECK-COUNT-1: xegpu.load_nd +// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]] +// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]] +// CHECK-COUNT-1: xegpu.update_nd_offset + +// Apply simple prefetching scheme - start loading the next set of input +// tiles before computation is started. +// CHECK: xegpu.prefetch_nd %[[new_tA]] +// CHECK: xegpu.prefetch_nd %[[new_tB]] +// CHECK-COUNT-1: xegpu.prefetch_nd + +// Extract DPAS-sized chunks from larger loaded tile A. +// Tile B is already in the correct shape. +// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16> +// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> +// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> +// CHECK-COUNT-3: vector.extract_strided_slice + +// Perform DPAS computation. +// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]] +// CHECK-COUNT-7: xegpu.dpas + +// Yield the results to the next iteration. +// CHECK: scf.yield %[[dpas]],{{.*}}%[[new_tA]],{{.*}}%[[new_tB]] +// CHECK: } + +// Truncate results to the original output precision. +// CHECK: %[[res_f16:.+]] = arith.truncf %[[res]]#0 +// CHECK-COUNT-7: arith.truncf + +// Store back the final results. +// CHECH: xegpu.store_nd %[[res_f16]], %[[tC]] +// CHECK-COUNT-7: xegpu.store_nd + +// CHECK: gpu.terminator diff --git a/test/GPU/linalg-to-xegpu-stages.mlir b/test/GPU/linalg-to-xegpu-stages.mlir new file mode 100644 index 000000000..d7639f7f0 --- /dev/null +++ b/test/GPU/linalg-to-xegpu-stages.mlir @@ -0,0 +1,53 @@ +// RUN: tpp-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 stages=1" -canonicalize -split-input-file | FileCheck %s --check-prefix=STAGES-1 + +// RUN: tpp-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 stages=2" -canonicalize -split-input-file | FileCheck %s --check-prefix=STAGES-2 + +#map = affine_map<()[s0, s1] -> (s0 + s1)> +module { + func.func @matmul_multistage_coop_prefetch(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c128, %c128) { + %subview = memref.subview %arg2[%arg3, %arg4] [128, 128] [1, 1] : memref<1024x1024xf16> to memref<128x128xf16, strided<[1024, 1], offset: ?>> + scf.parallel (%arg5, %arg6) = (%c0, %c0) to (%c128, %c128) step (%c32, %c32) { + %subview_0 = memref.subview %subview[%arg5, %arg6] [32, 32] [1, 1] : memref<128x128xf16, strided<[1024, 1], offset: ?>> to memref<32x32xf16, strided<[1024, 1], offset: ?>> + %0 = affine.apply #map()[%arg5, %arg3] + %subview_1 = memref.subview %arg0[%0, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + %1 = affine.apply #map()[%arg6, %arg4] + %subview_2 = memref.subview %arg1[0, %1] [1024, 32] [1, 1] : memref<1024x1024xf16> to memref<1024x32xf16, strided<[1024, 1], offset: ?>> + linalg.matmul ins(%subview_1, %subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x32xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>) + scf.reduce + } + scf.reduce + } + return + } +} + +// STAGES-1-LABEL: func.func @matmul_multistage_coop_prefetch +// STAGES-1-SAME: %[[A:.+]]: memref<1024x1024xf16>, %[[B:.+]]: memref<1024x1024xf16>, %[[C:.+]]: memref<1024x1024xf16> +// STAGES-1: %[[s1_A:.+]] = xegpu.create_nd_tdesc %[[A]] +// STAGES-1: %[[s1_B:.+]] = xegpu.create_nd_tdesc %[[B]] +// STAGES-1: xegpu.prefetch_nd %[[s1_A]] +// STAGES-1: xegpu.prefetch_nd %[[s1_B]] +// STAGES-1: %[[loop_pref_A:.+]] = xegpu.update_nd_offset %[[s1_A]] +// STAGES-1: %[[loop_pref_B:.+]] = xegpu.update_nd_offset %[[s1_B]] +// STAGES-1-NOT: xegpu.prefetch_nd +// STAGES-1: scf.for + +// STAGES-2-LABEL: func.func @matmul_multistage_coop_prefetch +// STAGES-2-SAME: %[[A:.+]]: memref<1024x1024xf16>, %[[B:.+]]: memref<1024x1024xf16>, %[[C:.+]]: memref<1024x1024xf16> +// STAGES-2: %[[s1_A:.+]] = xegpu.create_nd_tdesc %[[A]] +// STAGES-2: %[[s1_B:.+]] = xegpu.create_nd_tdesc %[[B]] +// STAGES-2: xegpu.prefetch_nd %[[s1_A]] +// STAGES-2: xegpu.prefetch_nd %[[s1_B]] +// STAGES-2: %[[s2_A:.+]] = xegpu.update_nd_offset %[[s1_A]] +// STAGES-2: %[[s2_B:.+]] = xegpu.update_nd_offset %[[s1_B]] +// STAGES-2: xegpu.prefetch_nd %[[s2_A]] +// STAGES-2: xegpu.prefetch_nd %[[s2_B]] +// STAGES-2: %[[loop_pref_A:.+]] = xegpu.update_nd_offset %[[s2_A]] +// STAGES-2: %[[loop_pref_B:.+]] = xegpu.update_nd_offset %[[s2_B]] +// STAGES-2-NOT: xegpu.prefetch_nd +// STAGES-2: scf.for diff --git a/test/GPU/linalg-to-xegpu.mlir b/test/GPU/linalg-to-xegpu.mlir new file mode 100644 index 000000000..0362b6f86 --- /dev/null +++ b/test/GPU/linalg-to-xegpu.mlir @@ -0,0 +1,217 @@ +// RUN: tpp-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +func.func @matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + linalg.matmul ins(%arg0, %arg1 : memref<8x16xf16>, memref<16x16xf16>) + outs(%arg2 : memref<8x16xf32>) + return +} + +// CHECK-LABEL: func.func @matmul +// CHECK-COUNT-3: xegpu.load_nd +// CHECK: xegpu.dpas +// CHECH: xegpu.store_nd + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func @generic_matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf16>) { + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<8x16xf16>, memref<16x16xf16>) outs(%arg2 : memref<8x16xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %0 = arith.mulf %in, %in_0 : f16 + %1 = arith.addf %out, %0 : f16 + linalg.yield %1 : f16 + } + return + } +} + +// CHECK-LABEL: func.func @generic_matmul +// CHECK-COUNT-3: xegpu.load_nd +// CHECK: xegpu.dpas +// CHECH: xegpu.store_nd + +// ----- + +func.func @matmul_trunc_result(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf16>) { + linalg.matmul ins(%arg0, %arg1 : memref<8x16xf16>, memref<16x16xf16>) + outs(%arg2 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @matmul_trunc_result +// CHECK: arith.extf +// CHECK: xegpu.dpas +// CHECK: arith.truncf +// CHECH: xegpu.store_nd + +// ----- + +func.func @abs(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { + linalg.abs ins(%arg0 : memref<8x16xf16>) + outs(%arg1 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @abs +// CHECK-COUNT-1: xegpu.load_nd +// CHECK: math.absf +// CHECK: xegpu.store_nd + +// ----- + +func.func @add(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>, %arg2: memref<8x16xf16>) { + linalg.add ins(%arg0, %arg1 : memref<8x16xf16>, memref<8x16xf16>) + outs(%arg2 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @add +// CHECK-COUNT-2: xegpu.load_nd +// CHECK: arith.addf +// CHECK: xegpu.store_nd + +// ----- + +func.func @ceil(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { + linalg.ceil ins(%arg0 : memref<8x16xf16>) + outs(%arg1 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @ceil +// CHECK-COUNT-1: xegpu.load_nd +// CHECK: math.ceil +// CHECK: xegpu.store_nd + +// ----- + +func.func @div(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>, %arg2: memref<8x16xf16>) { + linalg.div ins(%arg0, %arg1 : memref<8x16xf16>, memref<8x16xf16>) + outs(%arg2 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @div +// CHECK-COUNT-2: xegpu.load_nd +// CHECK: arith.divf +// CHECK: xegpu.store_nd + +// ----- + +func.func @div_unsigned(%arg0: memref<8x16xi16>, %arg1: memref<8x16xi16>, %arg2: memref<8x16xi16>) { + linalg.div_unsigned ins(%arg0, %arg1 : memref<8x16xi16>, memref<8x16xi16>) + outs(%arg2 : memref<8x16xi16>) + return +} + +// CHECK-LABEL: func.func @div_unsigned +// CHECK-COUNT-2: xegpu.load_nd +// CHECK: arith.divui +// CHECK: xegpu.store_nd + +// ----- + +func.func @exp(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { + linalg.exp ins(%arg0 : memref<8x16xf16>) + outs(%arg1 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @exp +// CHECK-COUNT-1: xegpu.load_nd +// CHECK: math.exp +// CHECK: xegpu.store_nd + +// ----- + +func.func @floor(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { + linalg.floor ins(%arg0 : memref<8x16xf16>) + outs(%arg1 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @floor +// CHECK-COUNT-1: xegpu.load_nd +// CHECK: math.floor +// CHECK: xegpu.store_nd + +// ----- + +func.func @max(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>, %arg2: memref<8x16xf16>) { + linalg.max ins(%arg0, %arg1 : memref<8x16xf16>, memref<8x16xf16>) + outs(%arg2 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @max +// CHECK-COUNT-2: xegpu.load_nd +// CHECK: arith.maximumf +// CHECK: xegpu.store_nd + +// ----- + +func.func @mul(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>, %arg2: memref<8x16xf16>) { + linalg.mul ins(%arg0, %arg1 : memref<8x16xf16>, memref<8x16xf16>) + outs(%arg2 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @mul +// CHECK-COUNT-2: xegpu.load_nd +// CHECK: arith.mulf +// CHECK: xegpu.store_nd + +// ----- + +func.func @negf(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>) { + linalg.negf ins(%arg0 : memref<8x16xf16>) + outs(%arg1 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @negf +// CHECK-COUNT-1: xegpu.load_nd +// CHECK: arith.negf +// CHECK: xegpu.store_nd + +// ----- + +func.func @sub(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf16>, %arg2: memref<8x16xf16>) { + linalg.sub ins(%arg0, %arg1 : memref<8x16xf16>, memref<8x16xf16>) + outs(%arg2 : memref<8x16xf16>) + return +} + +// CHECK-LABEL: func.func @sub +// CHECK-COUNT-2: xegpu.load_nd +// CHECK: arith.subf +// CHECK: xegpu.store_nd + +// ----- + +func.func @add_large_f16(%arg0: memref<64x64xf16>, %arg1: memref<64x64xf16>, %arg2: memref<64x64xf16>) { + linalg.add ins(%arg0, %arg1 : memref<64x64xf16>, memref<64x64xf16>) + outs(%arg2 : memref<64x64xf16>) + return +} + +// CHECK-LABEL: func.func @add_large_f16 +// CHECK: xegpu.load_nd{{.*}}: !xegpu.tensor_desc<32x32xf16{{.*}}> -> vector<32x32xf16> +// CHECK: arith.addf{{.*}}: vector<8x32xf16> +// CHECK: xegpu.store_nd{{.*}}: vector<8x32xf16> + +// ----- + +func.func @add_large_f32(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { + linalg.add ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) + outs(%arg2 : memref<64x64xf32>) + return +} + +// CHECK-LABEL: func.func @add_large_f32 +// CHECK: xegpu.load_nd{{.*}}: !xegpu.tensor_desc<32x16xf32{{.*}}> -> vector<32x16xf32> +// CHECK: arith.addf{{.*}}: vector<16x16xf32> +// CHECK: xegpu.store_nd{{.*}}: vector<16x16xf32>