Skip to content

Commit

Permalink
Add functionality to re-create op with new variadic argument list.
Browse files Browse the repository at this point in the history
Since we cannot update varargs directly, and often need to re-create the
op, add a small helper that re-creates a given op with a new variadic
argument list and invalidates the current op.
  • Loading branch information
Thomas Symalla committed Jul 17, 2024
1 parent 8ae2128 commit 2a3a135
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 57 deletions.
5 changes: 5 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ void createFunctionExample(Module &module, const Twine &name) {
b.create<xd::WriteVarArgOp>(p2, varArgs);
b.create<xd::HandleGetOp>();

auto *replacable = b.create<xd::WriteVarArgOp>(p2, varArgs);

Check warning on line 128 in example/ExampleMain.cpp

View workflow job for this annotation

GitHub Actions / typos

"replacable" should be "replicable" or "replaceable".
SmallVector<Value *> varArgs2 = varArgs;
varArgs2.push_back(p2);

replacable->replaceArgsAndInvalidate(b, varArgs2);

Check warning on line 132 in example/ExampleMain.cpp

View workflow job for this annotation

GitHub Actions / typos

"replacable" should be "replicable" or "replaceable".
b.create<xd::SetReadOp>(FixedVectorType::get(b.getInt32Ty(), 2));
b.create<xd::SetWriteOp>(y6);

Expand Down
51 changes: 43 additions & 8 deletions lib/TableGen/Operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class AccessorBuilder final {
: m_fmt{fmt}, m_os{out}, m_arg{arg}, m_argTypeString{argTypeString} {}

void emitAccessorDefinitions() const;
void emitVarArgReplacementDefinition(const size_t numNonVarArgs) const;

private:
FmtContext &m_fmt;
Expand Down Expand Up @@ -162,10 +163,18 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
const bool isVarArg = arg.type->isVarArgList();
std::string defaultDeclaration = "$0 get$1() $2;";

if (!isVarArg && !arg.type->isImmutable()) {
defaultDeclaration += R"(
void set$1($0 $3);
)";
if (!arg.type->isImmutable()) {
if (!isVarArg) {
defaultDeclaration += R"(
void set$1($0 $3);
)";
} else {
defaultDeclaration += R"(
/// Returns a new op with the same arguments and a new tail argument list.
/// The object on which this is called will be invalidated.
$_op *replace$1AndInvalidate(::llvm_dialects::Builder &, ::llvm::ArrayRef<Value *>);
)";
}
}

out << tgfmt(defaultDeclaration, &fmt, arg.type->getGetterCppType(),
Expand All @@ -174,8 +183,11 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
}

void AccessorBuilder::emitAccessorDefinitions() const {
// We do not generate a setter for variadic arguments for now.
emitGetterDefinition();

if (m_arg.type->isImmutable())
return;

if (!m_arg.type->isVarArgList())
emitSetterDefinition();
}
Expand Down Expand Up @@ -208,9 +220,6 @@ void AccessorBuilder::emitGetterDefinition() const {
}

void AccessorBuilder::emitSetterDefinition() const {
if (m_arg.type->isImmutable())
return;

std::string toLlvm = m_arg.name;

if (auto *attr = dyn_cast<Attr>(m_arg.type)) {
Expand All @@ -228,12 +237,34 @@ void AccessorBuilder::emitSetterDefinition() const {
&m_fmt);
}

void AccessorBuilder::emitVarArgReplacementDefinition(
const size_t numNonVarArgs) const {
std::string toLlvm = m_arg.name;

m_fmt.addSubst("numNonVarargs", std::to_string(numNonVarArgs));

m_os << tgfmt(R"(
$_op *$_op::replace$0AndInvalidate(::llvm_dialects::Builder &B, ::llvm::ArrayRef<Value *> $1) {
::llvm::SmallVector<Value *> newArgs;
if ($numNonVarargs > 0)
newArgs.append(arg_begin(), arg_begin() + $numNonVarargs);
newArgs.append($1.begin(), $1.end());
$_op *newOp = cast<$_op>(B.CreateCall(getCalledFunction(), newArgs, this->getName()));
this->replaceAllUsesWith(newOp);
this->eraseFromParent();
return newOp;
})",
&m_fmt, convertToCamelFromSnakeCase(toLlvm, true), toLlvm);
}

