diff --git a/benchmarks/config/GPU/cuda.json b/benchmarks/config/GPU/cuda.json index fece4d03e..91a199e54 100644 --- a/benchmarks/config/GPU/cuda.json +++ b/benchmarks/config/GPU/cuda.json @@ -15,40 +15,12 @@ "flags": [ "--gpu=cuda" ], "extensions": [ "(avx2|asimd)" ] }, - "fp32_1024_manual_kernel_mlir": { - "type": "MLIR", - "benchmark": "GPU/gemm-fp32-1024-manual-kernel.mlir", - "environment": {}, - "flags": [ "-n", "100", "--gpu=cuda" ], - "extensions": [ "(avx2|asimd)" ] - }, "fp32_1024_base_mlir": { "type": "MLIR", "benchmark": "GPU/gemm-fp32-1024-base.mlir", "environment": {}, "flags": [ "-n", "100", "--gpu=cuda" ], "extensions": [ "(avx2|asimd)" ] - }, - "fp32_1024_packed_mlir": { - "type": "MLIR", - "benchmark": "GPU/gemm-fp32-1024-packed.mlir", - "environment": {}, - "flags": [ "-n", "100", "--gpu=cuda" ], - "extensions": [ "(avx2|asimd)" ] - }, - "fp16_1024_packed_mlir": { - "type": "MLIR", - "benchmark": "GPU/gemm-fp16-1024-packed.mlir", - "environment": {}, - "flags": [ "-n", "100", "--gpu=cuda" ], - "extensions": [ "(avx2|asimd)" ] - }, - "fp16_1024_packed_wmma_mlir": { - "type": "MLIR", - "benchmark": "GPU/gemm-fp16-1024-packed.mlir", - "environment": {}, - "flags": [ "-n", "100", "--gpu=cuda", "-run-args=-gpu-wmma" ], - "extensions": [ "(avx2|asimd)" ] } }}, { diff --git a/benchmarks/mlir/GPU/gemm-fp16-1024-packed.mlir b/benchmarks/mlir/GPU/gemm-fp16-1024-packed.mlir deleted file mode 100644 index bab590f94..000000000 --- a/benchmarks/mlir/GPU/gemm-fp16-1024-packed.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: tpp-run %s -n 10 \ -// RUN: -e entry -entry-point-result=void - -// BENCH_TOTAL_FLOPS: 536870912 - -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> -module { - func.func @entry(%arg0: tensor<16x64x16x16xf16>, %arg1: tensor<64x64x16x16xf16>, %arg2: tensor<16x64x16x16xf16>) -> tensor<16x64x16x16xf16> { - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x64x16x16xf16>, tensor<64x64x16x16xf16>) outs(%arg2 : tensor<16x64x16x16xf16>) { - ^bb0(%in: f16, %in_0: f16, %out: f16): - %1 = arith.mulf %in, %in_0 : f16 - %2 = arith.addf %out, %1 : f16 - linalg.yield %2 : f16 - } -> tensor<16x64x16x16xf16> - return %0 : tensor<16x64x16x16xf16> - } -} diff --git a/benchmarks/mlir/GPU/gemm-fp32-1024-manual-kernel.mlir b/benchmarks/mlir/GPU/gemm-fp32-1024-manual-kernel.mlir deleted file mode 100644 index 511412624..000000000 --- a/benchmarks/mlir/GPU/gemm-fp32-1024-manual-kernel.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: tpp-run %s -n 10 \ -// RUN: -e entry -entry-point-result=void - -// BENCH_TOTAL_FLOPS: 536870912 - -module attributes {gpu.container_module} { - func.func @entry(%arg0: memref<256x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<256x1024xf32>) { - %blocksX = arith.constant 8 : index - %blocksY = arith.constant 32 : index - %threads = arith.constant 32 : index - %m = arith.constant 256 : index - %n = arith.constant 1024 : index - %k = arith.constant 1024 : index - %c1 = arith.constant 1 : index - gpu.launch_func @entry_kernel::@entry_kernel - blocks in (%blocksX, %blocksY, %c1) - threads in (%threads, %threads, %c1) - args(%arg0 : memref<256x1024xf32>, %arg1 : memref<1024x1024xf32>, %arg2 : memref<256x1024xf32>, %m : index, %n : index, %k : index) - return - } - - gpu.module @entry_kernel { - gpu.func @entry_kernel(%arg0: memref<256x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<256x1024xf32>, %m: index, %n: index, %k: index) - kernel attributes {known_block_size = array, known_grid_size = array} { - %bx = gpu.block_id x - %by = gpu.block_id y - %bDimx = gpu.block_dim x - %bDimy = gpu.block_dim y - %tx = gpu.thread_id x - %ty = gpu.thread_id y - - // row = blockIdx.x * blockDim.x + threadIdx.x - %rowOff = arith.muli %bx, %bDimx : index - %row = arith.addi %rowOff, %tx : index - - // col = blockIdx.y * blockDim.y + threadIdx.y - %colOff = arith.muli %by, %bDimy : index - %col = arith.addi %colOff, %ty : index - - %rowCheck = arith.cmpi ult, %row, %m : index - %colCheck = arith.cmpi ult, %col, %n : index - %isValidThread = arith.andi %rowCheck, %colCheck : i1 - - scf.if %isValidThread { - %lb = arith.constant 0 : index - %step = arith.constant 1 : index - %init = memref.load %arg2[%row, %col] : memref<256x1024xf32> - - %sum = scf.for %i = %lb to %k step %step iter_args(%partial = %init) -> (f32) { - %2 = memref.load %arg0[%row, %i] : memref<256x1024xf32> - %3 = memref.load %arg1[%i, %col] : memref<1024x1024xf32> - %5 = arith.mulf %2, %3 : f32 - %6 = arith.addf %partial, %5 : f32 - scf.yield %6 : f32 - } - - memref.store %sum, %arg2[%row, %col] : memref<256x1024xf32> - } - - gpu.return - } - } -} diff --git a/benchmarks/mlir/GPU/gemm-fp32-1024-packed.mlir b/benchmarks/mlir/GPU/gemm-fp32-1024-packed.mlir deleted file mode 100644 index a8fa0fbab..000000000 --- a/benchmarks/mlir/GPU/gemm-fp32-1024-packed.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: tpp-run %s -n 10 \ -// RUN: -e entry -entry-point-result=void - -// BENCH_TOTAL_FLOPS: 536870912 - -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> -module { - func.func @entry(%arg0: tensor<16x64x16x16xf32>, %arg1: tensor<64x64x16x16xf32>, %arg2: tensor<16x64x16x16xf32>) -> tensor<16x64x16x16xf32> { - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x64x16x16xf32>, tensor<64x64x16x16xf32>) outs(%arg2 : tensor<16x64x16x16xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %1 = arith.mulf %in, %in_0 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } -> tensor<16x64x16x16xf32> - return %0 : tensor<16x64x16x16xf32> - } -} diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index d393a40f6..93c8a73ca 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -119,22 +119,12 @@ def GpuConversion : Pass<"gpu-conversion", "ModuleOp"> { let description = [{ Convert all eligble operations into generic GPU operations. }]; - let options = [ - Option<"useWmma", "wmma", - "bool", /*default=*/"false", - "Use WMMA operations">, - ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, - ]; let dependentDialects = ["linalg::LinalgDialect", "gpu::GPUDialect", "scf::SCFDialect", "memref::MemRefDialect", "xegpu::XeGPUDialect"]; let options = [ - Option<"useWmma", "wmma", - "bool", /*default=*/"false", - "Use WMMA operations">, - ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, Option<"isIntel", "intel", "bool", /*default=*/"false", "Convert for Intel GPU">, diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 60845c849..a1decd3ee 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -323,27 +323,6 @@ def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { }]; } -def LinalgToGpu : Pass<"linalg-to-gpu", "func::FuncOp"> { - let summary = "Convert linalg ops to be GPU compatible."; - let description = [{ - Lower linalg to ops optimized for computation on GPU. - }]; - let dependentDialects = ["linalg::LinalgDialect", - "scf::SCFDialect", - "memref::MemRefDialect", - "gpu::GPUDialect", - "arith::ArithDialect"]; - let options = [ - Option<"useWmma", "wmma", - "bool", /*default=*/"false", - "Use WMMA operations">, - ListOption<"warpTile", "warp-tile", "int64_t", "Warp tile sizes MxNxK">, - Option<"kTile", "k-tile", "int64_t", - /*default=*/"32", - "GEMM tile size for reduction dimension.">, - ]; -} - def GpuDataTransfer : Pass<"gpu-data-transfer", "func::FuncOp"> { let summary = "Transfer data to and from GPU."; let description = [{ diff --git a/lib/TPP/GPU/CMakeLists.txt b/lib/TPP/GPU/CMakeLists.txt index 94e456b2c..bc5082b9a 100644 --- a/lib/TPP/GPU/CMakeLists.txt +++ b/lib/TPP/GPU/CMakeLists.txt @@ -5,7 +5,6 @@ add_mlir_library(TPPGPU GpuToCuda.cpp SetSPIRVCapabilities.cpp SetSPIRVAbiAttribute.cpp - LinalgToGpu.cpp GpuDataTransfer.cpp GpuInlineConstants.cpp LinalgToXeGPU.cpp diff --git a/lib/TPP/GPU/GpuConversion.cpp b/lib/TPP/GPU/GpuConversion.cpp index dc5f451ee..aa4963fd5 100644 --- a/lib/TPP/GPU/GpuConversion.cpp +++ b/lib/TPP/GPU/GpuConversion.cpp @@ -62,11 +62,8 @@ struct GpuConversion : public tpp::impl::GpuConversionBase, if (isIntel) { pm.addNestedPass( createLinalgToXeGPU(LinalgToXeGPUOptions{kTile, stages, dpasTile})); - } else { - pm.addNestedPass( - createLinalgToGpu(LinalgToGpuOptions{useWmma, warpTile, kTile})); } - pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); + pm.addNestedPass(createConvertLinalgToLoopsPass()); // Map loops into GPU kernels. pm.addNestedPass(createGpuMapParallelLoopsPass()); diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index a57a7edfb..f64d98a0e 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -41,15 +41,6 @@ using namespace mlir; using namespace mlir::tpp; -llvm::cl::opt gpuWmma("gpu-wmma", - llvm::cl::desc("Enable GPU WMMA support"), - llvm::cl::init(false)); - -llvm::cl::list wmmaTileSizes( - "wmma-tile-sizes", llvm::cl::desc("GPU WMMA tile sizes MxNxK"), - llvm::cl::list_init(SmallVector{16, 16, 16}), - llvm::cl::CommaSeparated); - llvm::cl::list gpuBlockTile("gpu-block-tile", llvm::cl::desc("GPU block tile size"), llvm::cl::list_init(SmallVector{128, 128}), @@ -165,29 +156,30 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, GpuType gpuType = parseGpuOption(this->gpuBackend); GpuOptions gpuOptions = getGpuOptions(gpuType); + // Input preprocessing. + pm.addPass(createCleanup()); + pm.addPass(createFoldIntoEltwise()); + pm.addNestedPass(createConvertLinalgToInplace()); + // Tile to split the kernel into threads and blocks. // Use default tiling to handle both packed and unpacked ops. pm.addPass(createCleanup()); - if (gpuType == GpuType::Intel) { - // First split computation into grid with blocks of specified size. - TileConsumerAndFuseProducersOptions blockTileOptions; + // First split computation into grid with blocks of specified size. + TileConsumerAndFuseProducersOptions blockTileOptions; + if (!llvm::any_of(gpuBlockTile, [](int64_t tile) { return tile == -1; })) blockTileOptions.tileSizes = gpuBlockTile; - blockTileOptions.minTileFactor = 1; - pm.addPass(createTileConsumerAndFuseProducers(blockTileOptions)); - - // Then try to further split computation into subtiles. - // This allows to split larger computations across multiple - // threads/workitems. For smaller workloads, it provides another - // chance for outlining. - TileConsumerAndFuseProducersOptions threadTileOptions; + blockTileOptions.minTileFactor = 1; + pm.addPass(createTileConsumerAndFuseProducers(blockTileOptions)); + + // Then try to further split computation into subtiles. + // This allows to split larger computations across multiple + // threads/workitems. For smaller workloads, it provides another + // chance for outlining. + TileConsumerAndFuseProducersOptions threadTileOptions; + if (!llvm::any_of(gpuThreadTile, [](int64_t tile) { return tile == -1; })) threadTileOptions.tileSizes = gpuThreadTile; - threadTileOptions.minTileFactor = 1; - pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions)); - } else { - TileConsumerAndFuseProducersOptions tilingOptions; - tilingOptions.minTileFactor = 1; - pm.addPass(createTileConsumerAndFuseProducers(tilingOptions)); - } + threadTileOptions.minTileFactor = 1; + pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions)); pm.addPass(createCleanup()); // Preprocess and bufferize as further conversion requires memref @@ -198,9 +190,8 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, pm.addPass(createCleanup()); // Convert to generic GPU ops. - pm.addPass(createGpuConversion( - GpuConversionOptions{gpuWmma, wmmaTileSizes, gpuType == GpuType::Intel, - kTile, stages, gpuDpasTile})); + pm.addPass(createGpuConversion(GpuConversionOptions{ + gpuType == GpuType::Intel, kTile, stages, gpuDpasTile})); // Lower GPU ops to the chosen GPU backend. switch (gpuType) { @@ -212,7 +203,7 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, gpuOptions.triple, gpuOptions.chip, gpuOptions.features})); break; } - case GpuType::Intel: + case GpuType::Intel: { pm.addPass(xegpu::createXeGPUFoldAliasOps()); std::string clientApi = "intel"; @@ -223,6 +214,7 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, break; } + } // Covert all local dialects like perf. pm.addPass(createLocalDialectsLowering()); diff --git a/lib/TPP/GPU/LinalgToGpu.cpp b/lib/TPP/GPU/LinalgToGpu.cpp deleted file mode 100644 index 6fdca1894..000000000 --- a/lib/TPP/GPU/LinalgToGpu.cpp +++ /dev/null @@ -1,763 +0,0 @@ -//===- LinalgToGpu.cpp -------------------------------------------*- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Passes.h" - -#include "TPP/IR/MatcherUtils.h" -#include "TPP/Transforms/Utils/ValueUtils.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" - -#include - -using namespace mlir; -using namespace mlir::tpp; - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_LINALGTOGPU -#include "TPP/Passes.h.inc" -} // namespace tpp -} // namespace mlir - -namespace { - -// Creates an outermost parallel loop wrapper around an operation to represent -// number of GPU blocks. -// If there is already a parallel loop present, no operation is created and -// a nullopt is returned instead. -static std::optional -createGpuBlocksWrapper(Operation *op, ArrayRef blockDims, - PatternRewriter &rewriter) { - assert(blockDims.size() <= 3 && "Too many GPU blocks dimensions"); - - auto loc = op->getLoc(); - - auto *parentOp = op->getParentOp(); - if (isa(parentOp)) - return std::nullopt; - - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - - SmallVector gpuBlocks; - SmallVector lbs; - SmallVector steps; - for (auto blockDim : blockDims) { - auto blockSize = rewriter.create(loc, blockDim); - gpuBlocks.push_back(blockSize); - // Add matching number of lbs and steps. - lbs.push_back(zero); - steps.push_back(one); - } - - return rewriter.create(loc, lbs, gpuBlocks, steps); -} - -// Return true if the operation can be represented with WMMA compute. -static bool isMMACompatible(linalg::LinalgOp linalgOp, - ArrayRef warpTile, int kTile) { - if (!(isa(linalgOp) || - isa(linalgOp))) { - return false; - } - - if (warpTile.size() != 3) - return false; - - // Only static shapes are supported. - if (linalgOp.hasDynamicShape()) - return false; - - auto aType = cast(linalgOp.getDpsInputs()[0].getType()); - auto bType = cast(linalgOp.getDpsInputs()[1].getType()); - auto cType = cast(linalgOp.getDpsInits()[0].getType()); - - auto elemTypeA = aType.getElementType(); - auto elemTypeB = bType.getElementType(); - auto elemTypeC = cType.getElementType(); - - // TODO: Add more WMMA combinations. - bool isSupportedPrecision = - (elemTypeA.isF16() && elemTypeB.isF16() && elemTypeC.isF16()) || - (elemTypeA.isF16() && elemTypeB.isF16() && elemTypeC.isF32()); - if (!isSupportedPrecision) - return false; - - auto mDim = cType.getShape()[0]; - auto nDim = cType.getShape()[1]; - auto kDim = aType.getShape().back(); - - // Validate warp tile sizes. - // The computation dimensions must fit into the tiles. - // Reduction dimension tile size has to be compatible - // with the warp tile. - int wmmaTileM = warpTile[0]; - int wmmaTileN = warpTile[1]; - int wmmaTileK = warpTile[2]; - if ((mDim % wmmaTileM != 0) || (nDim % wmmaTileN != 0) || - (kDim % wmmaTileK != 0) || (kTile % wmmaTileK != 0)) { - return false; - } - - return true; -} - -// Fuse a consumer using WMMA operations. -// Returns updated store op or nullopt if the fusion fails. -static std::optional> -eltwiseFusion(linalg::LinalgOp rootOp, linalg::LinalgOp consumer, - SmallVector rootStoreOps, - PatternRewriter &rewriter) { - assert(rootStoreOps.size() > 0 && "Requires at least one store op"); - - Location loc = rootOp.getLoc(); - - auto rootOutput = rootOp.getDpsInits()[0]; - auto outputType = cast(rootOutput.getType()); - - // Must be a floating point type. - // TODO: Add integer support. - auto floatType = dyn_cast(outputType.getElementType()); - if (!floatType) - return std::nullopt; - - // Insert fused eltwise ops before the store and later replace the store - // with a new result. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(rootStoreOps[0]); - - // It is assumed that WMMA tile sizes do not vary between different - // operations i.e., the original workload has been split into - // a series of operations using the same WMMA configuration. - gpu::MMAMatrixType mmaOutputType = rootStoreOps[0].getSrc().getType(); - auto leadingDim = rootStoreOps[0].getLeadDimension(); - - // Collect new results after fusion. - SmallVector fusedRes; - - SmallVector operands; - if (structured_match::utils::isTwoDAddOp(consumer, &operands)) { - // Get the value to be added - load the tile first. - // Must be a buffer of the same type - scalar broadcast is not supported. - // TODO: Add support for eltwise with broadcast. - auto addValue = (operands[0] != rootOutput) ? operands[0] : operands[1]; - if (addValue.getType() != rootOutput.getType()) - return std::nullopt; - - for (gpu::SubgroupMmaStoreMatrixOp rootStoreOp : rootStoreOps) { - auto storeIndices = rootStoreOp.getIndices(); - - // Fuse the add into the matmul body. - auto loadOp = - rewriter - .create( - loc, mmaOutputType, addValue, storeIndices, leadingDim, - /*transpose=*/UnitAttr()) - .getRes(); - auto eltwiseAttr = gpu::MMAElementwiseOp::ADDF; - auto addRes = - rewriter - .create( - loc, mmaOutputType, ValueRange{rootStoreOp.getSrc(), loadOp}, - eltwiseAttr) - .getRes(); - fusedRes.push_back(addRes); - } - } else if (structured_match::utils::isTwoDReluOp(consumer, &operands)) { - Value zeroFloat = rewriter.create( - loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); - - Value zeroTile = rewriter.create( - loc, mmaOutputType, zeroFloat); - for (auto rootStoreOp : rootStoreOps) { - // Fuse the relu into the matmul body. - auto eltwiseAttr = gpu::MMAElementwiseOp::MAXF; - auto maxRes = - rewriter - .create( - loc, mmaOutputType, - ValueRange{rootStoreOp.getSrc(), zeroTile}, eltwiseAttr) - .getRes(); - fusedRes.push_back(maxRes); - } - } else { - // Not a fusable operation. Bail out. - return std::nullopt; - } - - // Fusion must have failed, if number of new results is different. - // Bail out. - if (fusedRes.size() != rootStoreOps.size()) - return std::nullopt; - - // Store the new result. - SmallVector newStores; - for (size_t i = 0; i < rootStoreOps.size(); i++) { - auto storeIndices = rootStoreOps[i].getIndices(); - - auto newStore = rewriter.create( - loc, fusedRes[i], rootStoreOps[i].getDstMemref(), storeIndices, - leadingDim, - /*transpose=*/UnitAttr()); - newStores.push_back(newStore); - } - - // Replace store ops and cleanup standalone consumer. - for (size_t i = 0; i < rootStoreOps.size(); i++) - rewriter.replaceOp(rootStoreOps[i], newStores[i]); - - rewriter.eraseOp(consumer); - - return newStores; -} - -// Fuse a consumer using scalar operations. -// TODO: Extend scalar fusion to support multiple stores. -// -// Returns updated store op or nullopt if the fusion fails. -static std::optional eltwiseFusion(linalg::LinalgOp rootOp, - linalg::LinalgOp consumer, - memref::StoreOp rootStoreOp, - PatternRewriter &rewriter) { - Location loc = rootOp.getLoc(); - auto rootOutput = rootOp.getDpsInits()[0]; - auto outputType = cast(rootOutput.getType()); - - // Must be a floating point type. - // TODO: Add integer support. - auto floatType = dyn_cast(outputType.getElementType()); - if (!floatType) - return std::nullopt; - - auto storeIndices = rootStoreOp.getIndices(); - - // Insert fused eltwise ops before the store and later replace the store - // with a new result. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(rootStoreOp); - - std::optional newStore = std::nullopt; - SmallVector operands; - if (structured_match::utils::isTwoDAddOp(consumer, &operands)) { - // Get the value to be added. Load the element first, if necessary. - auto addValue = (operands[0] != rootOutput) ? operands[0] : operands[1]; - if (isa(addValue.getType())) { - addValue = rewriter.create(loc, addValue, storeIndices) - .getResult(); - } - // Fuse the add into the matmul body. - auto addOp = - rewriter.create(loc, rootStoreOp.getValue(), addValue); - // Store the new result. - newStore = rewriter.replaceOpWithNewOp( - rootStoreOp, addOp.getResult(), rootOutput, storeIndices); - } else if (structured_match::utils::isTwoDReluOp(consumer, &operands)) { - // Fuse the relu into the matmul body. - Value zeroFloat = rewriter.create( - loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); - auto maxOp = rewriter.create(loc, rootStoreOp.getValue(), - zeroFloat); - // Store the new result. - newStore = rewriter.replaceOpWithNewOp( - rootStoreOp, maxOp.getResult(), rootOutput, storeIndices); - } else { - // Not a fusable operation. Bail out. - return std::nullopt; - } - - rewriter.eraseOp(consumer); - - return newStore; -} - -// Find operations fusable with the given root op. -// -// A simple fusion strategy that looks at the other operations after the root -// linalg op and tries to fuse them. -static SmallVector -getFusableConsumers(linalg::LinalgOp rootOp) { - auto *parentOp = rootOp->getParentOp(); - auto rootOutput = rootOp.getDpsInits()[0]; - - // Traverse other ops within the same region and collect consumers. - SmallVector consumers; - Operation *nextOp = rootOp; - while ((nextOp = nextOp->getNextNode())) { - // Potential consumers must be within the same region. - if (nextOp->getParentOp() != parentOp) - break; - - // Only other linalg ops are expected as consumers. - // TODO: might need to be relaxed to skip over ops without side effects - auto consumer = dyn_cast(nextOp); - if (!consumer || !linalg::isElementwise(consumer)) - break; - // Require the same iteration space. - if (consumer.getNumParallelLoops() != rootOp.getNumParallelLoops()) - break; - - auto outBuf = consumer.getDpsInitOperand(0)->get(); - // Check that the op reuses the same output buffer as the root op. - // Otherwise, it is assumed that the op cannot be fused. - // TODO: Consider adding support for eltwise with broadcast. - if (outBuf != rootOutput) - break; - - consumers.push_back(consumer); - } - - return consumers; -} - -// Fuse elementwise consumers within a GPU kernel. -// -// Fusion bails on the first mismatch. -// Returns updated store ops. -template -static StoreTy fuseEltwiseConsumers(linalg::LinalgOp rootOp, - StoreTy rootStoreOps, - PatternRewriter &rewriter) { - // Constrain conversion to the supported fusion types. - static_assert( - llvm::is_one_of>::value); - - auto consumers = getFusableConsumers(rootOp); - - for (auto op : consumers) { - std::optional updatedStoreOps = std::nullopt; - - updatedStoreOps = eltwiseFusion(rootOp, op, rootStoreOps, rewriter); - - // Failed to fuse operation. Bail out. - if (!updatedStoreOps) - break; - - rootStoreOps = *updatedStoreOps; - } - - return rootStoreOps; -} - -// Create WMMA instructions out of matmul-like operation. -static LogicalResult gemmToGpuMMA(linalg::LinalgOp linalgOp, - ArrayRef warpTile, int kTile, - PatternRewriter &rewriter) { - assert((isa(linalgOp) || - isa(linalgOp)) && - "Requires a matmul like op for MMA lowering"); - - Location loc = linalgOp.getLoc(); - - // If there is no parallel loop, create a unit blocks wrapper around the - // current op. - // This ensures that WMMA operations are created at the thread level (inner - // nested parallel loops). - auto blocksLoop = createGpuBlocksWrapper(linalgOp, {1, 1}, rewriter); - if (blocksLoop) - rewriter.setInsertionPoint(blocksLoop->getBody()->getTerminator()); - - auto matA = linalgOp.getDpsInputs()[0]; - auto matB = linalgOp.getDpsInputs()[1]; - auto matC = linalgOp.getDpsInits()[0]; - - auto typeA = cast(matA.getType()); - auto typeB = cast(matB.getType()); - auto typeC = cast(matC.getType()); - - auto stridesA = utils::getStaticStrides(matA); - auto stridesB = utils::getStaticStrides(matB); - auto stridesC = utils::getStaticStrides(matC); - - if (failed(stridesA) || failed(stridesB) || failed(stridesC)) { - return rewriter.notifyMatchFailure( - linalgOp, "Expect static strides for MMA lowering"); - } - if (stridesA->back() != 1 || stridesB->back() != 1 || stridesC->back() != 1) { - return rewriter.notifyMatchFailure( - linalgOp, - "Expect unit stride in the innermost dimension for MMA operations"); - } - - int dimM = typeC.getShape()[0]; - int dimN = typeC.getShape()[1]; - int dimK = typeA.getShape().back(); - - int64_t wmmaTileM = warpTile[0]; - int64_t wmmaTileN = warpTile[1]; - int64_t wmmaTileK = warpTile[2]; - - gpu::MMAMatrixType mmaTypeA = gpu::MMAMatrixType::get( - {wmmaTileM, wmmaTileK}, typeA.getElementType(), "AOp"); - gpu::MMAMatrixType mmaTypeB = gpu::MMAMatrixType::get( - {wmmaTileK, wmmaTileN}, typeB.getElementType(), "BOp"); - gpu::MMAMatrixType mmaTypeC = gpu::MMAMatrixType::get( - {wmmaTileM, wmmaTileN}, typeC.getElementType(), "COp"); - - bool isBrgemm = isa(linalgOp); - - // Skip batch dimension stride in case of brgemm. - auto lda = rewriter.getIndexAttr(stridesA->begin()[isBrgemm ? 1 : 0]); - auto ldb = rewriter.getIndexAttr(stridesB->begin()[isBrgemm ? 1 : 0]); - auto ldc = rewriter.getIndexAttr(stridesC->front()); - - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - // WMMA requires warp/subgroup size of 32 threads/work items. - Value subgroupSize = rewriter.create(loc, 32); - - // Create parallel loop to indicate that the whole subgroup is performing MMA - // operations together. It also ensures that the kernel is outlined with - // the correct number of threads. - auto parallelLoop = rewriter.create( - loc, ValueRange{zero}, ValueRange{subgroupSize}, ValueRange{one}); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(parallelLoop.getBody()->getTerminator()); - - // Fetch the inital value of the output element. - SmallVector tilesC; - for (int m = 0; m < dimM; m += wmmaTileM) { - for (int n = 0; n < dimN; n += wmmaTileN) { - Value rowIdx = rewriter.create(loc, m); - Value colIdx = rewriter.create(loc, n); - Value tileC = - rewriter - .create( - loc, mmaTypeC, matC, ValueRange{rowIdx, colIdx}, ldc, - /*transpose=*/UnitAttr()) - .getRes(); - tilesC.push_back(tileC); - } - } - - // Create a loop and step into it. - auto startLoop = [&](int lb, int ub, int step) -> scf::ForOp { - Value lbCst = rewriter.create(loc, lb); - Value ubCst = rewriter.create(loc, ub); - Value stepCst = rewriter.create(loc, step); - scf::ForOp loopOp = - rewriter.create(loc, lbCst, ubCst, stepCst, tilesC); - rewriter.setInsertionPointToStart(loopOp.getBody()); - return loopOp; - }; - auto getLoopIterValues = [&](scf::ForOp loopOp) -> SmallVector { - SmallVector loopIterVals; - for (auto iterArg : loopOp.getRegionIterArgs()) - loopIterVals.push_back(iterArg); - return loopIterVals; - }; - - // Construct and move into batch reduction loop. - // Propagate output values as iter args. - scf::ForOp batchLoop; - Value batchIv; - if (isBrgemm) { - batchLoop = startLoop(0, typeA.getShape()[0], 1); - batchIv = batchLoop.getInductionVar(); - tilesC = getLoopIterValues(batchLoop); - } - - // Construct and move into GEMM reduction dimension tiling loop. - // Propagate output values as iter args. - scf::ForOp kDimLoop = startLoop(0, dimK, kTile); - Value kDimIv = kDimLoop.getInductionVar(); - tilesC = getLoopIterValues(kDimLoop); - - // Load A sub-tiles. - SmallVector tilesA; - for (int m = 0; m < dimM; m += wmmaTileM) { - for (int k = 0; k < kTile; k += wmmaTileK) { - Value rowOffset = rewriter.create(loc, m); - Value colOffset = rewriter.create(loc, k); - - Value rowIdx = rowOffset; - Value colIdx = rewriter.create(loc, kDimIv, colOffset); - - Value tileA = rewriter - .create( - loc, mmaTypeA, matA, - isBrgemm ? ValueRange{batchIv, rowIdx, colIdx} - : ValueRange{rowIdx, colIdx}, - lda, - /*transpose=*/UnitAttr()) - .getRes(); - tilesA.push_back(tileA); - } - } - - // Load B sub-tiles. - SmallVector tilesB; - for (int k = 0; k < kTile; k += wmmaTileK) { - for (int n = 0; n < dimN; n += wmmaTileN) { - Value rowOffset = rewriter.create(loc, k); - Value colOffset = rewriter.create(loc, n); - - Value rowIdx = rewriter.create(loc, kDimIv, rowOffset); - Value colIdx = colOffset; - - Value tileB = rewriter - .create( - loc, mmaTypeB, matB, - isBrgemm ? ValueRange{batchIv, rowIdx, colIdx} - : ValueRange{rowIdx, colIdx}, - ldb, /*transpose=*/UnitAttr()) - .getRes(); - tilesB.push_back(tileB); - } - } - - const int numTilesM = dimM / wmmaTileM; - const int numTilesN = dimN / wmmaTileN; - const int numTilesK = kTile / wmmaTileK; - - // Compute sub-tiles of the C tile. - // - // Iterate over the reduction dimension sub-tiles as the outermost - // loop to minimize read after write conflicts between partial - // computations of the same C sub-tile. - // - // Initialize sub-tiles with the loaded C tiles. - SmallVector results = tilesC; - for (int k = 0; k < numTilesK; k++) { - for (int m = 0; m < numTilesM; m++) { - for (int n = 0; n < numTilesN; n++) { - int aIdx = m * numTilesK + k; - int bIdx = k * numTilesN + n; - int cIdx = m * numTilesN + n; - - Value result = rewriter - .create( - loc, tilesC[cIdx].getType(), tilesA[aIdx], - tilesB[bIdx], results[cIdx], - /*a_transpose=*/UnitAttr(), - /*b_transpose=*/UnitAttr()) - .getRes(); - // Update sub-tile partial result. - results[cIdx] = result; - } - } - } - - // Create loop terminator and exit the loop. - auto terminateLoop = [&](scf::ForOp loopOp, SmallVector resultValues) { - rewriter.setInsertionPointToEnd(loopOp.getBody()); - rewriter.create(loc, resultValues); - rewriter.setInsertionPointAfter(loopOp); - }; - - // Terminate and exit reduction dim loop. - terminateLoop(kDimLoop, results); - results = kDimLoop.getResults(); - - // Terminate and exit batch reduce loop. - if (isBrgemm) { - terminateLoop(batchLoop, results); - results = batchLoop.getResults(); - } - - // Write back the final C sub-tiles results to the output buffer. - SmallVector storeOps; - for (int m = 0; m < numTilesM; m++) { - for (int n = 0; n < numTilesN; n++) { - int resIdx = m * numTilesN + n; - - Value rowIdx = - rewriter.create(loc, m * wmmaTileM); - Value colIdx = - rewriter.create(loc, n * wmmaTileN); - auto storeOp = rewriter.create( - loc, results[resIdx], matC, ValueRange{rowIdx, colIdx}, ldc, - /*transpose=*/UnitAttr()); - storeOps.push_back(storeOp); - } - } - - (void)fuseEltwiseConsumers>( - linalgOp, storeOps, rewriter); - - rewriter.eraseOp(linalgOp); - - return success(); -} - -// Create loops out of matmul-like operation. -static LogicalResult gemmToGpuLoops(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) { - assert((isa(linalgOp) || - isa(linalgOp)) && - "Requires a matmul like op for loop lowering"); - - Location loc = linalgOp.getLoc(); - - auto matA = linalgOp.getDpsInputs()[0]; - auto matB = linalgOp.getDpsInputs()[1]; - auto matC = linalgOp.getDpsInits()[0]; - - ArrayRef shapeC = cast(matC.getType()).getShape(); - ArrayRef shapeA = cast(matA.getType()).getShape(); - - // Parallel dims. - Value i = rewriter.create(loc, shapeC[0]); - Value j = rewriter.create(loc, shapeC[1]); - // Reduction dim. - Value k = rewriter.create(loc, shapeA.back()); - // Lbs. - Value zero = rewriter.create(loc, 0); - // Step. - Value one = rewriter.create(loc, 1); - SmallVector ivs; - - // Create parallel loops over the outer dimensions. - auto parallelLoop = rewriter.create( - loc, ValueRange{zero, zero}, ValueRange{i, j}, ValueRange{one, one}); - auto parallelIvs = parallelLoop.getInductionVars(); - ivs.append(parallelIvs.begin(), parallelIvs.end()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(parallelLoop.getBody()->getTerminator()); - - // Fetch the inital value of the output element. - Value initVal = - rewriter.create(loc, matC, parallelIvs).getResult(); - - bool isBrgemm = isa(linalgOp); - scf::ForOp batchLoop; - Value batchIv; - if (isBrgemm) { - Value batch = rewriter.create(loc, shapeA[0]); - batchLoop = - rewriter.create(loc, zero, batch, one, ValueRange{initVal}); - rewriter.setInsertionPointToStart(batchLoop.getBody()); - batchIv = batchLoop.getInductionVar(); - initVal = batchLoop.getRegionIterArg(0); - } - - // Compute matmul with a loop over reduction dimension. - // Each GPU thread computes a single result element. - // Accumulate result locally through loop's iter args. - // This maps to more efficient computation as the accumulation is kept - // locally by a thread. - auto bodyBuilder = [&](OpBuilder &b, Location loc, Value localIv, - ValueRange iterArgs) { - SmallVector loopIvs = ivs; - loopIvs.push_back(localIv); - assert(loopIvs.size() == 3); - Value localI = loopIvs[0]; - Value localJ = loopIvs[1]; - Value localK = loopIvs[2]; - Value scalarA = - b.create(loc, matA, - isBrgemm ? ValueRange{batchIv, localI, localK} - : ValueRange{localI, localK}); - Value scalarB = - b.create(loc, matB, - isBrgemm ? ValueRange{batchIv, localK, localJ} - : ValueRange{localK, localJ}); - Value scalarMul = b.create(loc, scalarA, scalarB); - auto scalarAdd = b.create(loc, iterArgs[0], scalarMul); - - b.create(loc, scalarAdd.getResult()); - }; - auto accumulationLoop = rewriter.create( - loc, zero, k, one, ValueRange{initVal}, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { - bodyBuilder(b, loc, iv, iterArgs); - }); - - Value result = accumulationLoop.getResults()[0]; - - if (isBrgemm) { - rewriter.setInsertionPointToEnd(batchLoop.getBody()); - rewriter.create(loc, ValueRange{result}); - result = batchLoop.getResults()[0]; - rewriter.setInsertionPointAfter(batchLoop); - } - - // Write back the total sum to the output buffer. - auto storeOp = - rewriter.create(loc, result, matC, parallelIvs); - - (void)fuseEltwiseConsumers(linalgOp, storeOp, rewriter); - - rewriter.eraseOp(linalgOp); - - return success(); -} - -// Convert linalg.matmul or linalg.batch_reduce_matmul to GPU-compatible kernel. -template -struct ConvertGemmLikeToGpu : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - // Constrain conversion to the supported GEMM-like ops. - static_assert(llvm::is_one_of::value); - - ConvertGemmLikeToGpu(MLIRContext *ctx, LinalgToGpuOptions options) - : OpRewritePattern(ctx), options(options) {} - - LogicalResult matchAndRewrite(LinalgOpTy gemmLikeOp, - PatternRewriter &rewriter) const override { - if (!gemmLikeOp.hasPureBufferSemantics()) { - return rewriter.notifyMatchFailure( - gemmLikeOp, "Linalg brgemm to GPU expects memref type"); - } - if (gemmLikeOp.hasDynamicShape()) { - return rewriter.notifyMatchFailure( - gemmLikeOp, "Expect static shape when mapping to GPU"); - } - - // Ensure that reduction dimension tiling also works for smaller workloads. - auto aType = cast(gemmLikeOp.getDpsInputs()[0].getType()); - auto kDim = aType.getShape().back(); - auto kTile = kDim < options.kTile ? kDim : options.kTile; - - if (options.useWmma && - isMMACompatible(gemmLikeOp, options.warpTile, kTile)) { - return gemmToGpuMMA(gemmLikeOp, options.warpTile, kTile, rewriter); - } - // TODO: Add warp and K dim tiling to looped implementation. - return gemmToGpuLoops(gemmLikeOp, rewriter); - } - -private: - LinalgToGpuOptions options; -}; - -void populateLinalgToGpuPatterns(RewritePatternSet &patterns, - LinalgToGpuOptions options) { - patterns.add, - ConvertGemmLikeToGpu>( - patterns.getContext(), options); -} - -struct LinalgToGpu : public tpp::impl::LinalgToGpuBase { - using LinalgToGpuBase::LinalgToGpuBase; - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateLinalgToGpuPatterns(patterns, - LinalgToGpuOptions{useWmma, warpTile, kTile}); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - -} // namespace diff --git a/test/GPU/CUDA/Integration/addf-cuda.mlir b/test/GPU/CUDA/Integration/addf-cuda.mlir deleted file mode 100644 index 2083e4178..000000000 --- a/test/GPU/CUDA/Integration/addf-cuda.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -module attributes {gpu.container_module} { - gpu.module @kernels { - gpu.func @kernel_add(%arg0 : memref<8xf32>, %arg1 : memref<8xf32>, %arg2 : memref<8xf32>) - kernel attributes { known_block_size = array, known_grid_size = array } { - %0 = gpu.block_id x - %1 = memref.load %arg0[%0] : memref<8xf32> - %2 = memref.load %arg1[%0] : memref<8xf32> - %3 = arith.addf %1, %2 : f32 - memref.store %3, %arg2[%0] : memref<8xf32> - gpu.return - } - } - - func.func @entry() { - %arg0, %t0 = gpu.alloc async () : memref<8xf32> - gpu.wait [%t0] - %arg1, %t1 = gpu.alloc async () : memref<8xf32> - gpu.wait [%t1] - %arg2, %t2 = gpu.alloc async () : memref<8xf32> - gpu.wait [%t2] - - %value0 = arith.constant 0.0 : f32 - %value1 = arith.constant 1.1 : f32 - %value2 = arith.constant 2.2 : f32 - linalg.fill ins(%value1 : f32) outs(%arg0 : memref<8xf32>) - linalg.fill ins(%value2 : f32) outs(%arg1 : memref<8xf32>) - linalg.fill ins(%value0 : f32) outs(%arg2 : memref<8xf32>) - - %cst1 = arith.constant 1 : index - %cst8 = arith.constant 8 : index - gpu.launch_func @kernels::@kernel_add - blocks in (%cst8, %cst1, %cst1) threads in (%cst1, %cst1, %cst1) - args(%arg0 : memref<8xf32>, %arg1 : memref<8xf32>, %arg2 : memref<8xf32>) - - %out = memref.alloc() : memref<8xf32> - %tOut = gpu.memcpy async %out, %arg2 : memref<8xf32>, memref<8xf32> - gpu.wait [%tOut] - %cast = memref.cast %out : memref<8xf32> to memref<*xf32> - call @printMemrefF32(%cast) : (memref<*xf32>) -> () - - return - } - func.func private @printMemrefF32(%ptr : memref<*xf32>) -} - -// CHECK: [3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3] diff --git a/test/GPU/CUDA/Integration/kernel-args-device-memref.mlir b/test/GPU/CUDA/Integration/kernel-args-device-memref.mlir index e458eca0b..8d28266af 100644 --- a/test/GPU/CUDA/Integration/kernel-args-device-memref.mlir +++ b/test/GPU/CUDA/Integration/kernel-args-device-memref.mlir @@ -15,20 +15,24 @@ module { func.func @entry(%arg0: memref<8x8xf32>, %arg1: memref<8x8xf32>, %arg2: memref<8x8xf32>) -> memref<8x8xf32>{ + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%t0, %t1, %t2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.matmul ins(%arg0, %arg1 : memref<8x8xf32>, memref<8x8xf32>) + outs(%arg2 : memref<8x8xf32>) + gpu.terminator + } // Kernel arguments are already allocated on GPU - use directly - linalg.matmul ins(%arg0, %arg1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%arg2 : memref<8x8xf32>) return %arg2 : memref<8x8xf32> } } -// CHECK: module attributes {gpu.container_module} +// CHECK: module attributes{{.*}}gpu.container_module // CHECK-LABEL: func.func @_entry // CHECK: gpu.launch_func @_entry_kernel::@_entry_kernel // CHECK: } // CHECK: gpu.module @_entry_kernel // CHECK-LABEL: llvm.func @_entry_kernel -// CHECK-DAG: nvvm.read // CHECK-DAG: llvm.mul // CHECK-DAG: llvm.add // CHECK-LABEL: func.func @entry diff --git a/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir b/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir index d37ef4b39..da11149d6 100644 --- a/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir +++ b/test/GPU/CUDA/Integration/kernel-args-device-tensor.mlir @@ -1,17 +1,20 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print-mlir=mid -gpu-args=1 \ +// RUN: tpp-run %s -gpu=cuda -print-mlir=mid -gpu-args=1 -gpu-block-tile=-1 \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print-mlir=mid -gpu-args=1 -print \ +// RUN: tpp-run %s -gpu=cuda -print-mlir=mid -gpu-args=1 -print -gpu-block-tile=-1 \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s --check-prefix=PRINT #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module { +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> +} { func.func @entry(%arg0: tensor<8x8xf32> {bufferization.writable = true}, %arg1: tensor<8x8xf32> {bufferization.writable = true}, %arg2: tensor<8x8xf32> {bufferization.writable = true} @@ -23,13 +26,12 @@ module { } } -// CHECK: module attributes {gpu.container_module} +// CHECK: module attributes{{.*}}gpu.container_module // CHECK-LABEL: func.func @_entry // CHECK: gpu.launch_func @_entry_kernel::@_entry_kernel // CHECK: } // CHECK: gpu.module @_entry_kernel // CHECK-LABEL: llvm.func @_entry_kernel -// CHECK-DAG: nvvm.read // CHECK-DAG: llvm.mul // CHECK-DAG: llvm.add // CHECK-LABEL: func.func @entry diff --git a/test/GPU/CUDA/Integration/kernel-args-host-memref.mlir b/test/GPU/CUDA/Integration/kernel-args-host-memref.mlir index f263bbe86..0eb4b6fde 100644 --- a/test/GPU/CUDA/Integration/kernel-args-host-memref.mlir +++ b/test/GPU/CUDA/Integration/kernel-args-host-memref.mlir @@ -22,8 +22,13 @@ module { %t5 = gpu.memcpy async [%t4] %2, %arg2 : memref<8x8xf32>, memref<8x8xf32> gpu.wait [%t5] - linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%2 : memref<8x8xf32>) + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%tx0, %tx1, %tx2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) + outs(%2 : memref<8x8xf32>) + gpu.terminator + } // Retrieve data from device %tOut = gpu.memcpy async %arg2, %2 : memref<8x8xf32>, memref<8x8xf32> @@ -40,7 +45,7 @@ module { } } -// CHECK: module attributes {gpu.container_module} +// CHECK: module attributes{{.*}}gpu.container_module // CHECK: func.func @_entry(%[[ARG0:.+]]: memref<8x8xf32>, %[[ARG1:.+]]: memref<8x8xf32>, %[[ARG2:.+]]: memref<8x8xf32> // CHECK: %[[gpu0:.+]],{{.*}}= gpu.alloc async () // CHECK: %[[gpu1:.+]],{{.*}}= gpu.alloc async () @@ -53,7 +58,6 @@ module { // CHECK: } // CHECK: gpu.module @_entry_kernel // CHECK-LABEL: llvm.func @_entry_kernel -// CHECK-DAG: nvvm.read // CHECK-DAG: llvm.mul // CHECK-DAG: llvm.add // CHECK-LABEL: func.func @entry diff --git a/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir b/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir index 48f622046..4d3f847f6 100644 --- a/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir +++ b/test/GPU/CUDA/Integration/kernel-args-host-tensor.mlir @@ -1,5 +1,5 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print-mlir=mid -gpu-args=0 -print \ +// RUN: tpp-run %s -gpu=cuda -print-mlir=mid -gpu-args=0 -print -gpu-block-tile=-1 \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s @@ -26,8 +26,13 @@ module { %t5 = gpu.memcpy async [%t4] %2, %a2 : memref<8x8xf32>, memref<8x8xf32> gpu.wait [%t5] - linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%2 : memref<8x8xf32>) + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%tx0, %tx1, %tx2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) + outs(%2 : memref<8x8xf32>) + gpu.terminator + } // Retrieve data from device %out = memref.alloc() : memref<8x8xf32> @@ -47,7 +52,7 @@ module { } } -// CHECK: module attributes {gpu.container_module} +// CHECK: module attributes{{.*}}gpu.container_module // CHECK: func.func @_entry(%[[ARG0:.+]]: memref<8x8xf32>, %[[ARG1:.+]]: memref<8x8xf32>, %[[ARG2:.+]]: memref<8x8xf32> // CHECK: %[[gpu0:.+]],{{.*}}= gpu.alloc async () // CHECK: %[[gpu1:.+]],{{.*}}= gpu.alloc async () @@ -61,7 +66,6 @@ module { // CHECK: } // CHECK: gpu.module @_entry_kernel // CHECK-LABEL: llvm.func @_entry_kernel -// CHECK-DAG: nvvm.read // CHECK-DAG: llvm.mul // CHECK-DAG: llvm.add // CHECK-LABEL: func.func @entry diff --git a/test/GPU/CUDA/Integration/linalg-matmul-cuda.mlir b/test/GPU/CUDA/Integration/linalg-matmul-cuda.mlir deleted file mode 100644 index c2a531ed6..000000000 --- a/test/GPU/CUDA/Integration/linalg-matmul-cuda.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry() { - %0, %t0 = gpu.alloc async () : memref<8x8xf32> - gpu.wait [%t0] - %1, %t1 = gpu.alloc async () : memref<8x8xf32> - gpu.wait [%t1] - %2, %t2 = gpu.alloc async () : memref<8x8xf32> - gpu.wait [%t2] - - %cst0 = arith.constant 0.0 : f32 - %cst1 = arith.constant 1.0 : f32 - %cst2 = arith.constant 2.0 : f32 - - linalg.fill ins(%cst1 : f32) outs(%0 : memref<8x8xf32>) - linalg.fill ins(%cst2 : f32) outs(%1 : memref<8x8xf32>) - linalg.fill ins(%cst0 : f32) outs(%2 : memref<8x8xf32>) - - linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%2 : memref<8x8xf32>) - - %out = memref.alloc() : memref<8x8xf32> - %tOut = gpu.memcpy async %out, %2 : memref<8x8xf32>, memref<8x8xf32> - gpu.wait [%tOut] - %cast = memref.cast %out : memref<8x8xf32> to memref<*xf32> - call @printMemrefF32(%cast) : (memref<*xf32>) -> () - - %tD0 = gpu.dealloc async %0 : memref<8x8xf32> - gpu.wait [%tD0] - %tD1 = gpu.dealloc async %1 : memref<8x8xf32> - gpu.wait [%tD1] - %tD2 = gpu.dealloc async %2 : memref<8x8xf32> - gpu.wait [%tD2] - - memref.dealloc %out : memref<8x8xf32> - - return -} - -func.func private @printMemrefF32(memref<*xf32>) - -// CHECK-COUNT-8: [16, 16, 16, 16, 16, 16, 16, 16] diff --git a/test/GPU/CUDA/Integration/linalg-matmul.mlir b/test/GPU/CUDA/Integration/linalg-matmul.mlir index b705679fd..450f2680a 100644 --- a/test/GPU/CUDA/Integration/linalg-matmul.mlir +++ b/test/GPU/CUDA/Integration/linalg-matmul.mlir @@ -3,10 +3,10 @@ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s -func.func @entry(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32>) -> tensor<8x8xf32> { - %1 = linalg.matmul ins(%arg0, %arg1 : tensor<8x8xf32>, tensor<8x8xf32>) - outs(%arg2 : tensor<8x8xf32>) -> tensor<8x8xf32> - return %1 : tensor<8x8xf32> +func.func @entry(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> } -// CHECK-COUNT-8: 9, 9, 9, 9, 9, 9, 9, 9 +// CHECK-COUNT-64: 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65 diff --git a/test/GPU/CUDA/Integration/linalg-mlp.mlir b/test/GPU/CUDA/Integration/linalg-mlp.mlir index 2d377c8d8..d9f40af93 100644 --- a/test/GPU/CUDA/Integration/linalg-mlp.mlir +++ b/test/GPU/CUDA/Integration/linalg-mlp.mlir @@ -1,5 +1,5 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print -print-mlir=mid \ +// RUN: tpp-run %s -gpu=cuda -print -print-mlir=mid -gpu-block-tile=-1 \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s @@ -7,7 +7,10 @@ #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> #map3 = affine_map<(d0, d1) -> (d0, d1)> -module { +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> +} { func.func @entry(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { %weights = arith.constant dense<0.1> : tensor<8x8xf32> %bias = arith.constant dense<0.4> : tensor<8x8xf32> @@ -20,18 +23,18 @@ module { linalg.yield %4 : f32 } -> tensor<8x8xf32> - %1 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} - ins(%0, %bias : tensor<8x8xf32>, tensor<8x8xf32>) outs(%arg1 : tensor<8x8xf32>) { - ^bb0(%in: f32, %in1: f32, %out: f32): - %3 = arith.addf %in, %in1 : f32 + %1 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} + ins(%bias : tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) { + ^bb0(%in: f32, %out: f32): + %3 = arith.addf %in, %out : f32 linalg.yield %3 : f32 } -> tensor<8x8xf32> %cst = arith.constant 0.000000e+00 : f32 - %2 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} - ins(%1 : tensor<8x8xf32>) outs(%arg1 : tensor<8x8xf32>) { - ^bb0(%in: f32, %out: f32): - %3 = arith.maximumf %in, %cst : f32 + %2 = linalg.generic {indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} + outs(%1 : tensor<8x8xf32>) { + ^bb0(%out: f32): + %3 = arith.maximumf %out, %cst : f32 linalg.yield %3 : f32 } -> tensor<8x8xf32> diff --git a/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir b/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir index 8af75bfad..83bd2247e 100644 --- a/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir +++ b/test/GPU/CUDA/Integration/tensor-kernel-dispatch.mlir @@ -1,5 +1,5 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print \ +// RUN: tpp-run %s -gpu=cuda -print -gpu-block-tile=-1 \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s @@ -7,17 +7,21 @@ // to GPU starting from the tensor level. // Bufferization will allocate two buffers to hold matrices A and B. // This requires either GPU unified memory or explicit data transfers to GPU. -func.func @entry(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %c0 = arith.constant 0.0 : f32 - %c1 = arith.constant 1.0 : f32 - %c2 = arith.constant 2.0 : f32 - %mat = tensor.empty() : tensor<8x8xf32> - %C = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> - %A = linalg.fill ins(%c1 : f32) outs(%mat : tensor<8x8xf32>) -> tensor<8x8xf32> - %B = linalg.fill ins(%c2 : f32) outs(%mat : tensor<8x8xf32>) -> tensor<8x8xf32> - %R = linalg.matmul ins(%A, %B : tensor<8x8xf32>, tensor<8x8xf32>) - outs(%C : tensor<8x8xf32>) -> tensor<8x8xf32> - return %R : tensor<8x8xf32> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> +} { + func.func @entry(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c0 = arith.constant 0.0 : f32 + %c1 = arith.constant 1.0 : f32 + %c2 = arith.constant 2.0 : f32 + %mat = tensor.empty() : tensor<8x8xf32> + %A = linalg.fill ins(%c1 : f32) outs(%mat : tensor<8x8xf32>) -> tensor<8x8xf32> + %B = linalg.fill ins(%c2 : f32) outs(%mat : tensor<8x8xf32>) -> tensor<8x8xf32> + %R = linalg.matmul ins(%A, %B : tensor<8x8xf32>, tensor<8x8xf32>) + outs(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> + return %R : tensor<8x8xf32> + } } -// CHECK-COUNT-8: 16, 16, 16, 16, 16, 16, 16, 16 +// CHECK-COUNT-8: 17, 17, 17, 17, 17, 17, 17, 17 diff --git a/test/GPU/CUDA/Integration/tpp-brgemm.mlir b/test/GPU/CUDA/Integration/tpp-brgemm.mlir deleted file mode 100644 index 92acce11a..000000000 --- a/test/GPU/CUDA/Integration/tpp-brgemm.mlir +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry(%arg0: memref<2x8x32x32xf32>, %arg1: memref<8x8x32x32xf32>, %arg2: memref<2x8x32x32xf32>) { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %c1 = arith.constant 1 : index - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c2, %c8) step (%c1, %c1) { - %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] - : memref<2x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] - : memref<8x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] - : memref<2x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview, %subview_0 : - memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, - memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>) - outs(%subview_1 : memref<32x32xf32, strided<[32, 1], offset: ?>>) - scf.reduce - } - - %out = memref.alloc() : memref<2x8x32x32xf32> - %tw = gpu.wait async - %tOut = gpu.memcpy async [%tw] %out, %arg2 : memref<2x8x32x32xf32>, memref<2x8x32x32xf32> - gpu.wait [%tOut] - - %d1 = arith.constant -1.0 : f32 - %zeroCst = arith.constant 0 : index - %v0 = vector.transfer_read %out[%zeroCst, %zeroCst, %zeroCst, %zeroCst], %d1 : memref<2x8x32x32xf32>, vector<32xf32> - vector.print %v0 : vector<32xf32> - - memref.dealloc %out : memref<2x8x32x32xf32> - - return -} - -// CHECK: 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257, 257 diff --git a/test/GPU/CUDA/Integration/tpp-gemm.mlir b/test/GPU/CUDA/Integration/tpp-gemm.mlir deleted file mode 100644 index 85bdedf86..000000000 --- a/test/GPU/CUDA/Integration/tpp-gemm.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry() { - %0, %t0 = gpu.alloc async () : memref<8x8xf32> - gpu.wait [%t0] - %1, %t1 = gpu.alloc async () : memref<8x8xf32> - gpu.wait [%t1] - %2, %t2 = gpu.alloc async () : memref<8x8xf32> - gpu.wait [%t2] - - %cst0 = arith.constant 0.0 : f32 - %cst1 = arith.constant 1.0 : f32 - %cst2 = arith.constant 2.0 : f32 - - linalg.fill ins(%cst1 : f32) outs(%0 : memref<8x8xf32>) - linalg.fill ins(%cst2 : f32) outs(%1 : memref<8x8xf32>) - linalg.fill ins(%cst0 : f32) outs(%2 : memref<8x8xf32>) - - linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%2: memref<8x8xf32>) - - %out = memref.alloc() : memref<8x8xf32> - %tOut = gpu.memcpy async %out, %2 : memref<8x8xf32>, memref<8x8xf32> - gpu.wait [%tOut] - %cast = memref.cast %out : memref<8x8xf32> to memref<*xf32> - call @printMemrefF32(%cast) : (memref<*xf32>) -> () - - %tD0 = gpu.dealloc async %0 : memref<8x8xf32> - gpu.wait [%tD0] - %tD1 = gpu.dealloc async %1 : memref<8x8xf32> - gpu.wait [%tD1] - %tD2 = gpu.dealloc async %2 : memref<8x8xf32> - gpu.wait [%tD2] - - memref.dealloc %out : memref<8x8xf32> - - return -} - -func.func private @printMemrefF32(memref<*xf32>) - -// CHECK-COUNT-8: [16, 16, 16, 16, 16, 16, 16, 16] diff --git a/test/GPU/CUDA/Integration/wmma/brgemm-wmma-tiled.mlir b/test/GPU/CUDA/Integration/wmma/brgemm-wmma-tiled.mlir deleted file mode 100644 index 143ce9424..000000000 --- a/test/GPU/CUDA/Integration/wmma/brgemm-wmma-tiled.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry(%arg0: memref<16x32x32xf16>, - %arg1: memref<16x32x32xf16>, - %arg2: memref<32x32xf16>) -> memref<32x32xf16> { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<16x32x32xf16>, memref<16x32x32xf16>) - outs(%arg2 : memref<32x32xf16>) - return %arg2 : memref<32x32xf16> -} - -// CHECK-COUNT-32: ( 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513, 513 ) diff --git a/test/GPU/CUDA/Integration/wmma/brgemm-wmma.mlir b/test/GPU/CUDA/Integration/wmma/brgemm-wmma.mlir index 3ffbd46f5..e7bbdcdfe 100644 --- a/test/GPU/CUDA/Integration/wmma/brgemm-wmma.mlir +++ b/test/GPU/CUDA/Integration/wmma/brgemm-wmma.mlir @@ -1,5 +1,5 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma \ +// RUN: tpp-run %s -gpu=cuda \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s diff --git a/test/GPU/CUDA/Integration/wmma/gemm-wmma-tiled.mlir b/test/GPU/CUDA/Integration/wmma/gemm-wmma-tiled.mlir deleted file mode 100644 index 0e850cc4e..000000000 --- a/test/GPU/CUDA/Integration/wmma/gemm-wmma-tiled.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) -> memref<32x32xf16> { - linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg2 : memref<32x32xf16>) - return %arg2 : memref<32x32xf16> -} - -// CHECK-COUNT-32: ( 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33 ) diff --git a/test/GPU/CUDA/Integration/wmma/gemm-wmma.mlir b/test/GPU/CUDA/Integration/wmma/gemm-wmma.mlir index 6d0a63747..24ed30e2d 100644 --- a/test/GPU/CUDA/Integration/wmma/gemm-wmma.mlir +++ b/test/GPU/CUDA/Integration/wmma/gemm-wmma.mlir @@ -1,5 +1,5 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ +// RUN: tpp-run %s -gpu=cuda -print \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s diff --git a/test/GPU/CUDA/Integration/wmma/mlp-wmma-tiled.mlir b/test/GPU/CUDA/Integration/wmma/mlp-wmma-tiled.mlir deleted file mode 100644 index a71f4f889..000000000 --- a/test/GPU/CUDA/Integration/wmma/mlp-wmma-tiled.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -// XFAIL:* -// See: #870 - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @entry(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>, %arg3: memref<32x32xf16>) -> memref<32x32xf16> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) { - ^bb0(%in: f16, %out: f16): - %0 = arith.addf %in, %out : f16 - linalg.yield %0 : f16 - } - linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<32x32xf16>) { - ^bb0(%out: f16): - %0 = arith.maximumf %out, %cst : f16 - linalg.yield %0 : f16 - } - return %arg3 : memref<32x32xf16> -} - -// CHECK-COUNT-32: ( 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34 ) diff --git a/test/GPU/CUDA/Integration/wmma/mlp-wmma.mlir b/test/GPU/CUDA/Integration/wmma/mlp-wmma.mlir deleted file mode 100644 index 249e86f38..000000000 --- a/test/GPU/CUDA/Integration/wmma/mlp-wmma.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -// XFAIL:* -// See: #870 - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @entry(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf16>) -> memref<16x16xf16> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - linalg.matmul ins(%arg0, %arg1 : memref<16x16xf16>, memref<16x16xf16>) outs(%arg3 : memref<16x16xf16>) - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<16x16xf16>) outs(%arg3 : memref<16x16xf16>) { - ^bb0(%in: f16, %out: f16): - %0 = arith.addf %in, %out : f16 - linalg.yield %0 : f16 - } - linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<16x16xf16>) { - ^bb0(%out: f16): - %0 = arith.maximumf %out, %cst : f16 - linalg.yield %0 : f16 - } - return %arg3 : memref<16x16xf16> -} - -// CHECK-COUNT-16: ( 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18 ) diff --git a/test/GPU/CUDA/Integration/wmma/pack-brgemm-unpack.mlir b/test/GPU/CUDA/Integration/wmma/pack-brgemm-unpack.mlir deleted file mode 100644 index 708d0f823..000000000 --- a/test/GPU/CUDA/Integration/wmma/pack-brgemm-unpack.mlir +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print -seed 123 \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -print -seed 123 -gpu-wmma \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) -> memref<32x32xf16> { - %alloc = gpu.alloc() {alignment = 64 : i64} : memref<2x2x16x16xf16> - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [2, 16, 2, 16] : memref<32x32xf16> into memref<2x16x2x16xf16> - %alloc_0 = gpu.alloc() {alignment = 64 : i64} : memref<2x2x16x16xf16> - linalg.transpose ins(%expand_shape : memref<2x16x2x16xf16>) outs(%alloc_0 : memref<2x2x16x16xf16>) permutation = [0, 2, 1, 3] - %expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2, 3]] output_shape [2, 16, 2, 16] : memref<32x32xf16> into memref<2x16x2x16xf16> - %alloc_2 = gpu.alloc() {alignment = 64 : i64} : memref<2x2x16x16xf16> - linalg.transpose ins(%expand_shape_1 : memref<2x16x2x16xf16>) outs(%alloc_2 : memref<2x2x16x16xf16>) permutation = [2, 0, 1, 3] - %expand_shape_3 = memref.expand_shape %arg2 [[0, 1], [2, 3]] output_shape [2, 16, 2, 16] : memref<32x32xf16> into memref<2x16x2x16xf16> - linalg.transpose ins(%expand_shape_3 : memref<2x16x2x16xf16>) outs(%alloc : memref<2x2x16x16xf16>) permutation = [0, 2, 1, 3] - scf.forall (%arg3, %arg4) in (2, 2) { - %subview = memref.subview %alloc_0[%arg3, 0, 0, 0] [1, 2, 16, 16] [1, 1, 1, 1] : memref<2x2x16x16xf16> to memref<2x16x16xf16, strided<[256, 16, 1], offset: ?>> - %subview_5 = memref.subview %alloc_2[%arg4, 0, 0, 0] [1, 2, 16, 16] [1, 1, 1, 1] : memref<2x2x16x16xf16> to memref<2x16x16xf16, strided<[256, 16, 1], offset: ?>> - %subview_6 = memref.subview %alloc[%arg3, %arg4, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<2x2x16x16xf16> to memref<16x16xf16, strided<[16, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview, %subview_5 : memref<2x16x16xf16, strided<[256, 16, 1], offset: ?>>, memref<2x16x16xf16, strided<[256, 16, 1], offset: ?>>) outs(%subview_6 : memref<16x16xf16, strided<[16, 1], offset: ?>>) - } - %alloc_4 = gpu.alloc() {alignment = 64 : i64} : memref<2x16x2x16xf16> - linalg.transpose ins(%alloc : memref<2x2x16x16xf16>) outs(%alloc_4 : memref<2x16x2x16xf16>) permutation = [0, 2, 1, 3] - %collapse_shape = memref.collapse_shape %alloc_4 [[0, 1], [2, 3]] : memref<2x16x2x16xf16> into memref<32x32xf16> - linalg.copy ins(%collapse_shape : memref<32x32xf16>) outs(%arg2 : memref<32x32xf16>) - gpu.dealloc %alloc : memref<2x2x16x16xf16> - gpu.dealloc %alloc_0 : memref<2x2x16x16xf16> - gpu.dealloc %alloc_2 : memref<2x2x16x16xf16> - gpu.dealloc %alloc_4 : memref<2x16x2x16xf16> - return %arg2 : memref<32x32xf16> - } - -// CHECK: ( 0.036{{[0-9]+}}, 0.089{{[0-9]+}}, 0.086{{[0-9]+}}, 0.35{{[0-9]+}} -// CHECK: ( 0.075{{[0-9]+}}, 0.38{{[0-9]+}}, 0.35{{[0-9]+}}, 0.27{{[0-9]+}} -// CHECK: ( 0.008{{[0-9]+}}, 0.083{{[0-9]+}}, 0.18{{[0-9]+}}, 0.15{{[0-9]+}} diff --git a/test/GPU/CUDA/Integration/wmma/packed-matmul-wmma.mlir b/test/GPU/CUDA/Integration/wmma/packed-matmul-wmma.mlir deleted file mode 100644 index 50b0ddad5..000000000 --- a/test/GPU/CUDA/Integration/wmma/packed-matmul-wmma.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> - -func.func @entry(%arg0: tensor<2x4x16x16xf16>, %arg1: tensor<4x4x16x16xf16>, %arg2: tensor<2x4x16x16xf16>) -> tensor<2x4x16x16xf16> { - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor<2x4x16x16xf16>, tensor<4x4x16x16xf16>) outs(%arg2 : tensor<2x4x16x16xf16>) { - ^bb0(%in: f16, %in_2: f16, %out: f16): - %4 = arith.mulf %in, %in_2 : f16 - %5 = arith.addf %out, %4 : f16 - linalg.yield %5 : f16 - } -> tensor<2x4x16x16xf16> - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %vcst = arith.constant -1.000000e+00 : f16 - - %buf = bufferization.to_memref %0 {read_only} : memref<2x4x16x16xf16> - %out = memref.alloc() : memref<2x4x16x16xf16> - %tOut = gpu.memcpy async %out, %buf : memref<2x4x16x16xf16>, memref<2x4x16x16xf16> - gpu.wait [%tOut] - %v0 = vector.transfer_read %out[%c1, %c2, %c0, %c0], %vcst : memref<2x4x16x16xf16>, vector<16x16xf16> - vector.print %v0 : vector<16x16xf16> - memref.dealloc %out : memref<2x4x16x16xf16> - - return %0 : tensor<2x4x16x16xf16> -} - -// CHECK-COUNT-16: ( 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65 ) diff --git a/test/GPU/CUDA/Integration/wmma/subview-strided-wmma.mlir b/test/GPU/CUDA/Integration/wmma/subview-strided-wmma.mlir deleted file mode 100644 index ec064f310..000000000 --- a/test/GPU/CUDA/Integration/wmma/subview-strided-wmma.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ -// RUN: -entry-point-result=void -e entry 2>&1 | \ -// RUN: FileCheck %s - -func.func @entry(%arg0: memref<16x32x16xf16>, %arg1: memref<16x64x16xf16>, %arg2: memref<32x32xf16>) -> memref<32x32xf16> { - %subview = memref.subview %arg0[0, 0, 0] [16, 1, 16] [1, 1, 1] - : memref<16x32x16xf16> to memref<16x16xf16, strided<[512, 1], offset: 0>> - %subview_0 = memref.subview %arg1[0, 0, 0] [16, 1, 16] [1, 1, 1] - : memref<16x64x16xf16> to memref<16x16xf16, strided<[1024, 1], offset: 0>> - %subview_1 = memref.subview %arg2[16, 0] [16, 16] [1, 1] - : memref<32x32xf16> to memref<16x16xf16, strided<[32, 1], offset: 512>> - - %c2 = arith.constant 2.0 : f16 - %c4 = arith.constant 4.0 : f16 - - linalg.fill ins(%c2 : f16) outs(%subview : memref<16x16xf16, strided<[512, 1], offset: 0>>) - linalg.fill ins(%c4 : f16) outs(%subview_0 : memref<16x16xf16, strided<[1024, 1], offset: 0>>) - - linalg.matmul ins(%subview, %subview_0 : memref<16x16xf16, strided<[512, 1], offset: 0>>, - memref<16x16xf16, strided<[1024, 1], offset: 0>>) - outs(%subview_1 : memref<16x16xf16, strided<[32, 1], offset: 512>>) - return %arg2 : memref<32x32xf16> -} - -// CHECK-COUNT-16: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) -// CHECK-COUNT-16: ( 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 129, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) diff --git a/test/GPU/CUDA/Integration/wmma/wmma-mem-access.mlir b/test/GPU/CUDA/Integration/wmma/wmma-mem-access.mlir index 93942abbd..5a60a478e 100644 --- a/test/GPU/CUDA/Integration/wmma/wmma-mem-access.mlir +++ b/test/GPU/CUDA/Integration/wmma/wmma-mem-access.mlir @@ -1,5 +1,5 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-run %s -gpu=cuda -gpu-wmma -print \ +// RUN: tpp-run %s -gpu=cuda -print \ // RUN: -entry-point-result=void -e entry 2>&1 | \ // RUN: FileCheck %s diff --git a/test/GPU/CUDA/all-reduce-max.mlir b/test/GPU/CUDA/all-reduce-max.mlir deleted file mode 100644 index b4eb44197..000000000 --- a/test/GPU/CUDA/all-reduce-max.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-opt %s -gpu-pipeline=gpu=cuda | FileCheck %s - -// Original test from: llvm-project/mlir/test/Integration/GPU/CUDA/all-reduce-max.mlir - -func.func @main() { - %data = memref.alloc() : memref<2x6xi32> - %sum = memref.alloc() : memref<2xi32> - %cst0 = arith.constant 0 : i32 - %cst1 = arith.constant 1 : i32 - %cst2 = arith.constant 2 : i32 - %cst4 = arith.constant 4 : i32 - %cst8 = arith.constant 8 : i32 - %cst16 = arith.constant 16 : i32 - - %cst3 = arith.constant 3 : i32 - %cst6 = arith.constant 6 : i32 - %cst7 = arith.constant 7 : i32 - %cst10 = arith.constant 10 : i32 - %cst11 = arith.constant 11 : i32 - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - - %cast_data = memref.cast %data : memref<2x6xi32> to memref<*xi32> - gpu.host_register %cast_data : memref<*xi32> - %cast_sum = memref.cast %sum : memref<2xi32> to memref<*xi32> - gpu.host_register %cast_sum : memref<*xi32> - - memref.store %cst0, %data[%c0, %c0] : memref<2x6xi32> - memref.store %cst1, %data[%c0, %c1] : memref<2x6xi32> - memref.store %cst2, %data[%c0, %c2] : memref<2x6xi32> - memref.store %cst4, %data[%c0, %c3] : memref<2x6xi32> - memref.store %cst8, %data[%c0, %c4] : memref<2x6xi32> - memref.store %cst16, %data[%c0, %c5] : memref<2x6xi32> - - memref.store %cst2, %data[%c1, %c0] : memref<2x6xi32> - memref.store %cst3, %data[%c1, %c1] : memref<2x6xi32> - memref.store %cst6, %data[%c1, %c2] : memref<2x6xi32> - memref.store %cst7, %data[%c1, %c3] : memref<2x6xi32> - memref.store %cst10, %data[%c1, %c4] : memref<2x6xi32> - memref.store %cst11, %data[%c1, %c5] : memref<2x6xi32> - - // MAX - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) - threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { - %val = memref.load %data[%bx, %tx] : memref<2x6xi32> - %reduced = gpu.all_reduce maxsi %val uniform {} : (i32) -> (i32) - memref.store %reduced, %sum[%bx] : memref<2xi32> - gpu.terminator - } - - call @printMemrefI32(%cast_sum) : (memref<*xi32>) -> () - - memref.dealloc %data : memref<2x6xi32> - memref.dealloc %sum : memref<2xi32> - - return -} - -func.func private @printMemrefI32(memref<*xi32>) - -// CHECK: module attributes {gpu.container_module} -// CHECK-LABEL: func.func @main() -// CHECK: gpu.host_register -// CHECK: gpu.launch_func @main_kernel::@main_kernel -// CHECK: } -// CHECK: gpu.module @main_kernel -// CHECK-LABEL: llvm.func @main_kernel -// CHECK-DAG: nvvm.read -// CHECK-DAG: nvvm.shfl.sync -// CHECK-DAG: nvvm.barrier diff --git a/test/GPU/CUDA/gpu-pipeline-cuda-wmma.mlir b/test/GPU/CUDA/gpu-pipeline-cuda-wmma.mlir deleted file mode 100644 index 366646a25..000000000 --- a/test/GPU/CUDA/gpu-pipeline-cuda-wmma.mlir +++ /dev/null @@ -1,33 +0,0 @@ -// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ -// RUN: tpp-opt %s -gpu-pipeline=gpu=cuda -gpu-wmma -split-input-file | FileCheck %s - -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> - -func.func @packed_matmul(%arg0: tensor<2x4x16x16xf16>, %arg1: tensor<4x4x16x16xf16>, %arg2: tensor<2x4x16x16xf16>) -> tensor<2x4x16x16xf16> { - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor<2x4x16x16xf16>, tensor<4x4x16x16xf16>) outs(%arg2 : tensor<2x4x16x16xf16>) { - ^bb0(%in: f16, %in_2: f16, %out: f16): - %4 = arith.mulf %in, %in_2 : f16 - %5 = arith.addf %out, %4 : f16 - linalg.yield %5 : f16 - } -> tensor<2x4x16x16xf16> - return %0 : tensor<2x4x16x16xf16> -} - -// CHECK: module attributes {gpu.container_module} -// CHECK-LABEL: func.func @packed_matmul -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-NOT: linalg.generic -// CHECK: gpu.launch_func @packed_matmul_kernel::@packed_matmul_kernel -// CHECK-SAME: blocks in (%[[C2]], %[[C4]], %[[C1]]) threads in (%[[C32]], %[[C1]], %[[C1]]) -// CHECK: } -// CHECK: gpu.module @packed_matmul_kernel -// CHECK-LABEL: llvm.func @packed_matmul_kernel -// CHECK-DAG: nvvm.wmma.load -// CHECK-DAG: nvvm.wmma.mma -// CHECK-DAG: nvvm.wmma.store diff --git a/test/GPU/CUDA/gpu-pipeline-cuda.mlir b/test/GPU/CUDA/gpu-pipeline-cuda.mlir index de7d8703d..b2f8040c3 100644 --- a/test/GPU/CUDA/gpu-pipeline-cuda.mlir +++ b/test/GPU/CUDA/gpu-pipeline-cuda.mlir @@ -1,64 +1,6 @@ // RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ // RUN: tpp-opt %s -gpu-pipeline=gpu=cuda -split-input-file | FileCheck %s -func.func @linalg_matmul() { - %0 = memref.alloc() : memref<8x8xf32> - %1 = memref.alloc() : memref<8x8xf32> - %2 = memref.alloc() : memref<8x8xf32> - - %cast_a = memref.cast %0 : memref<8x8xf32> to memref<*xf32> - gpu.host_register %cast_a : memref<*xf32> - %cast_b = memref.cast %1 : memref<8x8xf32> to memref<*xf32> - gpu.host_register %cast_b : memref<*xf32> - %cast_c = memref.cast %2 :memref<8x8xf32> to memref<*xf32> - gpu.host_register %cast_c : memref<*xf32> - - linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%2 : memref<8x8xf32>) - - call @printMemrefF32(%cast_c) : (memref<*xf32>) -> () - - return -} - -func.func private @printMemrefF32(memref<*xf32>) - -// CHECK: module attributes {gpu.container_module} -// CHECK-LABEL: func.func @linalg_matmul -// CHECK: %[[C1:.*]] = memref.cast -// CHECK: gpu.host_register %[[C1]] -// CHECK: %[[C2:.*]] = memref.cast -// CHECK: gpu.host_register %[[C2]] -// CHECK: %[[C3:.*]] = memref.cast -// CHECK: gpu.host_register %[[C3]] -// CHECK: gpu.launch_func @linalg_matmul_kernel::@linalg_matmul_kernel -// CHECK: call @printMemrefF32 -// CHECK: } -// CHECK: gpu.module @linalg_matmul_kernel -// CHECK-LABEL: llvm.func @linalg_matmul_kernel -// CHECK-DAG: nvvm.read -// CHECK-DAG: llvm.mul -// CHECK-DAG: llvm.add - -// ----- - -func.func @tpp_gemm(%arg0: memref<8x9xf32>, %arg1: memref<9x10xf32>, %arg2: memref<8x10xf32>) { - linalg.matmul ins(%arg0, %arg1 : memref<8x9xf32>, memref<9x10xf32>) - outs(%arg2: memref<8x10xf32>) - return -} - -// CHECK: module attributes {gpu.container_module} -// CHECK-LABEL: func.func @tpp_gemm -// CHECK: gpu.launch_func @tpp_gemm_kernel::@tpp_gemm_kernel -// CHECK: gpu.module @tpp_gemm_kernel -// CHECK-LABEL: llvm.func @tpp_gemm_kernel -// CHECK-DAG: nvvm.read -// CHECK-DAG: llvm.mul -// CHECK-DAG: llvm.add - -// ----- - func.func @packed_brgemm(%arg0: memref<4x16x64x64xf32>, %arg1: memref<16x16x64x64xf32>, %arg2: memref<4x16x64x64xf32>) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -126,10 +68,10 @@ func.func @matmul_blocks_threads( // CHECK: module attributes {gpu.container_module} // CHECK-LABEL: func.func @matmul_blocks_threads // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index // CHECK-NOT: linalg.matmul // CHECK: gpu.launch_func @matmul_blocks_threads_kernel::@matmul_blocks_threads_kernel -// CHECK-SAME: blocks in (%[[C8]], %[[C64]], %[[C1]]) threads in (%[[C32]], %[[C32]], %[[C1]]) +// CHECK-SAME: blocks in (%[[C2]], %[[C16]], %[[C1]]) threads in (%[[C4]], %[[C4]], %[[C1]]) // CHECK: gpu.module @matmul_blocks_threads_kernel diff --git a/test/GPU/gpu-conversion.mlir b/test/GPU/gpu-conversion.mlir index e48680fac..a4634dd69 100644 --- a/test/GPU/gpu-conversion.mlir +++ b/test/GPU/gpu-conversion.mlir @@ -1,32 +1,40 @@ // RUN: tpp-opt %s -gpu-conversion -split-input-file | FileCheck %s - -func.func @matmul() { - %0 = memref.alloc() : memref<8x8xf32> - %1 = memref.alloc() : memref<8x8xf32> - %2 = memref.alloc() : memref<8x8xf32> - - %cast_a = memref.cast %0 : memref<8x8xf32> to memref<*xf32> - gpu.host_register %cast_a : memref<*xf32> - %cast_b = memref.cast %1 : memref<8x8xf32> to memref<*xf32> - gpu.host_register %cast_b : memref<*xf32> - %cast_c = memref.cast %2 :memref<8x8xf32> to memref<*xf32> - gpu.host_register %cast_c : memref<*xf32> - - linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) - outs(%2 : memref<8x8xf32>) - - call @printMemrefF32(%cast_c) : (memref<*xf32>) -> () - - memref.dealloc %0 : memref<8x8xf32> - memref.dealloc %1 : memref<8x8xf32> - memref.dealloc %2 : memref<8x8xf32> - - return +// XFAIL:* +// Currently tiling needs matmul as an anchor and without it, other ops +// will not get outlined. + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 4 : i32>>> +} { + func.func @matmul() { + %0 = memref.alloc() : memref<8x8xf32> + %1 = memref.alloc() : memref<8x8xf32> + %2 = memref.alloc() : memref<8x8xf32> + + %cast_a = memref.cast %0 : memref<8x8xf32> to memref<*xf32> + gpu.host_register %cast_a : memref<*xf32> + %cast_b = memref.cast %1 : memref<8x8xf32> to memref<*xf32> + gpu.host_register %cast_b : memref<*xf32> + %cast_c = memref.cast %2 :memref<8x8xf32> to memref<*xf32> + gpu.host_register %cast_c : memref<*xf32> + + linalg.matmul ins(%0, %1 : memref<8x8xf32>, memref<8x8xf32>) + outs(%2 : memref<8x8xf32>) + + call @printMemrefF32(%cast_c) : (memref<*xf32>) -> () + + memref.dealloc %0 : memref<8x8xf32> + memref.dealloc %1 : memref<8x8xf32> + memref.dealloc %2 : memref<8x8xf32> + + return + } + func.func private @printMemrefF32(memref<*xf32>) } -func.func private @printMemrefF32(memref<*xf32>) -// CHECK: module attributes {gpu.container_module} +// CHECK: module attributes{{.*}}gpu.container_module // CHECK-LABEL: func.func @matmul // CHECK: %[[C1:.*]] = memref.cast // CHECK: gpu.host_register %[[C1]] @@ -93,8 +101,8 @@ func.func @generic_matmul(%arg0: memref<256x2048xf32>, // ----- -func.func @identity(%arg0: memref<5x6xf32>, %arg1: memref<5x6xf32>) { - linalg.copy ins(%arg0 : memref<5x6xf32>) outs(%arg1: memref<5x6xf32>) +func.func @identity(%arg0: memref<64x128xf32>, %arg1: memref<64x128xf32>) { + linalg.copy ins(%arg0 : memref<64x128xf32>) outs(%arg1: memref<64x128xf32>) return } @@ -103,11 +111,11 @@ func.func @identity(%arg0: memref<5x6xf32>, %arg1: memref<5x6xf32>) { // CHECK: gpu.launch_func @identity_kernel::@identity_kernel // CHECK: gpu.module @identity_kernel // CHECK-LABEL: gpu.func @identity_kernel -// CHECK-SAME: %[[ARG0:.+]]: memref<5x6xf32>, %[[ARG1:.+]]: memref<5x6xf32> +// CHECK-SAME: %[[ARG0:.+]]: memref<64x128xf32>, %[[ARG1:.+]]: memref<64x128xf32> // CHECK: %[[X:.+]] = gpu.block_id x // CHECK-NEXT: %[[Y:.+]] = gpu.block_id y -// CHECK: %[[L:.+]] = memref.load %[[ARG0]][%[[X]], %[[Y]]] : memref<5x6xf32> -// CHECK: memref.store %[[L]], %[[ARG1]][%[[X]], %[[Y]]] : memref<5x6xf32> +// CHECK: %[[L:.+]] = memref.load %[[ARG0]][%[[X]], %[[Y]]] : memref<64x128xf32> +// CHECK: memref.store %[[L]], %[[ARG1]][%[[X]], %[[Y]]] : memref<64x128xf32> // CHECK: gpu.return // ----- diff --git a/test/GPU/linalg-to-gpu-wmma.mlir b/test/GPU/linalg-to-gpu-wmma.mlir deleted file mode 100644 index 75a440634..000000000 --- a/test/GPU/linalg-to-gpu-wmma.mlir +++ /dev/null @@ -1,336 +0,0 @@ -// RUN: tpp-opt %s -linalg-to-gpu="wmma=1 warp-tile=16,16,16" -canonicalize -split-input-file | FileCheck %s - -func.func @matmul(%arg0: memref<16x16xf16>, - %arg1: memref<16x16xf16>, - %arg2: memref<16x16xf16>) { - linalg.matmul ins(%arg0, %arg1 : memref<16x16xf16>, memref<16x16xf16>) - outs(%arg2 : memref<16x16xf16>) - return -} - -// CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[A:.+]]: memref<16x16xf16>, %[[B:.+]]: memref<16x16xf16>, %[[C:.+]]: memref<16x16xf16> -// CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 -// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] -// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[C]]{{.*}}leadDimension = 16 -// CHECK: scf.reduce -// CHECK: } - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @matmul_wide_tiled(%arg0: memref<16x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<16x32xf16>) { - linalg.matmul ins(%arg0, %arg1 : memref<16x32xf16>, memref<32x32xf16>) outs(%arg2 : memref<16x32xf16>) - return -} - -// Assumes 16x16 WMMA tiles. -// -// CHECK-LABEL: func.func @matmul_wide_tiled( -// CHECK-SAME: %[[A:.+]]: memref<16x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<16x32xf16> -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 -// CHECK-DAG: %[[c16:.+]] = arith.constant 16 -// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[C]] -// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[A]] -// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c0]], %[[c0]] -// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c0]], %[[c16]] -// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c16]], %[[c0]] -// CHECK-DAG: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[c16]], %[[c16]] -// CHECK-COUNT-4: gpu.subgroup_mma_compute -// CHECK-COUNT-2: gpu.subgroup_mma_store_matrix{{.*}}, %[[C]] - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @matmul_tall_tiled(%arg0: memref<32x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<32x16xf16>) { - linalg.matmul ins(%arg0, %arg1 : memref<32x16xf16>, memref<16x16xf16>) outs(%arg2 : memref<32x16xf16>) - return -} - -// CHECK-LABEL: func.func @matmul_tall_tiled( -// CHECK-SAME: %[[A:.+]]: memref<32x16xf16>, %[[B:.+]]: memref<16x16xf16>, %[[C:.+]]: memref<32x16xf16> -// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[C]] -// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[A]] -// CHECK-COUNT-1: gpu.subgroup_mma_load_matrix %[[B]] -// CHECK-COUNT-2: gpu.subgroup_mma_compute -// CHECK-COUNT-2: gpu.subgroup_mma_store_matrix - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @matmul_2D_tiled(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) { - linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg2 : memref<32x32xf16>) - return -} - -// CHECK-LABEL: func.func @matmul_2D_tiled( -// CHECK-SAME: %[[A:.+]]: memref<32x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<32x32xf16> -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[C]] -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[A]] -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[B]] -// CHECK-COUNT-8: gpu.subgroup_mma_compute -// CHECK-COUNT-4: gpu.subgroup_mma_store_matrix - -// ----- - -func.func @matmul_K_dim_tiled(%arg0: memref<16x64xf16>, %arg1: memref<64x16xf16>, %arg2: memref<16x16xf16>) { - linalg.matmul ins(%arg0, %arg1 : memref<16x64xf16>, memref<64x16xf16>) outs(%arg2 : memref<16x16xf16>) - return -} - -// CHECK-LABEL: func.func @matmul_K_dim_tiled( -// CHECK-SAME: %[[A:.+]]: memref<16x64xf16>, %[[B:.+]]: memref<64x16xf16>, %[[C:.+]]: memref<16x16xf16> -// CHECK-DAG: %[[zero:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[kStep:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[kUB:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[wmmaSizeK:.+]] = arith.constant 16 : index -// CHECK-COUNT-1: %[[cTile:.+]] = gpu.subgroup_mma_load_matrix %[[C]] -// CHECK: %[[loopRes:.+]] = scf.for %[[iv:.+]] = %[[zero]] to %[[kUB]] step %[[kStep]] iter_args(%[[acc_tile:.+]] = %[[cTile]]) -// CHECK: gpu.subgroup_mma_load_matrix %[[A]]{{\[}}%[[zero]], %[[iv]] -// CHECK: %[[aCol:.+]] = arith.addi %[[iv]], %[[wmmaSizeK]] -// CHECK: gpu.subgroup_mma_load_matrix %[[A]]{{\[}}%[[zero]], %[[aCol]] -// CHECK: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[iv]], %[[zero]] -// CHECK: %[[bRow:.+]] = arith.addi %[[iv]], %[[wmmaSizeK]] -// CHECK: gpu.subgroup_mma_load_matrix %[[B]]{{\[}}%[[bRow]], %[[zero]] -// CHECK: gpu.subgroup_mma_compute -// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute -// CHECK: scf.yield %[[res]] -// CHECK: } -// CHECK: gpu.subgroup_mma_store_matrix %[[loopRes]], %[[C]] - -// ----- - -func.func @batch_reduce_matmul(%arg0: memref<64x16x16xf16>, - %arg1: memref<64x16x16xf16>, - %arg2: memref<16x16xf16>) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<64x16x16xf16>, memref<64x16x16xf16>) - outs(%arg2 : memref<16x16xf16>) - return -} - -// CHECK-LABEL: func.func @batch_reduce_matmul( -// CHECK-SAME: %[[A:.+]]: memref<64x16x16xf16>, %[[B:.+]]: memref<64x16x16xf16>, %[[C:.+]]: memref<16x16xf16> -// CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[batch:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[one:.+]] = arith.constant 1 : index -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 -// CHECK: %[[res:.+]] = scf.for {{.*}}to %[[batch]] {{.*}}iter_args(%[[acc_tile:.*]] = %[[tileC]]) -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 -// CHECK: %[[part_sum:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[acc_tile]] -// CHECK: scf.yield %[[part_sum]] -// CHECK: } -// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[C]]{{.*}}leadDimension = 16 -// CHECK: scf.reduce -// CHECK: } - -// ----- - -func.func @batch_reduce_matmul_2D_tiled(%arg0: memref<64x32x32xf16>, - %arg1: memref<64x32x32xf16>, - %arg2: memref<32x32xf16>) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<64x32x32xf16>, memref<64x32x32xf16>) - outs(%arg2 : memref<32x32xf16>) - return -} - -// CHECK-LABEL: func.func @batch_reduce_matmul_2D_tiled( -// CHECK-SAME: %[[A:.+]]: memref<64x32x32xf16>, %[[B:.+]]: memref<64x32x32xf16>, %[[C:.+]]: memref<32x32xf16> -// CHECK-DAG: %[[batch:.+]] = arith.constant 64 : index -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[C]] -// CHECK: %[[res:.+]] = scf.for {{.*}}to %[[batch]] {{.*}}iter_args -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[A]] -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[B]] -// CHECK-COUNT-8: gpu.subgroup_mma_compute -// CHECK: scf.yield{{.*}}: !gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: } -// CHECK-COUNT-4: gpu.subgroup_mma_store_matrix - -// ----- - -func.func @batch_reduce_matmul_K_dim_tiled(%arg0: memref<32x16x64xf16>, - %arg1: memref<32x64x16xf16>, - %arg2: memref<16x16xf16>) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<32x16x64xf16>, memref<32x64x16xf16>) - outs(%arg2 : memref<16x16xf16>) - return -} - -// CHECK-LABEL: func.func @batch_reduce_matmul_K_dim_tiled( -// CHECK-SAME: %[[A:.+]]: memref<32x16x64xf16>, %[[B:.+]]: memref<32x64x16xf16>, %[[C:.+]]: memref<16x16xf16> -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[c32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[c64:.+]] = arith.constant 64 : index -// CHECK-COUNT-1: %[[cTile:.+]] = gpu.subgroup_mma_load_matrix %[[C]] -// CHECK: %[[batchLoopRes:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c32]] step %[[c1]] iter_args(%[[acc_batch:.+]] = %[[cTile]]) -// CHECK: %[[kLoopRes:.+]] = scf.for %[[iv:.+]] = %[[zero]] to %[[kUB]] step %[[kStep]] iter_args(%[[acc_k_dim:.+]] = %[[acc_batch]]) -// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[A]] -// CHECK-COUNT-2: gpu.subgroup_mma_load_matrix %[[B]] -// CHECK: gpu.subgroup_mma_compute -// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute -// CHECK: scf.yield %[[res]] -// CHECK: scf.yield %[[kLoopRes]] -// CHECK: } -// CHECK: gpu.subgroup_mma_store_matrix %[[batchLoopRes]], %[[C]] - -// ----- - -func.func @matmul_strided_memrefs(%arg0: memref<16x32x16xf16>, %arg1: memref<16x64x16xf16>, %arg2: memref<32x32xf16>) { - %subview = memref.subview %arg0[0, 0, 0] [16, 1, 16] [1, 1, 1] - : memref<16x32x16xf16> to memref<16x16xf16, strided<[512, 1], offset: 0>> - %subview_0 = memref.subview %arg1[0, 0, 0] [16, 1, 16] [1, 1, 1] - : memref<16x64x16xf16> to memref<16x16xf16, strided<[1024, 1], offset: 0>> - %subview_1 = memref.subview %arg2[16, 0] [16, 16] [1, 1] - : memref<32x32xf16> to memref<16x16xf16, strided<[32, 1], offset: 512>> - - linalg.matmul ins(%subview, %subview_0 : memref<16x16xf16, strided<[512, 1], offset: 0>>, - memref<16x16xf16, strided<[1024, 1], offset: 0>>) - outs(%subview_1 : memref<16x16xf16, strided<[32, 1], offset: 512>>) - - return -} - -// CHECK-LABEL: func.func @matmul_strided_memrefs( -// CHECK-SAME: %[[A:.+]]: memref<16x32x16xf16>, %[[B:.+]]: memref<16x64x16xf16>, %[[C:.+]]: memref<32x32xf16> -// CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[one:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[subA:.+]] = memref.subview %[[A]] -// CHECK-DAG: %[[subB:.+]] = memref.subview %[[B]] -// CHECK-DAG: %[[subC:.+]] = memref.subview %[[C]] -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[subC]]{{.*}}leadDimension = 32 -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[subA]]{{.*}}leadDimension = 512 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[subB]]{{.*}}leadDimension = 1024 -// CHECK: %[[res:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] -// CHECK: gpu.subgroup_mma_store_matrix %[[res]], %[[subC]]{{.*}}leadDimension = 32 -// CHECK: scf.reduce -// CHECK: } - -// ----- - -// Operands' data types do not match supported WMMA types. -func.func @wrong_data_type(%arg0: memref<16x16xf32>, - %arg1: memref<16x16xf32>, - %arg2: memref<16x16xf32>) { - linalg.matmul ins(%arg0, %arg1 : memref<16x16xf32>, memref<16x16xf32>) - outs(%arg2 : memref<16x16xf32>) - return -} - -// CHECK-LABEL: func.func @wrong_data_type( -// CHECK-NOT: gpu.{{.*}}_mma_ - -// ----- - -// Dynamic shapes are not supported. -func.func @matmul_dynamic_shapes(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - -// CHECK-LABEL: func.func @matmul_dynamic_shape -// CHECK: linalg.matmul - -// ----- - -// Dynamic shapes are not supported. -func.func @brgemm_dynamic_shapes(%arg0: memref, - %arg1: memref, - %arg2: memref) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - -// CHECK-LABEL: func.func @brgemm_dynamic_shapes -// CHECK: linalg.batch_reduce_matmul - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @matmul_add_relu(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf16>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - linalg.matmul ins(%arg0, %arg1 : memref<16x16xf16>, memref<16x16xf16>) outs(%arg3 : memref<16x16xf16>) - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<16x16xf16>) outs(%arg3 : memref<16x16xf16>) { - ^bb0(%in: f16, %out: f16): - %0 = arith.addf %in, %out : f16 - linalg.yield %0 : f16 - } - linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<16x16xf16>) { - ^bb0(%out: f16): - %0 = arith.maximumf %out, %cst : f16 - linalg.yield %0 : f16 - } - return -} - -// CHECK-LABEL: func.func @matmul_add_relu( -// CHECK-SAME: %[[A:.+]]: memref<16x16xf16>, %[[B:.+]]: memref<16x16xf16>, %[[BIAS:.+]]: memref<16x16xf16>, %[[C:.+]]: memref<16x16xf16> -// CHECK-DAG: %[[subgroup_size:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[one:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[zeroF16:.+]] = arith.constant 0.000000e+00 : f16 -// CHECK: scf.parallel {{.*}}to (%[[subgroup_size]]) -// CHECK-DAG: %[[tileC:.+]] = gpu.subgroup_mma_load_matrix %[[C]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileA:.+]] = gpu.subgroup_mma_load_matrix %[[A]]{{.*}}leadDimension = 16 -// CHECK-DAG: %[[tileB:.+]] = gpu.subgroup_mma_load_matrix %[[B]]{{.*}}leadDimension = 16 -// CHECK: %[[compRes:.+]] = gpu.subgroup_mma_compute %[[tileA]], %[[tileB]], %[[tileC]] -// CHECK: %[[tileBias:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{.*}}leadDimension = 16 -// CHECK: %[[addRes:.+]] = gpu.subgroup_mma_elementwise addf %[[compRes]], %[[tileBias]] -// CHECK: %[[tileCstZero:.+]] = gpu.subgroup_mma_constant_matrix %[[zeroF16]] -// CHECK: %[[reluRes:.+]] = gpu.subgroup_mma_elementwise maxf %[[addRes]], %[[tileCstZero]] -// CHECK: gpu.subgroup_mma_store_matrix %[[reluRes]], %[[C]]{{.*}}leadDimension = 16 -// CHECK: scf.reduce -// CHECK: } - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @matmul_add_relu_2D_tiled(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>, %arg3: memref<32x32xf16>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - linalg.matmul ins(%arg0, %arg1 : memref<32x32xf16>, memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<32x32xf16>) outs(%arg3 : memref<32x32xf16>) { - ^bb0(%in: f16, %out: f16): - %0 = arith.addf %in, %out : f16 - linalg.yield %0 : f16 - } - linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg3 :memref<32x32xf16>) { - ^bb0(%out: f16): - %0 = arith.maximumf %out, %cst : f16 - linalg.yield %0 : f16 - } - return -} - -// CHECK-LABEL: func.func @matmul_add_relu_2D_tiled( -// CHECK-SAME: %[[A:.+]]: memref<32x32xf16>, %[[B:.+]]: memref<32x32xf16>, %[[BIAS:.+]]: memref<32x32xf16>, %[[C:.+]]: memref<32x32xf16> -// CHECK-DAG: %[[f0:.+]] = arith.constant 0.0 -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 -// CHECK-DAG: %[[c16:.+]] = arith.constant 16 -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[C]] -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[A]] -// CHECK-COUNT-4: gpu.subgroup_mma_load_matrix %[[B]] -// CHECK-COUNT-8: gpu.subgroup_mma_compute -// CHECK-DAG: %[[b0:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c0]], %[[c0]] -// CHECK-DAG: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b0]] -// CHECK-DAG: %[[b1:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c0]], %[[c16]] -// CHECK-DAG: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b1]] -// CHECK-DAG: %[[b2:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c16]], %[[c0]] -// CHECK-DAG: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b2]] -// CHECK-DAG: %[[b3:.+]] = gpu.subgroup_mma_load_matrix %[[BIAS]]{{\[}}%[[c16]], %[[c16]] -// CHECK: gpu.subgroup_mma_elementwise addf{{.*}}, %[[b3]] -// CHECK: %[[cstMat:.+]] = gpu.subgroup_mma_constant_matrix %[[f0]] -// CHECK-COUNT-4: gpu.subgroup_mma_elementwise maxf{{.*}}, %[[cstMat]] -// CHECK-COUNT-4: gpu.subgroup_mma_store_matrix{{.*}}, %[[C]] diff --git a/test/GPU/linalg-to-gpu.mlir b/test/GPU/linalg-to-gpu.mlir deleted file mode 100644 index cbf084365..000000000 --- a/test/GPU/linalg-to-gpu.mlir +++ /dev/null @@ -1,162 +0,0 @@ -// RUN: tpp-opt %s -linalg-to-gpu -split-input-file | FileCheck %s - -func.func @matmul(%arg0: memref<256x2048xf32>, - %arg1: memref<2048x1024xf32>, - %arg2: memref<256x1024xf32>) { - linalg.matmul ins(%arg0, %arg1 : memref<256x2048xf32>, memref<2048x1024xf32>) - outs(%arg2 : memref<256x1024xf32>) - return -} - -// CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[A:.+]]: memref<256x2048xf32>, %[[B:.+]]: memref<2048x1024xf32>, %[[C:.+]]: memref<256x1024xf32> -// CHECK-DAG: %[[m:.+]] = arith.constant 256 : index -// CHECK-DAG: %[[n:.+]] = arith.constant 1024 : index -// CHECK-DAG: %[[k:.+]] = arith.constant 2048 : index -// CHECK: scf.parallel (%[[arg3:.+]], %[[arg4:.+]]) ={{.*}}to (%[[m]], %[[n]]) -// CHECK: %[[init:.+]] = memref.load %[[C]]{{\[}}%[[arg3]], %[[arg4]]{{\]}} : memref<256x1024xf32> -// CHECK: %[[sum:.+]] = scf.for {{.*}}to %[[k]] {{.*}}iter_args(%[[acc:.*]] = %[[init]]) -// CHECK: %[[elemA:.+]] = memref.load %[[A]] -// CHECK: %[[elemB:.+]] = memref.load %[[B]] -// CHECK: %[[mul:.+]] = arith.mulf %[[elemA]], %[[elemB]] : f32 -// CHECK: %[[res:.+]] = arith.addf %[[acc]], %[[mul]] : f32 -// CHECK: scf.yield %[[res]] : f32 -// CHECK: } -// CHECK: memref.store %[[sum]], %[[C]][%arg3, %arg4] : memref<256x1024xf32> -// CHECK: scf.reduce -// CHECK: } - -// ----- - -func.func @batch_reduce_matmul(%arg0: memref<32x256x2048xf32>, - %arg1: memref<32x2048x1024xf32>, - %arg2: memref<256x1024xf32>) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<32x256x2048xf32>, memref<32x2048x1024xf32>) - outs(%arg2 : memref<256x1024xf32>) - return -} - -// CHECK-LABEL: func.func @batch_reduce_matmul( -// CHECK-SAME: %[[A:.+]]: memref<32x256x2048xf32>, %[[B:.+]]: memref<32x2048x1024xf32>, %[[C:.+]]: memref<256x1024xf32> -// CHECK-DAG: %[[m:.+]] = arith.constant 256 : index -// CHECK-DAG: %[[n:.+]] = arith.constant 1024 : index -// CHECK-DAG: %[[k:.+]] = arith.constant 2048 : index -// CHECK-DAG: %[[batch:.+]] = arith.constant 32 : index -// CHECK: scf.parallel (%[[arg3:.+]], %[[arg4:.+]]) ={{.*}}to (%[[m]], %[[n]]) -// CHECK: %[[init:.+]] = memref.load %[[C]]{{\[}}%[[arg3]], %[[arg4]]{{\]}} : memref<256x1024xf32> -// CHECK: %[[res:.+]] = scf.for {{.*}}to %[[batch]] {{.*}}iter_args(%[[outerAcc:.*]] = %[[init]]) -// CHECK: %[[sum:.+]] = scf.for {{.*}}to %[[k]] {{.*}}iter_args(%[[innerAcc:.*]] = %[[outerAcc]]) -// CHECK: %[[elemA:.+]] = memref.load %[[A]] -// CHECK: %[[elemB:.+]] = memref.load %[[B]] -// CHECK: %[[mul:.+]] = arith.mulf %[[elemA]], %[[elemB]] : f32 -// CHECK: %[[elemC:.+]] = arith.addf %[[innerAcc]], %[[mul]] : f32 -// CHECK: scf.yield %[[elemC]] : f32 -// CHECK: } -// CHECK: scf.yield %[[sum]] : f32 -// CHECK: } -// CHECK: memref.store %[[res]], %[[C]][%arg3, %arg4] : memref<256x1024xf32> -// CHECK: scf.reduce -// CHECK: } - -// ----- - -// Dynamic shapes are not supported. -func.func @matmul_dynamic_shapes(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - -// CHECK-LABEL: func.func @matmul_dynamic_shape -// CHECK: linalg.matmul - -// ----- - -// Dynamic shapes are not supported. -func.func @brgemm_dynamic_shapes(%arg0: memref, - %arg1: memref, - %arg2: memref) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - -// CHECK-LABEL: func.func @brgemm_dynamic_shapes -// CHECK: linalg.batch_reduce_matmul - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @matmul_add_relu(%arg0: memref<256x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<256x1024xf32>, %arg3: memref<256x1024xf32>) { - %c0 = arith.constant 0 : index - %c256 = arith.constant 256 : index - %c1024 = arith.constant 1024 : index - %c32 = arith.constant 32 : index - %cst = arith.constant 0.000000e+00 : f32 - scf.parallel (%arg4, %arg5) = (%c0, %c0) to (%c256, %c1024) step (%c32, %c32) { - %subview = memref.subview %arg2[%arg4, %arg5] [32, 32] [1, 1] : memref<256x1024xf32> to memref<32x32xf32, strided<[1024, 1], offset: ?>> - %subview_0 = memref.subview %arg0[%arg4, 0] [32, 1024] [1, 1] : memref<256x1024xf32> to memref<32x1024xf32, strided<[1024, 1], offset: ?>> - %subview_1 = memref.subview %arg1[0, %arg5] [1024, 32] [1, 1] : memref<1024x1024xf32> to memref<1024x32xf32, strided<[1024, 1], offset: ?>> - %subview_2 = memref.subview %arg3[%arg4, %arg5] [32, 32] [1, 1] : memref<256x1024xf32> to memref<32x32xf32, strided<[1024, 1], offset: ?>> - linalg.matmul ins(%subview_0, %subview_1 : memref<32x1024xf32, strided<[1024, 1], offset: ?>>, memref<1024x32xf32, strided<[1024, 1], offset: ?>>) outs(%subview_2 : memref<32x32xf32, strided<[1024, 1], offset: ?>>) - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%subview : memref<32x32xf32, strided<[1024, 1], offset: ?>>) outs(%subview_2 : memref<32x32xf32, strided<[1024, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.addf %in, %out : f32 - linalg.yield %0 : f32 - } - linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%subview_2 : memref<32x32xf32, strided<[1024, 1], offset: ?>>) { - ^bb0(%out: f32): - %0 = arith.maximumf %out, %cst : f32 - linalg.yield %0 : f32 - } - scf.reduce - } - return -} - -// CHECK-LABEL: func.func @matmul_add_relu( -// CHECK-SAME: %[[A:.+]]: memref<256x1024xf32>, %[[B:.+]]: memref<1024x1024xf32>, %[[BIAS:.+]]: memref<256x1024xf32>, %[[C:.+]]: memref<256x1024xf32> -// CHECK-DAG: %[[m:.+]] = arith.constant 256 : index -// CHECK-DAG: %[[n:.+]] = arith.constant 1024 : index -// CHECK-DAG: %[[tile:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[zero:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: scf.parallel (%[[arg5:.+]], %[[arg6:.+]]) ={{.*}}to (%[[m]], %[[n]]) -// CHECK: %[[outTile:.+]] = memref.subview %[[C]] -// CHECK: scf.parallel (%[[arg6:.+]], %[[arg7:.+]]) ={{.*}}to (%[[tile]], %[[tile]]) -// CHECK-NOT: linalg.matmul -// CHECK: %[[init:.+]] = memref.load %[[outTile]]{{\[}}%[[arg6]], %[[arg7]]{{\]}} : memref<32x32xf32 -// CHECK: %[[sum:.+]] = scf.for {{.*}}to %[[n]] {{.*}}iter_args(%[[acc:.*]] = %[[init]]) -// CHECK: %[[elemA:.+]] = memref.load -// CHECK: %[[elemB:.+]] = memref.load -// CHECK: %[[mul:.+]] = arith.mulf %[[elemA]], %[[elemB]] : f32 -// CHECK: %[[res:.+]] = arith.addf %[[acc]], %[[mul]] : f32 -// CHECK: scf.yield %[[res]] : f32 -// CHECK: } -// CHECK-NOT: linalg.generic -// CHECK: %[[elemBias:.+]] = memref.load -// CHECK: %[[biasAdd:.+]] = arith.addf %[[sum]], %[[elemBias]] -// CHECK: %[[reluRes:.+]] = arith.maximumf %[[biasAdd]], %[[zero]] -// CHECK: memref.store %[[reluRes]], %[[outTile]] -// CHECK: scf.reduce -// CHECK: } -// CHECK: scf.reduce -// CHECK: } - -// ----- - -// Do not fuse unknown ops. -func.func @mixed_ops_chain(%arg0: memref<256x256xf32>, %arg1: memref<256x256xf32>, %arg2: memref<256x256xf32>) { - linalg.matmul ins(%arg0, %arg1 : memref<256x256xf32>, memref<256x256xf32>) - outs(%arg2 : memref<256x256xf32>) - call @eltwiseFunc(%arg0, %arg1, %arg2) : (memref<256x256xf32>, memref<256x256xf32>, memref<256x256xf32>) -> () - linalg.add ins(%arg0, %arg1 : memref<256x256xf32>, memref<256x256xf32>) - outs(%arg2 : memref<256x256xf32>) - return -} -func.func private @eltwiseFunc(memref<256x256xf32>, memref<256x256xf32>, memref<256x256xf32>) -> () - -// CHECK-LABEL: func.func @mixed_ops_chain -// CHECK-NOT: linalg.matmul -// CHECK: call @eltwiseFunc -// CHECK: linalg.add