Skip to content

Commit

Permalink
[mlir][EmitC] Model lvalues as a type in EmitC (llvm#91475)
Browse files Browse the repository at this point in the history
This adds an `emitc.lvalue` type which models assignable lvlaues in the
type system. Operations modifying memory are restricted to this type
accordingly.

See also the discussion on
[discourse](https://discourse.llvm.org/t/rfc-separate-variables-from-ssa-values-in-emitc/75224/9).
The most notable changes are as follows.

- `emitc.variable` and `emitc.global` ops are restricted to return
`emitc.array` or `emitc.lvalue` types
- Taking the address of a value is restricted to operands with lvalue
type
- Conversion from lvalues into SSA values is done with the new
`emitc.load` op
- The var operand of the `emitc.assign` op is restricted to lvalue type 
- The result of the `emitc.subscript` and `emitc.get_global` ops is a
lvalue type
- The operands and results of the `emitc.member` and
`emitc.member_of_ptr` ops are restricted to lvalue types

---------

Co-authored-by: Matthias Gehre <matthias.gehre@amd.com>
  • Loading branch information
simon-camp and mgehre-amd authored Aug 20, 2024
1 parent 3c53745 commit e47b507
Show file tree
Hide file tree
Showing 27 changed files with 810 additions and 382 deletions.
101 changes: 74 additions & 27 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,17 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {

```mlir
// Custom form of applying the & operator.
%0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
%0 = emitc.apply "&"(%arg0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>

// Generic form of the same operation.
%0 = "emitc.apply"(%arg0) {applicableOperator = "&"}
: (i32) -> !emitc.ptr<i32>
: (!emitc.lvalue<i32>) -> !emitc.ptr<i32>

```
}];
let arguments = (ins
Arg<StrAttr, "the operator to apply">:$applicableOperator,
EmitCType:$operand
AnyTypeOf<[EmitCType, EmitC_LValueType]>:$operand
);
let results = (outs EmitCType:$result);
let assemblyFormat = [{
Expand Down Expand Up @@ -836,6 +836,35 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
let assemblyFormat = "operands attr-dict `:` type(operands)";
}

def EmitC_LoadOp : EmitC_Op<"load", [
TypesMatchWith<"result type matches value type of 'operand'",
"operand", "result",
"::llvm::cast<LValueType>($_self).getValueType()">
]> {
let summary = "Load an lvalue into an SSA value.";
let description = [{
This operation loads the content of a modifiable lvalue into an SSA value.
Modifications of the lvalue executed after the load are not observable on
the produced value.

Example:

```mlir
%1 = emitc.load %0 : !emitc.lvalue<i32>
```
```c++
// Code emitted for the operation above.
int32_t v2 = v1;
```
}];

let arguments = (ins
Res<EmitC_LValueType, "", [MemRead<DefaultResource, 0, FullEffect>]>:$operand);
let results = (outs AnyType:$result);

let assemblyFormat = "$operand attr-dict `:` type($operand)";
}

def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
let summary = "Multiplication operation";
let description = [{
Expand Down Expand Up @@ -918,15 +947,15 @@ def EmitC_MemberOp : EmitC_Op<"member"> {

```mlir
%0 = "emitc.member" (%arg0) {member = "a"}
: (!emitc.opaque<"mystruct">) -> i32
: (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.lvalue<i32>
```
}];

let arguments = (ins
Arg<StrAttr, "the member to access">:$member,
EmitC_OpaqueType:$operand
EmitC_LValueOf<[EmitC_OpaqueType]>:$operand
);
let results = (outs EmitCType);
let results = (outs EmitC_LValueOf<[EmitCType]>);
}

def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
Expand All @@ -939,15 +968,16 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {

```mlir
%0 = "emitc.member_of_ptr" (%arg0) {member = "a"}
: (!emitc.ptr<!emitc.opaque<"mystruct">>) -> i32
: (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>)
-> !emitc.lvalue<i32>
```
}];

let arguments = (ins
Arg<StrAttr, "the member to access">:$member,
AnyTypeOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand
EmitC_LValueOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand
);
let results = (outs EmitCType);
let results = (outs EmitC_LValueOf<[EmitCType]>);
}

def EmitC_ConditionalOp : EmitC_Op<"conditional",
Expand Down Expand Up @@ -1031,28 +1061,29 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {

```mlir
// Integer variable
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
%0 = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>

// Variable emitted as `int32_t* = NULL;`
%1 = "emitc.variable"() {value = #emitc.opaque<"NULL">}
: () -> !emitc.ptr<!emitc.opaque<"int32_t">>
: () -> !emitc.lvalue<!emitc.ptr<!emitc.opaque<"int32_t">>>
```

Since folding is not supported, it can be used with pointers.
As an example, it is valid to create pointers to `variable` operations
by using `apply` operations and pass these to a `call` operation.
```mlir
%0 = "emitc.variable"() {value = 0 : i32} : () -> i32
%1 = "emitc.variable"() {value = 0 : i32} : () -> i32
%2 = emitc.apply "&"(%0) : (i32) -> !emitc.ptr<i32>
%3 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
%0 = "emitc.variable"() {value = 0 : i32} : () -> !emitc.lvalue<i32>
%1 = "emitc.variable"() {value = 0 : i32} : () -> !emitc.lvalue<i32>
%2 = emitc.apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
%3 = emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
emitc.call_opaque "write"(%2, %3)
: (!emitc.ptr<i32>, !emitc.ptr<i32>) -> ()
```
}];