void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
FmtContext &fmt) const {
unsigned numSuperclassArgs = 0;
if (m_superclass)
numSuperclassArgs = m_superclass->getNumFullArguments();

unsigned numArgs = 0;
for (const auto &indexedArg : llvm::enumerate(m_arguments)) {
FmtContextScope scope(fmt);

Expand All @@ -247,6 +278,10 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
fmt.addSubst("Name", convertToCamelFromSnakeCase(arg.name, true));

builder.emitAccessorDefinitions();
if (!arg.type->isImmutable() && arg.type->isVarArgList())
builder.emitVarArgReplacementDefinition(numArgs);
else
++numArgs;
}
}

Expand Down
22 changes: 22 additions & 0 deletions test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,17 @@ instName
value_op_iterator(arg_begin() + 0),
value_op_iterator(arg_end()));
}

InstNameConflictVarargsOp *InstNameConflictVarargsOp::replaceInstName_0AndInvalidate(::llvm_dialects::Builder &B, ::llvm::ArrayRef<Value *> instName_0) {
::llvm::SmallVector<Value *> newArgs;
if (0 > 0)
newArgs.append(arg_begin(), arg_begin() + 0);
newArgs.append(instName_0.begin(), instName_0.end());
InstNameConflictVarargsOp *newOp = cast<InstNameConflictVarargsOp>(B.CreateCall(getCalledFunction(), newArgs, this->getName()));
this->replaceAllUsesWith(newOp);
this->eraseFromParent();
return newOp;
}
::llvm::Value *InstNameConflictVarargsOp::getResult() {return this;}


Expand Down Expand Up @@ -2233,6 +2244,17 @@ data
value_op_iterator(arg_end()));
}

WriteVarArgOp *WriteVarArgOp::replaceArgsAndInvalidate(::llvm_dialects::Builder &B, ::llvm::ArrayRef<Value *> args) {
::llvm::SmallVector<Value *> newArgs;
if (1 > 0)
newArgs.append(arg_begin(), arg_begin() + 1);
newArgs.append(args.begin(), args.end());
WriteVarArgOp *newOp = cast<WriteVarArgOp>(B.CreateCall(getCalledFunction(), newArgs, this->getName()));
this->replaceAllUsesWith(newOp);
this->eraseFromParent();
return newOp;
}


} // namespace xd

Expand Down
100 changes: 54 additions & 46 deletions test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ uint32_t getNumElements() const;
classof(::llvm::cast<::llvm::CallInst>(v));
}
::llvm::Value * getPtr() const;
void setPtr(::llvm::Value * ptr);
::llvm::Value * getCount() const;
void setCount(::llvm::Value * count);
::llvm::Value * getInitial() const;
void setInitial(::llvm::Value * initial);

void setPtr(::llvm::Value * ptr);
::llvm::Value * getCount() const;
void setCount(::llvm::Value * count);
::llvm::Value * getInitial() const;
void setInitial(::llvm::Value * initial);
};

class Add32Op : public ::llvm::CallInst {
Expand All @@ -123,12 +123,12 @@ uint32_t getNumElements() const;
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getLhs() const;
void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);
uint32_t getExtra() const;
void setExtra(uint32_t extra);

void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);
uint32_t getExtra() const;
void setExtra(uint32_t extra);
::llvm::Value * getResult();


Expand All @@ -150,10 +150,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getLhs() const;
void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);

void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);
::llvm::Value * getResult();


Expand All @@ -175,10 +175,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getVector() const;
void setVector(::llvm::Value * vector);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);

void setVector(::llvm::Value * vector);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);
::llvm::Value * getResult();


Expand All @@ -200,8 +200,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getSource() const;
void setSource(::llvm::Value * source);

void setSource(::llvm::Value * source);
::llvm::Value * getResult();


Expand Down Expand Up @@ -244,8 +244,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getSource() const;
void setSource(::llvm::Value * source);

void setSource(::llvm::Value * source);
::llvm::Value * getResult();


Expand All @@ -267,8 +267,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getSource() const;
void setSource(::llvm::Value * source);

