From aa4547fcc8eeb9bf4f3cf48cc926f62544e58767 Mon Sep 17 00:00:00 2001 From: Durgadoss R Date: Mon, 22 Jan 2024 13:09:30 +0530 Subject: [PATCH] [MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics (#78900) This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics. * Doc updated for the commit_group Op. * Tests are added to verify the lowering to the intrinsics. While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test. Signed-off-by: Durgadoss R --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 30 +++++++++++-------- .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 18 +++++------ mlir/test/Target/LLVMIR/nvvmir.mlir | 24 +++++++++++++-- 3 files changed, 47 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index b1bd3a95068076..37e525a139d4ad 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> { // NVVM TMA Ops //===----------------------------------------------------------------------===// -def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">, +def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">, Arguments<(ins )> { let assemblyFormat = "attr-dict"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); } + let description = [{ + This Op commits all prior initiated but uncommitted cp.async.bulk + instructions into a cp.async.bulk-group. + + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group) + }]; + + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group); }]; } -def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">, +def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">, Arguments<(ins ConfinedAttr]>:$group, - OptionalAttr:$read)> -{ + OptionalAttr:$read)> { let assemblyFormat = "$group attr-dict"; let description = [{ Op waits for completion of the most recent bulk async-groups. @@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group"> (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group) }]; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - auto ptx = std::string("cp.async.bulk.wait_group"); - if(getRead()) ptx += ".read"; - ptx += " %0;"; return ptx; } + string llvmBuilder = [{ + auto intId = op.getRead() ? + llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read : + llvm::Intrinsic::nvvm_cp_async_bulk_wait_group; + createIntrinsicCall(builder, intId, builder.getInt32($group)); }]; } - def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods, diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 9c7c27c49eb11d..0ac7331e1f6987 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -638,23 +638,19 @@ func.func @set_max_register() { // ----- -func.func @cp_bulk_commit() { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;" +func.func @cp_async_bulk_commit() { + // CHECK: nvvm.cp.async.bulk.commit.group nvvm.cp.async.bulk.commit.group func.return } // ----- -func.func @cp_bulk_wait_group() { - // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> () - // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> () - // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> () - // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> () +func.func @cp_async_bulk_wait_group() { + // CHECK: nvvm.cp.async.bulk.wait_group 1 + // CHECK: nvvm.cp.async.bulk.wait_group 0 + // CHECK: nvvm.cp.async.bulk.wait_group 5 {read} + // CHECK: nvvm.cp.async.bulk.wait_group 0 {read} nvvm.cp.async.bulk.wait_group 1 nvvm.cp.async.bulk.wait_group 0 nvvm.cp.async.bulk.wait_group 5 {read} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 8c5e3524a848f6..49f9426daabc21 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p // CHECK-LABEL: @llvm_nvvm_setmaxregister llvm.func @llvm_nvvm_setmaxregister() { - // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256) + // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256) nvvm.setmaxregister increase 256 - // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24) + // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24) nvvm.setmaxregister decrease 24 llvm.return } +// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group +llvm.func @llvm_nvvm_cp_async_bulk_commit_group() { + // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group() + nvvm.cp.async.bulk.commit.group + llvm.return +} + +// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group +llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { + // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0) + nvvm.cp.async.bulk.wait_group 0 + // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3) + nvvm.cp.async.bulk.wait_group 3 + // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0) + nvvm.cp.async.bulk.wait_group 0 {read} + // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3) + nvvm.cp.async.bulk.wait_group 3 {read} + llvm.return +} + // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})