Skip to content

Commit

Permalink
[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics (llvm#78900)
Browse files Browse the repository at this point in the history
This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM
intrinsics.
* Doc updated for the commit_group Op.
* Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the
'nvvm.setmaxregister' test.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
  • Loading branch information
durga4github authored Jan 22, 2024
1 parent 12c241b commit aa4547f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 25 deletions.
30 changes: 18 additions & 12 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//

def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
let description = [{
This Op commits all prior initiated but uncommitted cp.async.bulk
instructions into a cp.async.bulk-group.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
}];

string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
}];
}

def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
OptionalAttr<UnitAttr>:$read)>
{
OptionalAttr<UnitAttr>:$read)> {
let assemblyFormat = "$group attr-dict";
let description = [{
Op waits for completion of the most recent bulk async-groups.
Expand All @@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
}];

let extraClassDefinition = [{
std::string $cppClass::getPtx() {
auto ptx = std::string("cp.async.bulk.wait_group");
if(getRead()) ptx += ".read";
ptx += " %0;"; return ptx; }
string llvmBuilder = [{
auto intId = op.getRead() ?
llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
createIntrinsicCall(builder, intId, builder.getInt32($group));
}];
}


def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
Expand Down
18 changes: 7 additions & 11 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -638,23 +638,19 @@ func.func @set_max_register() {

// -----

func.func @cp_bulk_commit() {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
func.func @cp_async_bulk_commit() {
// CHECK: nvvm.cp.async.bulk.commit.group
nvvm.cp.async.bulk.commit.group
func.return
}

// -----

func.func @cp_bulk_wait_group() {
// CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
// CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
// CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
func.func @cp_async_bulk_wait_group() {
// CHECK: nvvm.cp.async.bulk.wait_group 1
// CHECK: nvvm.cp.async.bulk.wait_group 0
// CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
// CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
nvvm.cp.async.bulk.wait_group 1
nvvm.cp.async.bulk.wait_group 0
nvvm.cp.async.bulk.wait_group 5 {read}
Expand Down
24 changes: 22 additions & 2 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p

// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
// CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
nvvm.setmaxregister increase 256
// CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
// CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
nvvm.setmaxregister decrease 24
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
// CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
nvvm.cp.async.bulk.commit.group
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
nvvm.cp.async.bulk.wait_group 0
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
nvvm.cp.async.bulk.wait_group 3
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
nvvm.cp.async.bulk.wait_group 0 {read}
// CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
nvvm.cp.async.bulk.wait_group 3 {read}
llvm.return
}

// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
Expand Down

0 comments on commit aa4547f

Please sign in to comment.