void setSource(::llvm::Value * source);
::llvm::Value * getResult();


Expand Down Expand Up @@ -310,12 +310,12 @@ bool getVal() const;
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getVector() const;
void setVector(::llvm::Value * vector);
::llvm::Value * getValue() const;
void setValue(::llvm::Value * value);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);

void setVector(::llvm::Value * vector);
::llvm::Value * getValue() const;
void setValue(::llvm::Value * value);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);
::llvm::Value * getResult();


Expand All @@ -337,10 +337,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getInstName() const;
void setInstName(::llvm::Value * instName);
::llvm::Value * getInstName_0() const;
void setInstName_0(::llvm::Value * instName_0);

void setInstName(::llvm::Value * instName);
::llvm::Value * getInstName_0() const;
void setInstName_0(::llvm::Value * instName_0);
::llvm::Value * getResult();


Expand All @@ -362,8 +362,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getInstName() const;
void setInstName(::llvm::Value * instName);

void setInstName(::llvm::Value * instName);
::llvm::Value * getResult();


Expand All @@ -385,6 +385,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::iterator_range<::llvm::User::value_op_iterator> getInstName_0() ;
/// Returns a new op with the same arguments and a new tail argument list.
/// The object on which this is called will be invalidated.
InstNameConflictVarargsOp *replaceInstName_0AndInvalidate(::llvm_dialects::Builder &, ::llvm::ArrayRef<Value *>);

::llvm::Value * getResult();


Expand Down Expand Up @@ -448,8 +452,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getData() const;
void setData(::llvm::Value * data);

void setData(::llvm::Value * data);


};
Expand All @@ -470,8 +474,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Type * getSizeofType() const;
void setSizeofType(::llvm::Type * sizeof_type);

void setSizeofType(::llvm::Type * sizeof_type);
::llvm::Value * getResult();


Expand Down Expand Up @@ -576,8 +580,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getData() const;
void setData(::llvm::Value * data);

void setData(::llvm::Value * data);


};
Expand All @@ -598,8 +602,12 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getData() const;
void setData(::llvm::Value * data);
::llvm::iterator_range<::llvm::User::value_op_iterator> getArgs() ;
void setData(::llvm::Value * data);
::llvm::iterator_range<::llvm::User::value_op_iterator> getArgs() ;
/// Returns a new op with the same arguments and a new tail argument list.
/// The object on which this is called will be invalidated.
WriteVarArgOp *replaceArgsAndInvalidate(::llvm_dialects::Builder &, ::llvm::ArrayRef<Value *>);



};
Expand Down
14 changes: 11 additions & 3 deletions test/example/test-builder.test
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --check-globals --include-generated-funcs
; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy.
; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s

;.
; CHECK: @[[GLOB0:.*]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
; CHECK: @str = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
;.
; CHECK-LABEL: @example(
; CHECK-NEXT: entry:
Expand All @@ -28,6 +28,7 @@
; CHECK-NEXT: call void (...) @xd.write(i8 [[P2]])
; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[TMP14:%.*]] = call target("xd.handle") @xd.handle.get()
; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]], i8 [[P2]])
; CHECK-NEXT: [[TMP15:%.*]] = call <2 x i32> @xd.set.read__v2i32()
; CHECK-NEXT: call void (...) @xd.set.write(target("xd.vector", i32, 1, 2) [[TMP13]])
; CHECK-NEXT: [[TMP16:%.*]] = call [[TMP0]] @xd.read__s_s()
Expand All @@ -45,6 +46,13 @@
; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3)
; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4)
; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0]])
; CHECK-NEXT: call void @xd.string.attr.op(ptr @str)
; CHECK-NEXT: ret void
;
;.
; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind memory(inaccessiblemem: readwrite) }
; CHECK: attributes #[[ATTR1:[0-9]+]] = { nounwind willreturn memory(none) }
; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind willreturn memory(inaccessiblemem: write) }
; CHECK: attributes #[[ATTR3:[0-9]+]] = { nounwind willreturn memory(read) }
; CHECK: attributes #[[ATTR4:[0-9]+]] = { willreturn }
;.

0 comments on commit 2a3a135

Please sign in to comment.