diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 3a6c6e5438c6d7..1941c4dece1b86 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -463,17 +463,17 @@ def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max reg } def SetMaxRegisterActionAttr : EnumAttr; -def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister"> { +def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> { let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action); let assemblyFormat = "$action $regCount attr-dict"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - if(getAction() == NVVM::SetMaxRegisterAction::increase) - return std::string("setmaxnreg.inc.sync.aligned.u32 %0;"); - return std::string("setmaxnreg.dec.sync.aligned.u32 %0;"); - } - }]; let hasVerifier = 1; + string llvmBuilder = [{ + auto intId = (op.getAction() == NVVM::SetMaxRegisterAction::increase) ? + llvm::Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32 : + llvm::Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32; + + createIntrinsicCall(builder, intId, builder.getInt32($regCount)); + }]; } def NVVM_FenceMbarrierInitOp : NVVM_PTXBuilder_Op<"fence.mbarrier.init"> { diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 7e08ec6ffcbd89..2ee92e3d9527a6 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -628,9 +628,10 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { // ----- func.func @set_max_register() { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 $0;", "n" + // CHECK: nvvm.setmaxregister increase 232 nvvm.setmaxregister increase 232 - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 $0;", "n" + + // CHECK: nvvm.setmaxregister decrease 40 nvvm.setmaxregister decrease 40 func.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f83be9dbb2ff30..423b1a133a4ae2 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -369,6 +369,15 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p llvm.return } +// CHECK-LABEL: @llvm_nvvm_setmaxregister +llvm.func @llvm_nvvm_setmaxregister() { + // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256) + nvvm.setmaxregister increase 256 + // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24) + nvvm.setmaxregister decrease 24 + llvm.return +} + // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})