diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index bd40f2949..fd121b41d 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -525,4 +525,18 @@ def SplitReductionDim : Pass<"split-reduction-dim", "func::FuncOp"> { ]; } +def GpuVectorize : Pass<"gpu-vectorize", "ModuleOp"> { + let summary = "Vectorize GPU kernel."; + let description = [{ + Convert ops targeting GPU to vectorized representation. + }]; + let dependentDialects = ["gpu::GPUDialect", + "scf::SCFDialect", + "memref::MemRefDialect", + "tensor::TensorDialect", + "math::MathDialect", + "arith::ArithDialect", + "vector::VectorDialect"]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/lib/TPP/GPU/CMakeLists.txt b/lib/TPP/GPU/CMakeLists.txt index bc5082b9a..d7dd06dbd 100644 --- a/lib/TPP/GPU/CMakeLists.txt +++ b/lib/TPP/GPU/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(TPPGPU GpuDataTransfer.cpp GpuInlineConstants.cpp LinalgToXeGPU.cpp + GpuVectorize.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index f8a574a0c..06f238aaf 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -65,9 +65,9 @@ llvm::cl::list llvm::cl::CommaSeparated); // Control GPU vectorization. -llvm::cl::opt gpuVectorize("gpu-vectorize", - llvm::cl::desc("Vectorize GPU kernel"), - llvm::cl::init(false)); +llvm::cl::opt gpuVector("gpu-vector", + llvm::cl::desc("Vectorize GPU kernel"), + llvm::cl::init(false)); namespace mlir { namespace tpp { @@ -187,12 +187,22 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase, pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions)); pm.addPass(createCleanup()); - if (gpuVectorize) { + if (gpuVector) { // Early reduction dimension splitting is incompatible with // Linalg to XeGPU lowering that expects full GEMM. // For now, enable only with other vectorization passes. pm.addPass(createSplitReductionDim(SplitReductionDimOptions{kTile})); pm.addPass(createCleanup()); + + // Vectorize at tensor-level to benefit from better cleanup utilities like + // folding. + // TODO: Enable vectorization when vector unrolling is added. + // When vector sizes exceed hardware supported lengths, + // pipeline gets stuck on GPU binary compilation step. + // The vectorization can only be enabled when a pass + // to resize vector operations is available. + pm.addPass(createGpuVectorize()); + pm.addPass(createCleanup()); } // Preprocess and bufferize as further conversion requires memref diff --git a/lib/TPP/GPU/GpuToCuda.cpp b/lib/TPP/GPU/GpuToCuda.cpp index 6ee33d0af..07cd785b9 100644 --- a/lib/TPP/GPU/GpuToCuda.cpp +++ b/lib/TPP/GPU/GpuToCuda.cpp @@ -67,18 +67,29 @@ struct GpuToCuda : public tpp::impl::GpuToCudaBase, memref::createExpandStridedMetadataPass()); pm.addNestedPass(arith::createArithExpandOpsPass()); pm.addNestedPass(createLowerAffinePass()); + pm.addNestedPass(createConvertVectorToSCFPass()); pm.addNestedPass(createConvertSCFToCFPass()); - // Create CUDA kernels. - pm.addNestedPass(createStripDebugInfoPass()); + pm.addNestedPass(createConvertNVGPUToNVVMPass()); pm.addNestedPass(createConvertGpuOpsToNVVMOps()); - pm.addNestedPass(createReconcileUnrealizedCastsPass()); + pm.addNestedPass(createConvertVectorToLLVMPass()); + pm.addNestedPass(createConvertNVVMToLLVMPass()); + pm.addNestedPass(createConvertFuncToLLVMPass()); + pm.addNestedPass(createArithToLLVMConversionPass()); + pm.addNestedPass(createConvertIndexToLLVMPass()); + GpuNVVMAttachTargetOptions nvvmTargetOptions; nvvmTargetOptions.triple = gpuTriple; nvvmTargetOptions.chip = gpuChip; nvvmTargetOptions.features = gpuFeatures; pm.addPass(createGpuNVVMAttachTarget(nvvmTargetOptions)); + // Create CUDA kernels. + pm.addNestedPass(createStripDebugInfoPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createReconcileUnrealizedCastsPass()); + // Cleanup IR. pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); diff --git a/lib/TPP/GPU/GpuVectorize.cpp b/lib/TPP/GPU/GpuVectorize.cpp new file mode 100644 index 000000000..04888eddb --- /dev/null +++ b/lib/TPP/GPU/GpuVectorize.cpp @@ -0,0 +1,116 @@ +//===- GpuVectorize.cpp ------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/TransformOps/Utils.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_GPUVECTORIZE +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +// Vectorize ops within GPU kernel. +struct VectorizeGpuLaunch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::LaunchOp launchOp, + PatternRewriter &rewriter) const override { + // Vectorize all linalg ops within GPU kernel. + // It is expected that the ops operate on statically sized tiles. + auto walkResult = launchOp->walk([&](linalg::LinalgOp linalgOp) { + if (linalgOp.hasDynamicShape()) + return WalkResult::interrupt(); + + if (failed(vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, + /*scalableVecDims=*/{}))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return rewriter.notifyMatchFailure( + launchOp, "Failed to vectorize ops within GPU launch"); + + return success(); + } +}; + +// Vectorize linalg ops targeting GPU. +struct GpuVectorizeLinalg : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + // Vectorize all Linalg ops within parallelized loops. + if (!linalgOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure(linalgOp, "Expects tensor semantics"); + + if (linalgOp.hasDynamicShape()) + return rewriter.notifyMatchFailure(linalgOp, + "Expects static shapes only"); + + // Only process operations within parallelized loops. + // TODO: Use some different mechanism like annotations to determine which + // ops target GPU. + if (!linalgOp->getParentOfType()) + return rewriter.notifyMatchFailure(linalgOp, + "Expects parallel loop parent"); + + return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, + /*scalableVecDims=*/{}); + } +}; + +// Vectorize operations targeting GPU. +struct GpuVectorize : public tpp::impl::GpuVectorizeBase { + using GpuVectorizeBase::GpuVectorizeBase; + + void runOnOperation() override { + MLIRContext *ctx = getOperation().getContext(); + RewritePatternSet patterns(ctx); + + // Vectorize core computation ops within kernel launch. + patterns.add(ctx); + + // Vector postprocessing patterns. + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + vector::populateVectorReductionToContractPatterns(patterns); + vector::populateSinkVectorOpsPatterns(patterns); + vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); + vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace diff --git a/test/GPU/CUDA/Integration/vector-contract-small.mlir b/test/GPU/CUDA/Integration/vector-contract-small.mlir new file mode 100644 index 000000000..1d3e5c3bd --- /dev/null +++ b/test/GPU/CUDA/Integration/vector-contract-small.mlir @@ -0,0 +1,28 @@ +// RUN: ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0:${ASAN_OPTIONS} \ +// RUN: tpp-run %s -gpu=cuda -print \ +// RUN: -entry-point-result=void -e entry 2>&1 | \ +// RUN: FileCheck %s + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @entry(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32>) -> tensor<8x8xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = scf.forall (%arg3, %arg4) = (0, 0) to (8, 8) step (4, 4) shared_outs(%arg5 = %arg2) -> (tensor<8x8xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg3, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32> + %extracted_slice_0 = tensor.extract_slice %arg1[0, %arg4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32> + %extracted_slice_1 = tensor.extract_slice %arg5[%arg3, %arg4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32> + %1 = vector.transfer_read %extracted_slice[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x8xf32>, vector<4x8xf32> + %2 = vector.transfer_read %extracted_slice_0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<8x4xf32>, vector<8x4xf32> + %3 = vector.transfer_read %extracted_slice_1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<4x8xf32>, vector<8x4xf32> into vector<4x4xf32> + %5 = vector.transfer_write %4, %extracted_slice_1[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg5[%arg3, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32> + } + } + return %0 : tensor<8x8xf32> +} + +// CHECK-COUNT-8: 9, 9, 9, 9, 9, 9, 9, 9 diff --git a/test/GPU/gpu-vectorize.mlir b/test/GPU/gpu-vectorize.mlir new file mode 100644 index 000000000..b3a493e19 --- /dev/null +++ b/test/GPU/gpu-vectorize.mlir @@ -0,0 +1,205 @@ +// RUN: tpp-opt %s -gpu-vectorize -canonicalize -split-input-file | FileCheck %s + +func.func @vectorize_tensor_matmul(%arg0: tensor<64x64xf32>, + %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 64) step (16, 16) shared_outs(%arg5 = %arg2) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg3, 0] [16, 64] [1, 1] : tensor<64x64xf32> to tensor<16x64xf32> + %extracted_slice_0 = tensor.extract_slice %arg1[0, %arg4] [64, 16] [1, 1] : tensor<64x64xf32> to tensor<64x16xf32> + %extracted_slice_1 = tensor.extract_slice %arg5[%arg3, %arg4] [16, 16] [1, 1] : tensor<64x64xf32> to tensor<16x16xf32> + %1 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x64xf32>, tensor<64x16xf32>) + outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %1 into %arg5[%arg3, %arg4] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<64x64xf32> + } + } + return %0 : tensor<64x64xf32> +} + +// CHECK-LABEL: @vectorize_tensor_matmul( +// CHECK: scf.forall +// CHECK-NOT: linalg.matmul +// CHECK-COUNT-3: vector.transfer_read +// CHECK: vector.contract +// CHECK: vector.transfer_write + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @vectorize_tensor_binary(%arg0: tensor<64x64xf32>, + %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = scf.forall (%arg4, %arg5) = (0, 0) to (64, 64) step (16, 16) shared_outs(%arg2 = %arg1) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg4, %arg5] [16, 16] [1, 1] : tensor<64x64xf32> to tensor<16x16xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%arg4, %arg5] [16, 16] [1, 1] : tensor<64x64xf32> to tensor<16x16xf32> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} + ins(%extracted_slice : tensor<16x16xf32>) outs(%extracted_slice_0 : tensor<16x16xf32>) { + ^bb0(%in: f32, %out: f32): + %3 = arith.subf %in, %out : f32 + linalg.yield %3 : f32 + } -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg2[%arg4, %arg5] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<64x64xf32> + } + } + return %0 : tensor<64x64xf32> +} + +// CHECK-LABEL: @vectorize_tensor_binary( +// CHECK: scf.forall +// CHECK-NOT: linalg.generic +// CHECK-COUNT-2: vector.transfer_read +// CHECK: arith.subf +// CHECK: vector.transfer_write + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @vectorize_tensor_unary(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = scf.forall (%arg4, %arg5) = (0, 0) to (64, 64) step (16, 16) shared_outs(%arg1 = %arg0) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg1[%arg4, %arg5] [16, 16] [1, 1] : tensor<64x64xf32> to tensor<16x16xf32> + %2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} + outs(%extracted_slice : tensor<16x16xf32>) { + ^bb0(%out: f32): + %3 = math.absf %out : f32 + linalg.yield %3 : f32 + } -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg1[%arg4, %arg5] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<64x64xf32> + } + } + return %0 : tensor<64x64xf32> +} + +// CHECK-LABEL: @vectorize_tensor_unary( +// CHECK: scf.forall +// CHECK-NOT: linalg.generic +// CHECK-COUNT-1: vector.transfer_read +// CHECK: math.absf +// CHECK: vector.transfer_write + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @vectorize_tensor_matmul_add(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, + %arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = scf.forall (%arg4, %arg5) = (0, 0) to (64, 64) step (16, 16) shared_outs(%arg6 = %arg2) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg3[%arg4, %arg5] [16, 16] [1, 1] : tensor<64x64xf32> to tensor<16x16xf32> + %extracted_slice_0 = tensor.extract_slice %arg0[%arg4, 0] [16, 64] [1, 1] : tensor<64x64xf32> to tensor<16x64xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg5] [64, 16] [1, 1] : tensor<64x64xf32> to tensor<64x16xf32> + %extracted_slice_2 = tensor.extract_slice %arg6[%arg4, %arg5] [16, 16] [1, 1] : tensor<64x64xf32> to tensor<16x16xf32> + %1 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x64xf32>, tensor<64x16xf32>) outs(%extracted_slice_2 : tensor<16x16xf32>) -> tensor<16x16xf32> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<16x16xf32>) outs(%1 : tensor<16x16xf32>) { + ^bb0(%in: f32, %out: f32): + %3 = arith.addf %in, %out : f32 + linalg.yield %3 : f32 + } -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg6[%arg4, %arg5] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<64x64xf32> + } + } + return %0 : tensor<64x64xf32> +} + +// CHECK-LABEL: @vectorize_tensor_matmul_add( +// CHECK: scf.forall +// CHECK-NOT: linalg.matmul +// CHECK-COUNT-3: vector.transfer_read +// CHECK: vector.contract +// CHECK-NOT: linalg.generic +// CHECK-COUNT-1: vector.transfer_read +// CHECK: arith.addf +// CHECK: vector.transfer_write + +// ----- + +func.func @vectorize_matmul(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%t0, %t1, %t2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) + outs(%arg2 : memref<64x64xf32>) + gpu.terminator + } + return +} + +// CHECK-LABEL: @vectorize_matmul( +// CHECK: gpu.launch +// CHECK-NOT: linalg.matmul +// CHECK-COUNT-3: vector.transfer_read +// CHECK: vector.contract +// CHECK: vector.transfer_write + +// ----- + +func.func @vectorize_binary(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) { + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%t0, %t1, %t2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.sub ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) + outs(%arg2 : memref<64x64xf32>) + gpu.terminator + } + return +} + +// CHECK-LABEL: @vectorize_binary( +// CHECK: gpu.launch +// CHECK-NOT: linalg.sub +// CHECK-COUNT-2: vector.transfer_read +// CHECK: arith.subf +// CHECK: vector.transfer_write + +// ----- + +func.func @vectorize_unary(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>) { + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%t0, %t1, %t2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.abs ins(%arg0 : memref<64x64xf32>) + outs(%arg1 : memref<64x64xf32>) + gpu.terminator + } + return +} + +// CHECK-LABEL: @vectorize_unary( +// CHECK: gpu.launch +// CHECK-NOT: linalg.abs +// CHECK-COUNT-1: vector.transfer_read +// CHECK: math.absf +// CHECK: vector.transfer_write + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @vectorize_matmul_add(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, + %arg2: memref<64x64xf32>, %arg3: memref<64x64xf32>) { + %c1 = arith.constant 1 : index + gpu.launch blocks(%b0, %b1, %b2) in (%gs0 = %c1, %gs1 = %c1, %gs2 = %c1) + threads(%t0, %t1, %t2) in (%bs0 = %c1, %bs1 = %c1, %bs2 = %c1) { + linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>) + outs(%arg2 : memref<64x64xf32>) + linalg.generic {indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg3 : memref<64x64xf32>) outs(%arg2 : memref<64x64xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %out : f32 + linalg.yield %2 : f32 + } + gpu.terminator + } + return +} + +// NOTE: RAW is present between vector.contract and arith.add as vector folders +// only work on tensors. +// CHECK-LABEL: @vectorize_matmul_add( +// CHECK: gpu.launch +// CHECK-NOT: linalg.matmul +// CHECK-COUNT-3: vector.transfer_read +// CHECK: vector.contract +// CHECK: vector.transfer_write +// CHECK-NOT: linalg.generic +// CHECK-COUNT-2: vector.transfer_read +// CHECK: arith.addf +// CHECK: vector.transfer_write