diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index bc20d42..063fff1 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -125,6 +125,11 @@ void createFunctionExample(Module &module, const Twine &name) { b.create(p2, varArgs); b.create(); + auto *replacable = b.create(p2, varArgs); + SmallVector varArgs2 = varArgs; + varArgs2.push_back(p2); + + replacable->replaceArgsAndInvalidate(b, varArgs2); b.create(FixedVectorType::get(b.getInt32Ty(), 2)); b.create(y6); diff --git a/lib/TableGen/Operations.cpp b/lib/TableGen/Operations.cpp index f53a09a..f4bdef0 100644 --- a/lib/TableGen/Operations.cpp +++ b/lib/TableGen/Operations.cpp @@ -64,6 +64,7 @@ class AccessorBuilder final { : m_fmt{fmt}, m_os{out}, m_arg{arg}, m_argTypeString{argTypeString} {} void emitAccessorDefinitions() const; + void emitVarArgReplacementDefinition(const size_t numNonVarArgs) const; private: FmtContext &m_fmt; @@ -162,10 +163,18 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out, const bool isVarArg = arg.type->isVarArgList(); std::string defaultDeclaration = "$0 get$1() $2;"; - if (!isVarArg && !arg.type->isImmutable()) { - defaultDeclaration += R"( - void set$1($0 $3); - )"; + if (!arg.type->isImmutable()) { + if (!isVarArg) { + defaultDeclaration += R"( + void set$1($0 $3); + )"; + } else { + defaultDeclaration += R"( + /// Returns a new op with the same arguments and a new tail argument list. + /// The object on which this is called will be invalidated. + $_op *replace$1AndInvalidate(::llvm_dialects::Builder &, ::llvm::ArrayRef); + )"; + } } out << tgfmt(defaultDeclaration, &fmt, arg.type->getGetterCppType(), @@ -174,8 +183,11 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out, } void AccessorBuilder::emitAccessorDefinitions() const { - // We do not generate a setter for variadic arguments for now. emitGetterDefinition(); + + if (m_arg.type->isImmutable()) + return; + if (!m_arg.type->isVarArgList()) emitSetterDefinition(); } @@ -208,9 +220,6 @@ void AccessorBuilder::emitGetterDefinition() const { } void AccessorBuilder::emitSetterDefinition() const { - if (m_arg.type->isImmutable()) - return; - std::string toLlvm = m_arg.name; if (auto *attr = dyn_cast(m_arg.type)) { @@ -228,12 +237,34 @@ void AccessorBuilder::emitSetterDefinition() const { &m_fmt); } +void AccessorBuilder::emitVarArgReplacementDefinition( + const size_t numNonVarArgs) const { + std::string toLlvm = m_arg.name; + + m_fmt.addSubst("numNonVarargs", std::to_string(numNonVarArgs)); + + m_os << tgfmt(R"( + + $_op *$_op::replace$0AndInvalidate(::llvm_dialects::Builder &B, ::llvm::ArrayRef $1) { + ::llvm::SmallVector newArgs; + if ($numNonVarargs > 0) + newArgs.append(arg_begin(), arg_begin() + $numNonVarargs); + newArgs.append($1.begin(), $1.end()); + $_op *newOp = cast<$_op>(B.CreateCall(getCalledFunction(), newArgs, this->getName())); + this->replaceAllUsesWith(newOp); + this->eraseFromParent(); + return newOp; + })", + &m_fmt, convertToCamelFromSnakeCase(toLlvm, true), toLlvm); +} + void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out, FmtContext &fmt) const { unsigned numSuperclassArgs = 0; if (m_superclass) numSuperclassArgs = m_superclass->getNumFullArguments(); + unsigned numArgs = 0; for (const auto &indexedArg : llvm::enumerate(m_arguments)) { FmtContextScope scope(fmt); @@ -247,6 +278,10 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out, fmt.addSubst("Name", convertToCamelFromSnakeCase(arg.name, true)); builder.emitAccessorDefinitions(); + if (!arg.type->isImmutable() && arg.type->isVarArgList()) + builder.emitVarArgReplacementDefinition(numArgs); + else + ++numArgs; } } diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index c8b2695..7ca78d2 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -1502,6 +1502,17 @@ instName value_op_iterator(arg_begin() + 0), value_op_iterator(arg_end())); } + + InstNameConflictVarargsOp *InstNameConflictVarargsOp::replaceInstName_0AndInvalidate(::llvm_dialects::Builder &B, ::llvm::ArrayRef instName_0) { + ::llvm::SmallVector newArgs; + if (0 > 0) + newArgs.append(arg_begin(), arg_begin() + 0); + newArgs.append(instName_0.begin(), instName_0.end()); + InstNameConflictVarargsOp *newOp = cast(B.CreateCall(getCalledFunction(), newArgs, this->getName())); + this->replaceAllUsesWith(newOp); + this->eraseFromParent(); + return newOp; + } ::llvm::Value *InstNameConflictVarargsOp::getResult() {return this;} @@ -2233,6 +2244,17 @@ data value_op_iterator(arg_end())); } + WriteVarArgOp *WriteVarArgOp::replaceArgsAndInvalidate(::llvm_dialects::Builder &B, ::llvm::ArrayRef args) { + ::llvm::SmallVector newArgs; + if (1 > 0) + newArgs.append(arg_begin(), arg_begin() + 1); + newArgs.append(args.begin(), args.end()); + WriteVarArgOp *newOp = cast(B.CreateCall(getCalledFunction(), newArgs, this->getName())); + this->replaceAllUsesWith(newOp); + this->eraseFromParent(); + return newOp; + } + } // namespace xd diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 913d295..cf7471d 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -99,12 +99,12 @@ uint32_t getNumElements() const; classof(::llvm::cast<::llvm::CallInst>(v)); } ::llvm::Value * getPtr() const; - void setPtr(::llvm::Value * ptr); - ::llvm::Value * getCount() const; - void setCount(::llvm::Value * count); - ::llvm::Value * getInitial() const; - void setInitial(::llvm::Value * initial); - + void setPtr(::llvm::Value * ptr); + ::llvm::Value * getCount() const; + void setCount(::llvm::Value * count); + ::llvm::Value * getInitial() const; + void setInitial(::llvm::Value * initial); + }; class Add32Op : public ::llvm::CallInst { @@ -123,12 +123,12 @@ uint32_t getNumElements() const; bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getLhs() const; - void setLhs(::llvm::Value * lhs); - ::llvm::Value * getRhs() const; - void setRhs(::llvm::Value * rhs); - uint32_t getExtra() const; - void setExtra(uint32_t extra); - + void setLhs(::llvm::Value * lhs); + ::llvm::Value * getRhs() const; + void setRhs(::llvm::Value * rhs); + uint32_t getExtra() const; + void setExtra(uint32_t extra); + ::llvm::Value * getResult(); @@ -150,10 +150,10 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getLhs() const; - void setLhs(::llvm::Value * lhs); - ::llvm::Value * getRhs() const; - void setRhs(::llvm::Value * rhs); - + void setLhs(::llvm::Value * lhs); + ::llvm::Value * getRhs() const; + void setRhs(::llvm::Value * rhs); + ::llvm::Value * getResult(); @@ -175,10 +175,10 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getVector() const; - void setVector(::llvm::Value * vector); - ::llvm::Value * getIndex() const; - void setIndex(::llvm::Value * index); - + void setVector(::llvm::Value * vector); + ::llvm::Value * getIndex() const; + void setIndex(::llvm::Value * index); + ::llvm::Value * getResult(); @@ -200,8 +200,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getSource() const; - void setSource(::llvm::Value * source); - + void setSource(::llvm::Value * source); + ::llvm::Value * getResult(); @@ -244,8 +244,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getSource() const; - void setSource(::llvm::Value * source); - + void setSource(::llvm::Value * source); + ::llvm::Value * getResult(); @@ -267,8 +267,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getSource() const; - void setSource(::llvm::Value * source); - + void setSource(::llvm::Value * source); + ::llvm::Value * getResult(); @@ -310,12 +310,12 @@ bool getVal() const; bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getVector() const; - void setVector(::llvm::Value * vector); - ::llvm::Value * getValue() const; - void setValue(::llvm::Value * value); - ::llvm::Value * getIndex() const; - void setIndex(::llvm::Value * index); - + void setVector(::llvm::Value * vector); + ::llvm::Value * getValue() const; + void setValue(::llvm::Value * value); + ::llvm::Value * getIndex() const; + void setIndex(::llvm::Value * index); + ::llvm::Value * getResult(); @@ -337,10 +337,10 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getInstName() const; - void setInstName(::llvm::Value * instName); - ::llvm::Value * getInstName_0() const; - void setInstName_0(::llvm::Value * instName_0); - + void setInstName(::llvm::Value * instName); + ::llvm::Value * getInstName_0() const; + void setInstName_0(::llvm::Value * instName_0); + ::llvm::Value * getResult(); @@ -362,8 +362,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getInstName() const; - void setInstName(::llvm::Value * instName); - + void setInstName(::llvm::Value * instName); + ::llvm::Value * getResult(); @@ -385,6 +385,10 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::iterator_range<::llvm::User::value_op_iterator> getInstName_0() ; + /// Returns a new op with the same arguments and a new tail argument list. + /// The object on which this is called will be invalidated. + InstNameConflictVarargsOp *replaceInstName_0AndInvalidate(::llvm_dialects::Builder &, ::llvm::ArrayRef); + ::llvm::Value * getResult(); @@ -448,8 +452,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getData() const; - void setData(::llvm::Value * data); - + void setData(::llvm::Value * data); + }; @@ -470,8 +474,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Type * getSizeofType() const; - void setSizeofType(::llvm::Type * sizeof_type); - + void setSizeofType(::llvm::Type * sizeof_type); + ::llvm::Value * getResult(); @@ -576,8 +580,8 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getData() const; - void setData(::llvm::Value * data); - + void setData(::llvm::Value * data); + }; @@ -598,8 +602,12 @@ bool verifier(::llvm::raw_ostream &errs); bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getData() const; - void setData(::llvm::Value * data); - ::llvm::iterator_range<::llvm::User::value_op_iterator> getArgs() ; + void setData(::llvm::Value * data); + ::llvm::iterator_range<::llvm::User::value_op_iterator> getArgs() ; + /// Returns a new op with the same arguments and a new tail argument list. + /// The object on which this is called will be invalidated. + WriteVarArgOp *replaceArgsAndInvalidate(::llvm_dialects::Builder &, ::llvm::ArrayRef); + }; diff --git a/test/example/test-builder.test b/test/example/test-builder.test index ec1e292..0afbc60 100644 --- a/test/example/test-builder.test +++ b/test/example/test-builder.test @@ -1,9 +1,9 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --check-globals --include-generated-funcs ; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy. ; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s ;. -; CHECK: @[[GLOB0:.*]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1 +; CHECK: @str = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1 ;. ; CHECK-LABEL: @example( ; CHECK-NEXT: entry: @@ -28,6 +28,7 @@ ; CHECK-NEXT: call void (...) @xd.write(i8 [[P2]]) ; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]]) ; CHECK-NEXT: [[TMP14:%.*]] = call target("xd.handle") @xd.handle.get() +; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]], i8 [[P2]]) ; CHECK-NEXT: [[TMP15:%.*]] = call <2 x i32> @xd.set.read__v2i32() ; CHECK-NEXT: call void (...) @xd.set.write(target("xd.vector", i32, 1, 2) [[TMP13]]) ; CHECK-NEXT: [[TMP16:%.*]] = call [[TMP0]] @xd.read__s_s() @@ -45,6 +46,13 @@ ; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]]) ; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3) ; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4) -; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0]]) +; CHECK-NEXT: call void @xd.string.attr.op(ptr @str) ; CHECK-NEXT: ret void ; +;. +; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind memory(inaccessiblemem: readwrite) } +; CHECK: attributes #[[ATTR1:[0-9]+]] = { nounwind willreturn memory(none) } +; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind willreturn memory(inaccessiblemem: write) } +; CHECK: attributes #[[ATTR3:[0-9]+]] = { nounwind willreturn memory(read) } +; CHECK: attributes #[[ATTR4:[0-9]+]] = { willreturn } +;.