Skip to content

Commit

Permalink
[mlir][spirv] Add spirv-to-llvm conversion for OpControlBarrier (llvm…
Browse files Browse the repository at this point in the history
…#111864)

The conversion is based on the expected llvm function from the
LLVM/SPIRV translation tool.
  • Loading branch information
FMarno authored Oct 19, 2024
1 parent 1bbf3a3 commit 1775b98
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def SPIRV_ControlBarrierOp : SPIRV_Op<"ControlBarrier", []> {
#### Example:

```mlir
spirv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory"
spirv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
```
}];

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===-- SPIRVBarrierOps.td - MLIR SPIR-V Barrier Ops -------*- tablegen -*-===//
//===-- SPIRVMiscOps.td - MLIR SPIR-V Misc Ops -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
70 changes: 69 additions & 1 deletion mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
}
};

static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
StringRef name,
ArrayRef<Type> paramTypes,
Type resultType) {
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTable, name));
if (func)
return func;

OpBuilder b(symbolTable->getRegion(0));
func = b.create<LLVM::LLVMFuncOp>(
symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
func.setConvergent(true);
func.setNoUnwind(true);
func.setWillReturn(true);
return func;
}

static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
LLVM::LLVMFuncOp func,
ValueRange args) {
auto call = builder.create<LLVM::CallOp>(loc, func, args);
call.setCConv(func.getCConv());
call.setConvergentAttr(func.getConvergentAttr());
call.setNoUnwindAttr(func.getNoUnwindAttr());
call.setWillReturnAttr(func.getWillReturnAttr());
return call;
}

class ControlBarrierPattern
: public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
public:
using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii";
Operation *symbolTable =
controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();

Type i32 = rewriter.getI32Type();

Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
LLVM::LLVMFuncOp func =
lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);

Location loc = controlBarrierOp->getLoc();
Value execution = rewriter.create<LLVM::ConstantOp>(
loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
Value memory = rewriter.create<LLVM::ConstantOp>(
loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
Value semantics = rewriter.create<LLVM::ConstantOp>(
loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));

auto call = createSPIRVBuiltinCall(loc, rewriter, func,
{execution, memory, semantics});

rewriter.replaceOp(controlBarrierOp, call);
return success();
}
};

/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
Expand Down Expand Up @@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,

// Return ops
ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
ReturnPattern, ReturnValuePattern,

// Barrier ops
ControlBarrierPattern>(patterns.getContext(), typeConverter);

patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s

//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//

// CHECK: llvm.func spir_funccc @_Z22__spirv_ControlBarrieriii(i32, i32, i32) attributes {convergent, no_unwind, will_return}

// CHECK-LABEL: @control_barrier
spirv.func @control_barrier() "None" {
// CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
spirv.ControlBarrier <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>

// CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
spirv.Return
}

0 comments on commit 1775b98

Please sign in to comment.