From 669b75108350206cac0f7bffd6a7ddb2f4a20526 Mon Sep 17 00:00:00 2001 From: Thomas Symalla Date: Tue, 21 May 2024 11:16:09 +0200 Subject: [PATCH] Add support for immutable strings. We want a way to store strings in a global variable by passing in a StringRef to an op, but we don't want to generate a setter for it, since we currently don't inject the builder into the setter. So, add an immutable string type based on the existing isImmutable option for attributes. --- example/ExampleDialect.td | 10 ++ example/ExampleMain.cpp | 5 + include/llvm-dialects/Dialect/Dialect.td | 9 ++ test/example/generated/ExampleDialect.cpp.inc | 118 +++++++++++++++--- test/example/generated/ExampleDialect.h.inc | 20 +++ test/example/test-builder.test | 6 +- test/example/visitor-basic.ll | 6 +- 7 files changed, 154 insertions(+), 20 deletions(-) diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 616a73a..3bc6761 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -317,3 +317,13 @@ def ImmutableOp : Op { Make an argument immutable }]; } + +def StringAttrOp : Op { + let results = (outs); + let arguments = (ins ImmutableStringAttr:$val); + + let summary = "demonstrate an argument that takes in a StringRef"; + let description = [{ + The argument should not have a setter method + }]; +} diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index ca01caa..2346fb6 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -141,6 +141,8 @@ void createFunctionExample(Module &module, const Twine &name) { moreVarArgs.push_back(b.getInt32(4)); b.create(moreVarArgs, "four.varargs"); + b.create("Hello world!"); + b.CreateRetVoid(); } @@ -242,6 +244,9 @@ template const Visitor &getExampleVisitor() { for (Value *arg : op.getArgs()) out << " " << *arg << '\n'; }); + b.add([](raw_ostream &out, xd::StringAttrOp &op) { + out << "visiting StringAttrOp: " << op.getVal() << '\n'; + }); b.add([](raw_ostream &out, ReturnInst &ret) { out << "visiting ReturnInst: " << ret << '\n'; }); diff --git a/include/llvm-dialects/Dialect/Dialect.td b/include/llvm-dialects/Dialect/Dialect.td index 7e3c949..fd4aa27 100644 --- a/include/llvm-dialects/Dialect/Dialect.td +++ b/include/llvm-dialects/Dialect/Dialect.td @@ -279,6 +279,15 @@ def : AttrLlvmType; def : AttrLlvmType; def : AttrLlvmType; +def ImmutableStringAttr : Attr<"::llvm::StringRef"> { + let toLlvmValue = [{ $_builder.CreateGlobalString($0) }]; + let fromLlvmValue = [{ ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>($0)->getInitializer())->getAsString() }]; + let isImmutable = true; +} + +// Global string variables are essentially pointers in addrspace(0). +def : AttrLlvmType; + // ============================================================================ /// More general attributes // ============================================================================ diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 138817b..3cedb67 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -129,6 +129,11 @@ namespace xd { state.setError(); }); + builder.add([](::llvm_dialects::VerifierState &state, StringAttrOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + builder.add([](::llvm_dialects::VerifierState &state, WriteOp &op) { if (!op.verifier(state.out())) state.setError(); @@ -154,21 +159,21 @@ namespace xd { ::llvm::AttrBuilder attrBuilder{context}; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod)); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref)); m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } { ::llvm::AttrBuilder attrBuilder{context}; attrBuilder.addAttribute(::llvm::Attribute::NoUnwind); attrBuilder.addAttribute(::llvm::Attribute::WillReturn); -attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none()); +attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod)); m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder); } { @@ -329,7 +334,7 @@ return true; const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), { lhs->getType(), rhs->getType(), @@ -451,7 +456,7 @@ uint32_t const extra = getExtra(); const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {lhs->getType()}); @@ -546,7 +551,7 @@ rhs const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {::llvm::cast(vector->getType())->getElementType()}); @@ -650,7 +655,7 @@ index const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -820,7 +825,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), { }, false); @@ -882,7 +887,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -980,7 +985,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {resultType}); @@ -1147,7 +1152,7 @@ source const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {vector->getType()}); @@ -1613,7 +1618,7 @@ instName const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1676,7 +1681,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(2); + = ExampleDialect::get(context).getAttributeList(0); auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -1750,7 +1755,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -1842,7 +1847,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -1934,7 +1939,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(0); + = ExampleDialect::get(context).getAttributeList(1); std::string mangledName = ::llvm_dialects::getMangledName(s_name, {initial->getType()}); @@ -2017,6 +2022,75 @@ initial + const ::llvm::StringLiteral StringAttrOp::s_name{"xd.string.attr.op"}; + + StringAttrOp* StringAttrOp::create(llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(4); + auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), { +::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0), +}, false); + + auto fn = module.getOrInsertFunction(s_name, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << s_name << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + +::llvm::SmallVector<::llvm::Value*, 1> args = { + b.CreateGlobalString(val) + }; + + return ::llvm::cast(b.CreateCall(fn, args, instName)); + } + + + bool StringAttrOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 1) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 1\n"; + return false; + } + + if (getArgOperand(0)->getType() != ::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0)) { + errs << " argument 0 (val) has type: " + << *getArgOperand(0)->getType() << '\n'; + errs << " expected: " << *::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0) << '\n'; + return false; + } + ::llvm::StringRef const val = getVal(); +(void)val; + return true; +} + + + ::llvm::StringRef StringAttrOp::getVal() { + return ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>(getArgOperand(0))->getInitializer())->getAsString() ; + } + + + const ::llvm::StringLiteral WriteOp::s_name{"xd.write"}; WriteOp* WriteOp::create(llvm_dialects::Builder& b, ::llvm::Value * data, const llvm::Twine &instName) { @@ -2026,7 +2100,7 @@ initial const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -2089,7 +2163,7 @@ data const ::llvm::AttributeList attrs - = ExampleDialect::get(context).getAttributeList(1); + = ExampleDialect::get(context).getAttributeList(2); auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); auto fn = module.getOrInsertFunction(s_name, fnType, attrs); @@ -2303,6 +2377,14 @@ data } + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{false, "xd.string.attr.op"}; + return desc; + } + + template <> const ::llvm_dialects::OpDescription & ::llvm_dialects::OpDescription::get() { diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 0b29e7d..d6b2a5f 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -535,6 +535,26 @@ bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getResult(); + }; + + class StringAttrOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.string.attr.op"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isSimpleOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static StringAttrOp* create(::llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName = ""); + +bool verifier(::llvm::raw_ostream &errs); + +::llvm::StringRef getVal(); + + }; class WriteOp : public ::llvm::CallInst { diff --git a/test/example/test-builder.test b/test/example/test-builder.test index baf29e0..eef0f91 100644 --- a/test/example/test-builder.test +++ b/test/example/test-builder.test @@ -1,7 +1,10 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals ; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy. ; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s +;. +; CHECK: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1 +;. ; CHECK-LABEL: @example( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.read__i32() @@ -42,5 +45,6 @@ ; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]]) ; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3) ; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4) +; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0:[0-9]+]]) ; CHECK-NEXT: ret void ; diff --git a/test/example/visitor-basic.ll b/test/example/visitor-basic.ll index f173888..eaeea9b 100644 --- a/test/example/visitor-basic.ll +++ b/test/example/visitor-basic.ll @@ -16,10 +16,13 @@ ; DEFAULT-NEXT: %v2 = ; DEFAULT-NEXT: %q = ; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) +; DEFAULT-NEXT: visiting StringAttrOp: Hello world! ; DEFAULT-NEXT: visiting Ret (set): ret void ; DEFAULT-NEXT: visiting ReturnInst: ret void ; DEFAULT-NEXT: inner.counter = 1 +@0 = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1 + define void @test1(ptr %p) { entry: %v = call i32 @xd.read__i32() @@ -36,6 +39,7 @@ entry: call void (...) @xd.set.write(i8 %v.2) call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q) %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) + call void @xd.string.attr.op(ptr @0) ret void } @@ -46,6 +50,6 @@ declare void @xd.write(...) declare void @xd.set.write(...) declare void @xd.write.vararg(...) declare i8 @xd.itrunc__i8(...) - +declare void @xd.string.attr.op(ptr) declare i32 @llvm.umax.i32(i32, i32) declare i32 @llvm.umin.i32(i32, i32)