let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
let results = (outs EmitCType);
let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>, "",
[MemAlloc<DefaultResource, 0, FullEffect>]>);

let hasVerifier = 1;
}
Expand Down Expand Up @@ -1118,11 +1149,12 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global",

```mlir
%x = emitc.get_global @foo : !emitc.array<2xf32>
%y = emitc.get_global @bar : !emitc.lvalue<i32>
```
}];

let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs EmitCType:$result);
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
}

Expand Down Expand Up @@ -1172,15 +1204,17 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {

```mlir
// Integer variable
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
%0 = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>
%1 = emitc.call_opaque "foo"() : () -> (i32)

// Assign emitted as `... = ...;`
"emitc.assign"(%0, %1) : (i32, i32) -> ()
"emitc.assign"(%0, %1) : (!emitc.lvalue<i32>, i32) -> ()
```
}];

let arguments = (ins EmitCType:$var, EmitCType:$value);
let arguments = (ins
Res<EmitC_LValueType, "", [MemWrite<DefaultResource, 1, FullEffect>]>:$var,
EmitCType:$value);
let results = (outs);

let hasVerifier = 1;
Expand Down Expand Up @@ -1276,8 +1310,10 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
%1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
%0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, index, index)
-> !emitc.lvalue<f32>
%1 = emitc.subscript %arg1[%i] : (!emitc.ptr<i32>, index)
-> !emitc.lvalue<i32>
```
}];
let arguments = (ins Arg<AnyTypeOf<[
Expand All @@ -1286,15 +1322,26 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
EmitC_PointerType]>,
"the value to subscript">:$value,
Variadic<EmitCType>:$indices);
let results = (outs EmitCType:$result);
let results = (outs EmitC_LValueType:$result);

let builders = [
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
build($_builder, $_state, array.getType().getElementType(), array, indices);
build(
$_builder,
$_state,
emitc::LValueType::get(array.getType().getElementType()),
array,
indices
);
}]>,
OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
build($_builder, $_state, pointer.getType().getPointee(), pointer,
ValueRange{index});
build(
$_builder,
$_state,
emitc::LValueType::get(pointer.getType().getPointee()),
pointer,
ValueRange{index}
);
}]>
];

Expand Down Expand Up @@ -1338,7 +1385,7 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
emitc.yield
}
default {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
%3 = "emitc.constant"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
}
```
Expand Down
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
let hasCustomAssemblyFormat = 1;
}

def EmitC_LValueType : EmitC_Type<"LValue", "lvalue"> {
let summary = "EmitC lvalue type";

let description = [{
Values of this type can be assigned to and their address can be taken.
}];

let parameters = (ins "Type":$valueType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$valueType), [{
return $_get(valueType.getContext(), valueType);
}]>
];
let assemblyFormat = "`<` qualified($valueType) `>`";
let genVerifyDecl = 1;
}

def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
let summary = "EmitC opaque type";

Expand Down Expand Up @@ -129,6 +146,7 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
}]>
];
let assemblyFormat = "`<` qualified($pointee) `>`";
let genVerifyDecl = 1;
}

def EmitC_SignedSizeT : EmitC_Type<"SignedSizeT", "ssize_t"> {
Expand Down Expand Up @@ -158,4 +176,13 @@ def EmitC_SizeT : EmitC_Type<"SizeT", "size_t"> {
}];
}

class EmitC_LValueOf<list<Type> allowedTypes> :
ContainerType<
AnyTypeOf<allowedTypes>,
CPred<"::llvm::isa<::mlir::emitc::LValueType>($_self)">,
"::llvm::cast<::mlir::emitc::LValueType>($_self).getValueType()",
"emitc.lvalue",
"::mlir::emitc::LValueType"
>;

#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES
7 changes: 1 addition & 6 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,7 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), arrayValue, operands.getIndices());

auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);

rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
rewriter.replaceOp(op, var);
rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
return success();
}
};
Expand Down
36 changes: 31 additions & 5 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,

for (OpResult result : op.getResults()) {
Type resultType = result.getType();
Type varType = emitc::LValueType::get(resultType);
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
resultVariables.push_back(var);
}

Expand All @@ -80,6 +81,14 @@ static void assignValues(ValueRange values, SmallVector<Value> &variables,
rewriter.create<emitc::AssignOp>(loc, var, value);
}

SmallVector<Value> loadValues(const SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
});
}

static void lowerYield(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, scf::YieldOp yield) {
Location loc = yield.getLoc();
Expand Down Expand Up @@ -126,15 +135,26 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(loweredBody);

SmallVector<Value> iterArgsValues =
loadValues(resultVariables, rewriter, loc);

rewriter.restoreInsertionPoint(ip);

SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(resultVariables.begin(), resultVariables.end());
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());

rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

rewriter.replaceOp(forOp, resultVariables);
// Load variables into SSA values after the for loop.
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);

rewriter.replaceOp(forOp, resultValues);
return success();
}

Expand Down Expand Up @@ -174,7 +194,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
}

rewriter.replaceOp(ifOp, resultVariables);
rewriter.setInsertionPointAfter(ifOp);
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);

rewriter.replaceOp(ifOp, results);
return success();
}

Expand Down Expand Up @@ -212,7 +235,10 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
loweredSwitch.getDefaultRegion());

rewriter.replaceOp(indexSwitchOp, resultVariables);
rewriter.setInsertionPointAfter(indexSwitchOp);
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);

rewriter.replaceOp(indexSwitchOp, results);
return success();
}

Expand Down
Loading

0 comments on commit e47b507

Please sign in to comment.