From fa89c71c9e3ee209aa0a266f0eda140a34649fbc Mon Sep 17 00:00:00 2001 From: Thomas Symalla Date: Mon, 17 Jul 2023 10:35:49 +0200 Subject: [PATCH] Simplify operating on sets and maps of dialect ops. This commit introduces a new type, `OpSet`, that simplifies membership tests of an instruction on a set of dialect operations. With this PR, applying a callback on a given set of operations also becomes simpler since a visitor function `addSet` is introduced. It also introduces a map-like data structure, `OpMap`, which can be used to simplify handling of association of dialect operations to certain values. With this PR, a new check-llvm-dialects-units target is added which tests the new GTest ADT tests as well. --- docker/dialects.Dockerfile | 4 + example/ExampleDialect.td | 26 + example/ExampleMain.cpp | 32 + include/llvm-dialects/Dialect/Dialect.h | 13 +- include/llvm-dialects/Dialect/OpDescription.h | 31 +- include/llvm-dialects/Dialect/OpMap.h | 768 ++++++++++++++++++ include/llvm-dialects/Dialect/OpSet.h | 218 +++++ include/llvm-dialects/Dialect/Visitor.h | 56 +- lib/Dialect/Dialect.cpp | 63 +- lib/Dialect/OpDescription.cpp | 7 + lib/Dialect/Visitor.cpp | 121 +-- test/CMakeLists.txt | 2 + test/example/generated/ExampleDialect.cpp.inc | 146 ++++ test/example/generated/ExampleDialect.h.inc | 43 + test/example/test-builder.test | 14 +- test/example/visitor-basic.ll | 12 + test/unit/CMakeLists.txt | 39 + test/unit/dialect/CMakeLists.txt | 12 + test/unit/dialect/TestDialect.cpp | 32 + test/unit/dialect/TestDialect.h | 32 + test/unit/dialect/TestDialect.td | 67 ++ test/unit/interface/CMakeLists.txt | 6 + test/unit/interface/OpMapIRTests.cpp | 254 ++++++ test/unit/interface/OpMapTests.cpp | 258 ++++++ test/unit/interface/OpSetTests.cpp | 114 +++ test/unit/lit.cfg.py | 60 ++ test/unit/lit.site.cfg.py.in | 22 + 27 files changed, 2324 insertions(+), 128 deletions(-) create mode 100644 include/llvm-dialects/Dialect/OpMap.h create mode 100644 include/llvm-dialects/Dialect/OpSet.h create mode 100644 test/unit/CMakeLists.txt create mode 100644 test/unit/dialect/CMakeLists.txt create mode 100644 test/unit/dialect/TestDialect.cpp create mode 100644 test/unit/dialect/TestDialect.h create mode 100644 test/unit/dialect/TestDialect.td create mode 100644 test/unit/interface/CMakeLists.txt create mode 100644 test/unit/interface/OpMapIRTests.cpp create mode 100644 test/unit/interface/OpMapTests.cpp create mode 100644 test/unit/interface/OpSetTests.cpp create mode 100644 test/unit/lit.cfg.py create mode 100644 test/unit/lit.site.cfg.py.in diff --git a/docker/dialects.Dockerfile b/docker/dialects.Dockerfile index 5e4a8a9..359099c 100644 --- a/docker/dialects.Dockerfile +++ b/docker/dialects.Dockerfile @@ -41,3 +41,7 @@ RUN source /vulkandriver/env.sh \ # Run the lit test suite. RUN source /vulkandriver/env.sh \ && cmake --build . --target check-llvm-dialects -- -v + +# Run the unit tests suite. +RUN source /vulkandriver/env.sh \ + && cmake --build . --target check-llvm-dialects-units -v diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index f2ce3cf..58fb84e 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -101,6 +101,32 @@ def WriteVarArgOp : ExampleOp<"write.vararg", Longer description of how this operation writes pieces of data. }]; } + +def SetReadOp : ExampleOp<"set.read", + [Memory<[(readwrite InaccessibleMem)]>, NoUnwind]> { + let results = (outs value:$data); + let arguments = (ins); + + let defaultBuilderHasExplicitResultType = true; + + let summary = "read a piece of data"; + let description = [{ + Longer description of how this operation reads a piece of data. + }]; +} + +def SetWriteOp : ExampleOp<"set.write", + [Memory<[(write InaccessibleMem)]>, NoUnwind, + WillReturn]> { + let results = (outs); + let arguments = (ins value:$data); + + let summary = "write a data element"; + let description = [{ + Longer description of how this operation writes pieces of data. + }]; +} + def CombineOp : ExampleOp<"combine", [Memory<[]>, NoUnwind, WillReturn]> { let results = (outs value:$result); diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index 3238129..cc3b97e 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -26,6 +26,8 @@ #include "ExampleDialect.h" #include "llvm-dialects/Dialect/Builder.h" +#include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm-dialects/Dialect/OpSet.h" #include "llvm-dialects/Dialect/Verifier.h" #include "llvm/AsmParser/Parser.h" @@ -121,6 +123,9 @@ void createFunctionExample(Module &module, const Twine &name) { b.create(p2, varArgs); b.create(); + b.create(FixedVectorType::get(b.getInt32Ty(), 2)); + b.create(y6); + useUnnamedStructTypes(b); b.CreateRetVoid(); @@ -166,12 +171,39 @@ LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorNest, inner) // i.e. if C++ had a strong enough compile-time evaluation (constexpr), it // should be possible to evaluate the initialization entirely at compile-time. template const Visitor &getExampleVisitor() { + static const auto complexSet = OpSet::fromOpDescriptions( + {OpDescription::fromCoreOp(Instruction::Ret), + OpDescription::fromIntrinsic(Intrinsic::umin)}); + static const auto visitor = VisitorBuilder() .nest([](VisitorBuilder &b) { b.add([](VisitorNest &self, xd::ReadOp &op) { *self.out << "visiting ReadOp: " << op << '\n'; }); + b.addSet( + [](VisitorNest &self, llvm::Instruction &op) { + if (isa(op)) { + *self.out << "visiting SetReadOp (set): " << op << '\n'; + } else if (isa(op)) { + *self.out << "visiting SetWriteOp (set): " << op << '\n'; + } + }); + b.addSet(complexSet, [](VisitorNest &self, llvm::Instruction &op) { + assert((op.getOpcode() == Instruction::Ret || + (isa)(&op) && + cast(&op)->getIntrinsicID() == + Intrinsic::umin) && + "Unexpected operation detected while visiting OpSet!"); + + if (op.getOpcode() == Instruction::Ret) { + *self.out << "visiting Ret (set): " << op << '\n'; + } else if (auto *II = dyn_cast(&op)) { + if (II->getIntrinsicID() == Intrinsic::umin) { + *self.out << "visiting umin (set): " << op << '\n'; + } + } + }); b.add( [](VisitorNest &self, UnaryInstruction &inst) { *self.out << "visiting UnaryInstruction: " << inst << '\n'; diff --git a/include/llvm-dialects/Dialect/Dialect.h b/include/llvm-dialects/Dialect/Dialect.h index dc97d48..f3b92ef 100644 --- a/include/llvm-dialects/Dialect/Dialect.h +++ b/include/llvm-dialects/Dialect/Dialect.h @@ -245,11 +245,14 @@ class DialectExtensionRegistration { namespace detail { -bool isSimpleOperationDecl(const llvm::Function *fn, llvm::StringRef name); -bool isOverloadedOperationDecl(const llvm::Function *fn, llvm::StringRef name); - -bool isSimpleOperation(const llvm::CallInst *i, llvm::StringRef name); -bool isOverloadedOperation(const llvm::CallInst *i, llvm::StringRef name); +bool isSimpleOperationDecl(const llvm::Function *fn, llvm::StringRef mnemonic); +bool isOverloadedOperationDecl(const llvm::Function *fn, + llvm::StringRef mnemonic); + +bool isSimpleOperation(const llvm::CallInst *i, llvm::StringRef mnemonic); +bool isOverloadedOperation(const llvm::CallInst *i, llvm::StringRef mnemonic); +bool isOperationDecl(llvm::StringRef fn, bool isOverloaded, + llvm::StringRef mnemonic); } // namespace detail diff --git a/include/llvm-dialects/Dialect/OpDescription.h b/include/llvm-dialects/Dialect/OpDescription.h index 08b440c..1a19a21 100644 --- a/include/llvm-dialects/Dialect/OpDescription.h +++ b/include/llvm-dialects/Dialect/OpDescription.h @@ -20,7 +20,6 @@ #include "llvm/ADT/StringRef.h" #include -#include namespace llvm { class Function; @@ -40,19 +39,45 @@ class OpDescription { }; public: + OpDescription() = default; OpDescription(bool hasOverloads, llvm::StringRef mnemonic) : m_kind(hasOverloads ? Kind::DialectWithOverloads : Kind::Dialect), m_op(mnemonic) {} OpDescription(Kind kind, unsigned opcode) : m_kind(kind), m_op(opcode) {} OpDescription(Kind kind, llvm::MutableArrayRef opcodes); - template - static const OpDescription& get(); + static OpDescription fromCoreOp(unsigned op) { return {Kind::Core, op}; } + + static OpDescription fromIntrinsic(unsigned op) { + return {Kind::Intrinsic, op}; + } + + static OpDescription fromDialectOp(bool hasOverloads, + llvm::StringRef mnemonic) { + return {hasOverloads, mnemonic}; + } + + bool isCoreOp() const { return m_kind == Kind::Core; } + bool isIntrinsic() const { return m_kind == Kind::Intrinsic; } + bool isDialectOp() const { + return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads; + } + + template static const OpDescription &get(); Kind getKind() const { return m_kind; } + + unsigned getOpcode() const; + llvm::ArrayRef getOpcodes() const; + llvm::StringRef getMnemonic() const { + assert(m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads); + return std::get(m_op); + } + bool matchInstruction(const llvm::Instruction &inst) const; + bool matchDeclaration(const llvm::Function &decl) const; bool canMatchDeclaration() const { diff --git a/include/llvm-dialects/Dialect/OpMap.h b/include/llvm-dialects/Dialect/OpMap.h new file mode 100644 index 0000000..180f097 --- /dev/null +++ b/include/llvm-dialects/Dialect/OpMap.h @@ -0,0 +1,768 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#pragma once + +#include "llvm-dialects/Dialect/Dialect.h" +#include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/IntrinsicInst.h" + +#include +#include + +using namespace llvm; +using namespace llvm_dialects; + +namespace { + +using DialectOpKey = std::pair; + +class DialectOpKVUtils { +public: + static DialectOpKey getDialectMapKey(const OpDescription &desc) { + return {desc.getMnemonic(), + desc.getKind() == OpDescription::Kind::DialectWithOverloads}; + } +}; + +template struct DialectOpKV final { + DialectOpKey Key; + ValueT Value; + + bool operator==(const DialectOpKV &other) const { + return Key.first == other.Key.first && Key.second == other.Key.second && + Value == other.Value; + } + + bool operator==(const OpDescription &desc) const { + const bool isOverload = + desc.getKind() == OpDescription::Kind::DialectWithOverloads; + return Key.first == desc.getMnemonic() && Key.second == isOverload; + } +}; +} // namespace + +namespace llvm_dialects { + +// Forward declarations. +template class OpMap; + +template class OpMapIteratorBase; + +// OpMap implements a map-like container that can store core opcodes, +// intrinsics and dialect operations. It provides some lookup functionality for +// these kinds of operations, OpDescriptions and functions as well as +// instructions to simplify working with dialects in scenarios requiring +// association of dialect operations with certain values. +template class OpMap final { + // We don't care about the value type in the @reserve member function, + // thus we need to make the inners of OpMaps of arbitrary value type + // accessible to a given OpMap instance. + template friend class OpMap; + + friend class OpMapIteratorBase; + friend class OpMapIteratorBase; + + using DialectOpKVT = DialectOpKV; + +public: + using iterator = OpMapIteratorBase; + using const_iterator = OpMapIteratorBase; + + OpMap() = default; + + // -------------------------------------------------------------------------- + // Convenience constructor to initialize the OpMap from a set of + // OpDescription/Value pairs. + // -------------------------------------------------------------------------- + OpMap(std::initializer_list> vals) { + for (const std::pair &val : vals) + insert(val.first, val.second); + } + + // -------------------------------------------------------------------------- + // Comparison operator overloads. + // -------------------------------------------------------------------------- + bool operator==(const OpMap &rhs) const { + if (m_dialectOps.size() != rhs.m_dialectOps.size()) + return false; + + if (m_coreOpcodes == rhs.m_coreOpcodes && + m_intrinsics == rhs.m_intrinsics) { + // Do a lookup for each vector entry, since both LHS and RHS potentially + // are in different order. + for (const auto &dialectOp : rhs.m_dialectOps) { + if (std::find(m_dialectOps.begin(), m_dialectOps.end(), dialectOp) == + m_dialectOps.end()) + return false; + } + + return true; + } + + return false; + } + + bool operator!=(const OpMap &rhs) const { return !(*this == rhs); } + + // -------------------------------------------------------------------------- + // contains checks if a given op is contained in any of the + // internal map containers. + // -------------------------------------------------------------------------- + + bool containsCoreOp(unsigned op) const { return m_coreOpcodes.contains(op); } + + bool containsIntrinsic(unsigned op) const { + return m_intrinsics.contains(op); + } + + // Check if the map contains an OpDescription created for a given dialect + // operation type. + template bool contains() const { + static OpDescription desc = OpDescription::get(); + return contains(desc); + } + + // Check if the map contains an op described by an OpDescription. + bool contains(const OpDescription &desc) const { + if (desc.isCoreOp() || desc.isIntrinsic()) { + assert(desc.getOpcodes().size() == 1 && + "OpMap only supports querying of single core opcodes and " + "intrinsics."); + + const unsigned op = desc.getOpcode(); + return (desc.isCoreOp() && containsCoreOp(op)) || + (desc.isIntrinsic() && containsIntrinsic(op)); + } + + for (const DialectOpKVT &dialectOpKV : m_dialectOps) { + if (dialectOpKV == desc) + return true; + } + + return false; + } + + // -------------------------------------------------------------------------- + // find returns an iterator that contains info about elements from one of the + // internal map containers. + // -------------------------------------------------------------------------- + + // A simple DSL to simplify generating some of the find() overloads + +#define GENERATE_FIND_BODY(iterator_t) \ + { \ + if (empty()) \ + return end(); \ + iterator_t it(this, arg); \ + if (it) \ + return it; \ + return end(); \ + } + +#define FIND_OVERLOAD(arg_t) \ + iterator find(arg_t &arg) GENERATE_FIND_BODY(iterator) + +#define FIND_CONST_OVERLOAD(arg_t) \ + const_iterator find(const arg_t &arg) const GENERATE_FIND_BODY(const_iterator) + + FIND_OVERLOAD(OpDescription) + FIND_CONST_OVERLOAD(OpDescription) + FIND_OVERLOAD(Function) + FIND_CONST_OVERLOAD(Function) + FIND_OVERLOAD(Instruction) + FIND_CONST_OVERLOAD(Instruction) + +#undef FIND_CONST_OVERLOAD +#undef FIND_OVERLOAD +#undef GENERATE_FIND_BODY + + // -------------------------------------------------------------------------- + // Convenience getter definition. + // -------------------------------------------------------------------------- + + ValueT &operator[](const OpDescription &desc) { + auto [it, inserted] = insert(desc, {}); + return (*it).second; + } + + // -------------------------------------------------------------------------- + // lookup tries to find whether a given function or instruction + // can be mapped to any of the entries in the internal map + // containers. + // It returns either a default-constructed object if the key + // was not found or a copy of the contained value. + // -------------------------------------------------------------------------- + + // Try to lookup a function which is either the callee of an intrinsic call + // or a dialect operation. + ValueT lookup(const Function &func) const { + auto it = find(func); + if (auto val = it.val(); val) + return *val; + + return {}; + } + + // Try to lookup an instruction which is either an intrinsic instruction, + // a dialect operation or a core instruction. + ValueT lookup(const Instruction &inst) const { + auto it = find(inst); + if (auto val = it.val(); val) + return *val; + + return {}; + } + + // -------------------------------------------------------------------------- + // Try to construct a value in-place for a given OpDescription + // and returns a pair which consists of the internal OpMapIterator and a + // boolean return value, marking if the value was inserted or not. If the + // OpDescription was already in the internal data structures, nothing will be + // changed. + // -------------------------------------------------------------------------- + template + std::pair try_emplace(const OpDescription &desc, + Ts &&... vals) { + auto found = find(const_cast(desc)); + if (found) + return {found, false}; + + if (desc.isCoreOp() || desc.isIntrinsic()) { + assert(desc.getOpcodes().size() == 1 && + "OpMap: Can only emplace a single op at a time."); + + const unsigned op = desc.getOpcode(); + if (desc.isCoreOp()) { + auto [it, inserted] = + m_coreOpcodes.try_emplace(op, std::forward(vals)...); + return {makeIterator(it, OpDescription::Kind::Core, false), inserted}; + } + + auto [it, inserted] = + m_intrinsics.try_emplace(op, std::forward(vals)...); + return {makeIterator(it, OpDescription::Kind::Intrinsic, false), + inserted}; + } + + // Find the iterator into the dialect ops. + size_t Idx = 0; + for (DialectOpKVT &dialectOpKV : m_dialectOps) { + if (dialectOpKV == desc) { + auto it = m_dialectOps.begin(); + std::advance(it, Idx); + + return {makeIterator(it, OpDescription::Kind::Dialect, false), false}; + } + + ++Idx; + } + + // If the entry doesn't exist, construct it and return an iterator to the + // end of dialect ops. + auto it = m_dialectOps.insert( + m_dialectOps.end(), + {DialectOpKVUtils::getDialectMapKey(desc), std::forward(vals)...}); + return {makeIterator(it, OpDescription::Kind::Dialect, false), true}; + } + + template std::pair insert(const ValueT &val) { + const OpDescription desc = OpDescription::get(); + return try_emplace(desc, val); + } + + template std::pair insert(ValueT &&val) { + const OpDescription desc = OpDescription::get(); + return try_emplace(desc, std::move(val)); + } + + std::pair insert(const OpDescription &desc, + const ValueT &val) { + return try_emplace(desc, val); + } + + std::pair insert(const OpDescription &desc, ValueT &&val) { + return try_emplace(desc, std::move(val)); + } + + // -------------------------------------------------------------------------- + // Erase a given operation from the correct container. + // -------------------------------------------------------------------------- + + // Erase a given dialect operation. + template bool erase() { + const OpDescription desc = OpDescription::get(); + return erase(const_cast(desc)); + } + + // Erase all the operations described by a given OpDescription. + bool erase(OpDescription &desc) { + iterator it = find(desc); + if (!it) + return false; + + return it.erase(); + } + + // -------------------------------------------------------------------------- + // Reserve a given number of elements for the maps. + // -------------------------------------------------------------------------- + void reserve(size_t numCoreOps, size_t numIntrinsics, size_t numDialectOps) { + m_coreOpcodes.reserve(numCoreOps); + m_intrinsics.reserve(numIntrinsics); + m_dialectOps.reserve(numDialectOps); + } + + template void reserve(const OpMap &other) { + m_coreOpcodes.reserve(other.m_coreOpcodes.size()); + m_intrinsics.reserve(other.m_intrinsics.size()); + m_dialectOps.reserve(other.m_dialectOps.size()); + } + + // -------------------------------------------------------------------------- + // Convenience helpers. + // -------------------------------------------------------------------------- + size_t size() const { + return m_coreOpcodes.size() + m_intrinsics.size() + m_dialectOps.size(); + } + + bool empty() const { + return m_coreOpcodes.empty() && m_intrinsics.empty() && + m_dialectOps.empty(); + } + + // -------------------------------------------------------------------------- + // Iterator definitions. + // -------------------------------------------------------------------------- + +#define GENERATE_ITERATOR_BODY(iterator_provider, name, isInvalid) \ + { \ + if (empty() && !isInvalid) \ + return end(); \ + if (!m_coreOpcodes.empty()) \ + return iterator_provider(m_coreOpcodes.name(), \ + OpDescription::Kind::Core, isInvalid); \ + if (!m_intrinsics.empty()) \ + return iterator_provider(m_intrinsics.name(), \ + OpDescription::Kind::Intrinsic, isInvalid); \ + return iterator_provider(m_dialectOps.name(), \ + OpDescription::Kind::Dialect, isInvalid); \ + } + +#define DEFINE_NONCONST_ITERATOR(name, isInvalid) \ + inline iterator name() GENERATE_ITERATOR_BODY(makeIterator, name, isInvalid) + +#define DEFINE_CONST_ITERATOR(name, isInvalid) \ + inline const_iterator name() \ + const GENERATE_ITERATOR_BODY(makeConstIterator, name, isInvalid) + + DEFINE_NONCONST_ITERATOR(begin, false) + DEFINE_NONCONST_ITERATOR(end, true) + DEFINE_CONST_ITERATOR(begin, false) + DEFINE_CONST_ITERATOR(end, true) + +#undef DEFINE_NONCONST_ITERATOR +#undef DEFINE_CONST_ITERATOR +#undef GENERATE_ITERATOR_BODY + +private: + DenseMap m_coreOpcodes; + DenseMap m_intrinsics; + SmallVector m_dialectOps; + + template iterator makeIterator(Args &&... args) { + return iterator(this, std::forward(args)...); + } + + template + const_iterator makeConstIterator(Args &&... args) const { + return const_iterator(this, std::forward(args)...); + } +}; + +/// A simple iterator operating on the internal data structures of the OpMap. It +/// uses separate storage and stores pointers to the elements of the internal +/// data structures. +/// It should be used with caution, since the iterators get invalidated after +/// inserting or erasing an element. +/// Note that iterating over an OpMap instance never guarantees the order of +/// insertion. +template class OpMapIteratorBase final { + using BaseIteratorT = + std::conditional_t::const_iterator, + typename DenseMap::iterator>; + using DialectOpIteratorT = std::conditional_t< + isConst, typename SmallVectorImpl>::const_iterator, + typename SmallVectorImpl>::iterator>; + + using InternalValueT = std::conditional_t; + + using OpMapT = + std::conditional_t, OpMap>; + + friend class OpMap; + friend class OpMapIteratorBase; + friend class OpMapIteratorBase; + + class OpMapIteratorState final { + OpMapIteratorBase &m_iterator; + + enum class IteratorState : uint8_t { + CoreOp, + Intrinsic, + DialectOp, + Invalid + }; + + bool isCoreOp() const { return m_iterator.m_desc.isCoreOp(); } + + bool isIntrinsic() const { return m_iterator.m_desc.isIntrinsic(); } + + bool isDialectOp() const { return m_iterator.m_desc.isDialectOp(); } + + IteratorState computeCurrentState() { + const auto isValidIterator = [&](auto it, auto endIt) -> bool { + return it != endIt; + }; + + if (isCoreOp() && + isValidIterator(std::get(m_iterator.m_iterator), + m_iterator.m_map->m_coreOpcodes.end())) { + return IteratorState::CoreOp; + } + + if (isIntrinsic() && + isValidIterator(std::get(m_iterator.m_iterator), + m_iterator.m_map->m_intrinsics.end())) { + return IteratorState::Intrinsic; + } + + if (isDialectOp() && + isValidIterator(std::get(m_iterator.m_iterator), + m_iterator.m_map->m_dialectOps.end())) { + return IteratorState::DialectOp; + } + + return IteratorState::Invalid; + } + + // Compute a possible next state after iteration. + IteratorState computeNextState(IteratorState currentState) { + IteratorState nextState = currentState; + + if (nextState == IteratorState::CoreOp || + nextState == IteratorState::Intrinsic) { + auto peek = std::get(m_iterator.m_iterator); + std::advance(peek, 1); + + if (nextState == IteratorState::CoreOp) { + if (peek == m_iterator.m_map->m_coreOpcodes.end()) { + if (!m_iterator.m_map->m_intrinsics.empty()) + return IteratorState::Intrinsic; + + nextState = IteratorState::DialectOp; + } + } + + if (nextState == IteratorState::Intrinsic) { + if (peek == m_iterator.m_map->m_intrinsics.end()) { + if (!m_iterator.m_map->m_dialectOps.empty()) + return IteratorState::DialectOp; + + return IteratorState::Invalid; + } + } + } + + if (nextState == IteratorState::DialectOp) { + auto peek = std::get(m_iterator.m_iterator); + std::advance(peek, 1); + if (peek != m_iterator.m_map->m_dialectOps.end()) + return IteratorState::DialectOp; + + return IteratorState::Invalid; + } + + return nextState; + } + + public: + OpMapIteratorState(OpMapIteratorBase &iterator) : m_iterator{iterator} {} + + void step() { + auto currentState = computeCurrentState(); + auto nextState = computeNextState(currentState); + + if (currentState == nextState) { + switch (currentState) { + case IteratorState::CoreOp: + case IteratorState::Intrinsic: { + auto &it = std::get(m_iterator.m_iterator); + ++it; + if (currentState == IteratorState::CoreOp) + m_iterator.m_desc = OpDescription::fromCoreOp(it->first); + else + m_iterator.m_desc = OpDescription::fromIntrinsic(it->first); + + break; + } + case IteratorState::DialectOp: { + auto &it = std::get(m_iterator.m_iterator); + ++it; + + m_iterator.m_desc = {it->Key.second, it->Key.first}; + break; + } + + case IteratorState::Invalid: + m_iterator.invalidate(); + break; + } + } else { + transitionTo(nextState); + } + } + + void transitionTo(IteratorState nextState) { + if (nextState == IteratorState::Intrinsic) { + auto newIt = m_iterator.m_map->m_intrinsics.begin(); + m_iterator.m_iterator = newIt; + + m_iterator.m_desc = OpDescription::fromIntrinsic(newIt->first); + } else if (nextState == IteratorState::DialectOp) { + auto newIt = m_iterator.m_map->m_dialectOps.begin(); + m_iterator.m_iterator = newIt; + + m_iterator.m_desc = {newIt->Key.second, newIt->Key.first}; + } else { + m_iterator.invalidate(); + } + } + }; + + OpMapIteratorBase(OpMapT *map, + std::variant it, + OpDescription::Kind kind, bool isInvalid = false) + : m_map{map}, m_iterator{it}, m_isInvalid{isInvalid} { + if (m_map->empty()) { + invalidate(); + return; + } + + refreshOpDescriptor(kind); + } + + OpMapIteratorBase(OpMapT *map, const OpDescription &desc) + : m_map{map}, m_desc{desc} { + if (desc.isCoreOp() || desc.isIntrinsic()) { + assert(desc.getOpcodes().size() == 1 && + "OpMapIterator only supports querying of single core opcodes and " + "intrinsics."); + + const unsigned op = desc.getOpcode(); + + if (desc.isCoreOp()) { + m_iterator = map->m_coreOpcodes.find(op); + if (std::get(m_iterator) == map->m_coreOpcodes.end()) + invalidate(); + } else { + m_iterator = map->m_intrinsics.find(op); + if (std::get(m_iterator) == map->m_intrinsics.end()) + invalidate(); + } + } else { + createFromDialectOp(desc.getMnemonic()); + } + } + + OpMapIteratorBase(OpMapT *map, const Function &func) : m_map{map} { + createFromFunc(func); + } + + // Do a lookup for a given instruction. Mark the iterator as invalid + // if the instruction is a call-like core instruction. + OpMapIteratorBase(OpMapT *map, const Instruction &inst) : m_map{map} { + if (auto *CI = dyn_cast(&inst)) { + const Function *callee = CI->getCalledFunction(); + if (callee) { + createFromFunc(*callee); + return; + } + } + + const unsigned op = inst.getOpcode(); + + // Construct an invalid iterator. + if (op == Instruction::Call || op == Instruction::CallBr) { + invalidate(); + return; + } + + BaseIteratorT it = m_map->m_coreOpcodes.find(op); + if (it != m_map->m_coreOpcodes.end()) { + m_desc = OpDescription::fromCoreOp(op); + m_iterator = it; + } else { + invalidate(); + } + } + +public: + std::pair operator*() { + return {m_desc, *val()}; + } + + InternalValueT *val() { + assert(this->operator bool() && + "Trying to call val() on invalid OpMapIterator!"); + + if (m_desc.isCoreOp() || m_desc.isIntrinsic()) + return std::addressof(std::get(m_iterator)->second); + + return std::addressof(std::get(m_iterator)->Value); + } + + operator bool() const { return !m_isInvalid; } + + OpMapIteratorBase &operator++() { + OpMapIteratorState stateMachine{*this}; + stateMachine.step(); + + return *this; + } + + OpMapIteratorBase &operator++(int) { return this->operator++(); } + + template > + bool erase() { + if (m_desc.isCoreOp() || m_desc.isIntrinsic()) { + assert(m_desc.getOpcodes().size() == 1 && + "OpMapIterator only supports erasing of single core opcodes and " + "intrinsics."); + + const unsigned op = m_desc.getOpcode(); + + if (m_desc.isCoreOp()) + return m_map->m_coreOpcodes.erase(op); + + return m_map->m_intrinsics.erase(op); + } + + // Try to erase the dialect op at last. + for (size_t I = 0; I < m_map->m_dialectOps.size(); ++I) { + if (m_map->m_dialectOps[I] == m_desc) { + DialectOpIteratorT it = m_map->m_dialectOps.begin(); + std::advance(it, I); + + if (it == m_map->m_dialectOps.end()) + return false; + + m_map->m_dialectOps.erase(it); + return true; + } + } + + return false; + } + +protected: + OpMapT *m_map = nullptr; + OpDescription m_desc; + std::variant m_iterator; + bool m_isInvalid = false; + +private: + void invalidate() { m_isInvalid = true; } + + void createFromFunc(const Function &func) { + if (func.isIntrinsic()) { + m_iterator = m_map->m_intrinsics.find(func.getIntrinsicID()); + + if (std::get(m_iterator) != m_map->m_intrinsics.end()) { + m_desc = OpDescription::fromIntrinsic(func.getIntrinsicID()); + return; + } + } + + createFromDialectOp(func.getName()); + } + + void createFromDialectOp(StringRef funcName) { + size_t idx = 0; + bool found = false; + for (auto &dialectOpKV : m_map->m_dialectOps) { + const DialectOpKey &key = dialectOpKV.Key; + if (detail::isOperationDecl(funcName, key.second, key.first)) { + m_desc = {key.second, key.first}; + auto it = m_map->m_dialectOps.begin(); + std::advance(it, idx); + m_iterator = it; + found = true; + break; + } + + ++idx; + } + + if (!found) + invalidate(); + } + + // Re-construct base OpDescription from the stored iterator. + // Since this is invoked when passing an existing iterator to the + // OpMapIterator constructor, we need to check the original kind as well + // to prevent epoch mismatches when comparing the stored iterator with the + // internal map data structures. + void refreshOpDescriptor(OpDescription::Kind kind) { + if (m_isInvalid) + return; + + if (auto baseIt = std::get_if(&m_iterator)) { + auto &unwrapped = *baseIt; + if (!m_map->m_coreOpcodes.empty() && kind == OpDescription::Kind::Core && + unwrapped != m_map->m_coreOpcodes.end()) { + m_desc = OpDescription::fromCoreOp(unwrapped->first); + } else if (!m_map->m_intrinsics.empty() && + kind == OpDescription::Kind::Intrinsic && + unwrapped != m_map->m_intrinsics.end()) { + m_desc = OpDescription::fromIntrinsic(unwrapped->first); + } else { + llvm_unreachable("OpMapIterator: Invalid iterator provided!"); + } + } else if (auto dialectOpIt = + std::get_if(&m_iterator)) { + auto &unwrapped = *dialectOpIt; + if (unwrapped != m_map->m_dialectOps.end()) + m_desc = {unwrapped->Key.second, unwrapped->Key.first}; + } else { + llvm_unreachable("OpMapIterator: Invalid iterator provided!"); + } + } +}; + +} // namespace llvm_dialects diff --git a/include/llvm-dialects/Dialect/OpSet.h b/include/llvm-dialects/Dialect/OpSet.h new file mode 100644 index 0000000..c585285 --- /dev/null +++ b/include/llvm-dialects/Dialect/OpSet.h @@ -0,0 +1,218 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#pragma once + +#include "llvm-dialects/Dialect/Dialect.h" +#include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +namespace llvm_dialects { + +struct DialectOpPair final { + StringRef mnemonic; + bool isOverload; + + // Checks whether the current pair is comparable to an OpDescription + // object. + bool operator==(const OpDescription &desc) const { + auto Kind = desc.getKind(); + bool hasOverload = Kind == OpDescription::Kind::DialectWithOverloads; + + if (Kind != OpDescription::Kind::Dialect && !hasOverload) + return false; + + return desc.getMnemonic() == mnemonic && hasOverload == isOverload; + } +}; + +// An OpSet defines a set of operations. It is used to simplify operating on a +// set of dialect operations, for instance, in the Visitor. +class OpSet final { +public: + // ------------------------------------------------------------- + // Convenience functions to generate an OpSet from a given range + // of operations. + // ------------------------------------------------------------- + + // Construct an OpSet from a set of core opcodes. + static OpSet fromCoreOpcodes(ArrayRef ops) { + OpSet set; + for (const unsigned op : ops) + set.m_coreOpcodes.insert(op); + + return set; + } + + // Construct an OpSet from a set of intrinsics. + static OpSet fromIntrinsicIDs(ArrayRef intrinsicIDs) { + OpSet set; + for (const unsigned intrinsicID : intrinsicIDs) + set.m_intrinsicIDs.insert(intrinsicID); + + return set; + } + + // Construct an OpSet from a set of OpDescriptions. + static OpSet fromOpDescriptions(ArrayRef descs) { + OpSet set; + for (const OpDescription &desc : descs) + set.tryInsertOp(desc); + + return set; + } + + // Construct an OpSet from a set of dialect ops, given as template + // arguments. + template static const OpSet get() { + static OpSet set; + (... && appendT(set)); + return set; + } + + // ------------------------------------------------------------- + // contains check to check if a given operation is stored in the OpSet. + // ------------------------------------------------------------- + // Checks if a given core opcode is stored in the set. + bool containsCoreOp(unsigned coreOpcode) const { + return m_coreOpcodes.contains(coreOpcode); + } + + // Checks if a given intrinsic ID is stored in the set. + bool containsIntrinsicID(unsigned intrinsicID) const { + return m_intrinsicIDs.contains(intrinsicID); + } + + // Checks if a given dialect operation is stored in the set. + template bool contains() const { + static OpDescription desc = OpDescription::get(); + return contains(desc); + } + + // Checks if a given OpDescription is stored in the set. + bool contains(const OpDescription &desc) const { + if (desc.isCoreOp() || desc.isIntrinsic()) { + assert(desc.getOpcodes().size() == 1 && + "OpSet only supports querying of single core opcodes and " + "intrinsics."); + + const unsigned op = desc.getOpcode(); + return (desc.isCoreOp() && containsCoreOp(op)) || + (desc.isIntrinsic() && containsIntrinsicID(op)); + } + + return isMatchingDialectOp(desc.getMnemonic()); + } + + // Checks if `inst` belongs to the OpSet. + bool contains(const Instruction &inst) const { + if (containsCoreOp(inst.getOpcode())) + return true; + + if (auto *CI = dyn_cast(&inst)) { + const Function *Callee = CI->getCalledFunction(); + if (!Callee) + return false; + + return contains(*Callee); + } + + return false; + } + + // Checks if `func` belongs to the OpSet. + bool contains(const Function &func) const { + if (func.isIntrinsic() && containsIntrinsicID(func.getIntrinsicID())) + return true; + + return isMatchingDialectOp(func.getName()); + } + + // ------------------------------------------------------------- + // Convenience getters to access the internal data structures. + // ------------------------------------------------------------- + const DenseSet &getCoreOpcodes() const { return m_coreOpcodes; } + + const DenseSet &getIntrinsicIDs() const { return m_intrinsicIDs; } + + const ArrayRef getDialectOps() const { return m_dialectOps; } + +private: + // Generates an `OpDescription` for a given `OpT`, extracts the + // internal operation representation and collects it in the set. + template static bool appendT(OpSet &set) { + static OpDescription desc = OpDescription::get(); + set.tryInsertOp(desc); + + return true; + } + + // Checks if `mnemonic` can be described by any of the stored dialect + // operations. + bool isMatchingDialectOp(StringRef mnemonic) const { + for (const auto &dialectOp : m_dialectOps) { + if (detail::isOperationDecl(mnemonic, dialectOp.isOverload, + dialectOp.mnemonic)) + return true; + } + + return false; + } + + // Tries to insert a given description in the internal data structures. + void tryInsertOp(const OpDescription &desc) { + if (desc.isCoreOp()) { + for (const unsigned op : desc.getOpcodes()) + m_coreOpcodes.insert(op); + + return; + } + + if (desc.isIntrinsic()) { + for (const unsigned op : desc.getOpcodes()) + m_intrinsicIDs.insert(op); + + return; + } + + // Store duplicate OpDescriptions once in the set. + if (!contains(desc)) + m_dialectOps.push_back({desc.getMnemonic(), hasOverloads(desc)}); + } + + static bool hasOverloads(const OpDescription &desc) { + return desc.getKind() == OpDescription::Kind::DialectWithOverloads; + } + + DenseSet m_coreOpcodes; + DenseSet m_intrinsicIDs; + SmallVector m_dialectOps; +}; +} // namespace llvm_dialects diff --git a/include/llvm-dialects/Dialect/Visitor.h b/include/llvm-dialects/Dialect/Visitor.h index 4886a4f..a3706f1 100644 --- a/include/llvm-dialects/Dialect/Visitor.h +++ b/include/llvm-dialects/Dialect/Visitor.h @@ -21,6 +21,8 @@ #include "llvm/Support/Casting.h" #include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm-dialects/Dialect/OpMap.h" +#include "llvm-dialects/Dialect/OpSet.h" namespace llvm { class Function; @@ -31,8 +33,7 @@ class Module; namespace llvm_dialects { -template -class Visitor; +template class Visitor; /// The iteration strategy of Visitor. enum class VisitorStrategy { @@ -104,6 +105,19 @@ class VisitorKey { return key; } + template static VisitorKey opSet() { + VisitorKey key{Kind::OpSet}; + static const OpSet set = OpSet::get(); + key.m_set = &set; + return key; + } + + static VisitorKey opSet(const OpSet &set) { + VisitorKey key{Kind::OpSet}; + key.m_set = &set; + return key; + } + static VisitorKey intrinsic(unsigned id) { VisitorKey key{Kind::Intrinsic}; key.m_intrinsicId = id; @@ -114,12 +128,14 @@ class VisitorKey { enum class Kind { OpDescription, Intrinsic, + OpSet, }; VisitorKey(Kind kind) : m_kind(kind) {} Kind m_kind; const OpDescription *m_description = nullptr; + const OpSet *m_set = nullptr; unsigned m_intrinsicId = 0; }; @@ -211,10 +227,7 @@ class VisitorTemplate { VisitorStrategy m_strategy = VisitorStrategy::Default; std::vector m_projections; std::vector m_handlers; - llvm::DenseMap> m_coreOpcodeMap; - llvm::DenseMap> m_intrinsicIdMap; - std::vector>> - m_dialectCases; + OpMap> m_opMap; }; /// @brief Base class for VisitorBuilders @@ -287,9 +300,7 @@ class VisitorBase { VisitorStrategy m_strategy; std::vector m_projections; std::vector m_handlers; - llvm::DenseMap m_coreOpcodeMap; - llvm::DenseMap m_intrinsicIdMap; - std::vector> m_dialectCases; + OpMap m_opMap; }; } // namespace detail @@ -363,6 +374,18 @@ class VisitorBuilder : private detail::VisitorBuilderBase { return *this; } + template + VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) { + addSetCase(detail::VisitorKey::opSet(), fn); + return *this; + } + + VisitorBuilder &addSet(const OpSet &opSet, + void (*fn)(PayloadT &, llvm::Instruction &I)) { + addSetCase(detail::VisitorKey::opSet(opSet), fn); + return *this; + } + template VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) { addMemberFnCase(detail::VisitorKey::op(), fn); return *this; @@ -418,6 +441,14 @@ class VisitorBuilder : private detail::VisitorBuilderBase { VisitorBuilderBase::add(key, &VisitorBuilder::forwarder, data); } + void addSetCase(detail::VisitorKey key, + void (*fn)(PayloadT &, llvm::Instruction &)) { + detail::VisitorCallbackData data{}; + static_assert(sizeof(fn) <= sizeof(data.data)); + memcpy(&data.data, &fn, sizeof(fn)); + VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data); + } + template void addMemberFnCase(detail::VisitorKey key, void (PayloadT::*fn)(OpT &)) { detail::VisitorCallbackData data{}; @@ -434,6 +465,13 @@ class VisitorBuilder : private detail::VisitorBuilderBase { fn(*static_cast(payload), *llvm::cast(op)); } + static void setForwarder(const detail::VisitorCallbackData &data, + void *payload, llvm::Instruction *op) { + void (*fn)(PayloadT &, llvm::Instruction &); + memcpy(&fn, &data.data, sizeof(fn)); + fn(*static_cast(payload), *op); + } + template static void memberFnForwarder(const detail::VisitorCallbackData &data, void *payload, llvm::Instruction *op) { diff --git a/lib/Dialect/Dialect.cpp b/lib/Dialect/Dialect.cpp index 36de464..a9136fa 100644 --- a/lib/Dialect/Dialect.cpp +++ b/lib/Dialect/Dialect.cpp @@ -157,13 +157,13 @@ void ContextMap::remove(LLVMContext *llvmContext, void Dialect::anchor() {} -SmallVectorImpl& Dialect::Key::getRegisteredKeys() { - static SmallVector keys; +SmallVectorImpl &Dialect::Key::getRegisteredKeys() { + static SmallVector keys; return keys; } Dialect::Key::Key() { - auto& keys = getRegisteredKeys(); + auto &keys = getRegisteredKeys(); for (auto enumeratedKey : llvm::enumerate(keys)) { if (!enumeratedKey.value()) { @@ -197,17 +197,18 @@ DialectContext::~DialectContext() { for (unsigned i = 0; i < m_extensionArraySize; ++i) std::destroy_n(extensionArray, m_extensionArraySize); - Dialect** dialectArray = getTrailingObjects(); + Dialect **dialectArray = getTrailingObjects(); for (unsigned i = 0; i < m_dialectArraySize; ++i) delete dialectArray[i]; // may be nullptr } void DialectContext::operator delete(void *ctx) { free(ctx); } -std::unique_ptr DialectContext::make(LLVMContext& context, - ArrayRef dialects) { +std::unique_ptr +DialectContext::make(LLVMContext &context, + ArrayRef dialects) { unsigned dialectArraySize = 0; - for (const auto& desc : dialects) + for (const auto &desc : dialects) dialectArraySize = std::max(dialectArraySize, desc.index + 1); unsigned extensionArraySize = detail::ContextExtensionKey::getKeys().size(); @@ -215,14 +216,14 @@ std::unique_ptr DialectContext::make(LLVMContext& context, size_t totalSize = totalSizeToAlloc>( dialectArraySize, extensionArraySize); - void* ptr = malloc(totalSize); + void *ptr = malloc(totalSize); std::unique_ptr result{ new (ptr) DialectContext(context, dialectArraySize, extensionArraySize)}; - Dialect** dialectArray = result->getTrailingObjects(); + Dialect **dialectArray = result->getTrailingObjects(); std::uninitialized_fill_n(dialectArray, dialectArraySize, nullptr); - for (const auto& desc : dialects) + for (const auto &desc : dialects) dialectArray[desc.index] = desc.make(context); auto *extensionArray = @@ -232,7 +233,7 @@ std::unique_ptr DialectContext::make(LLVMContext& context, return result; } -DialectContext& DialectContext::get(LLVMContext& context) { +DialectContext &DialectContext::get(LLVMContext &context) { return *CurrentContextCache::get(&context); } @@ -249,28 +250,40 @@ void DialectExtensionPointBase::clear(unsigned index) { } bool llvm_dialects::detail::isSimpleOperationDecl(const Function *fn, - StringRef name) { - return fn->getName() == name; + StringRef mnemonic) { + return isOperationDecl(fn->getName(), false, mnemonic); } bool llvm_dialects::detail::isOverloadedOperationDecl(const Function *fn, - StringRef name) { - StringRef fnName = fn->getName(); - if (name.size() >= fnName.size()) - return false; - if (!fnName.startswith(name)) - return false; - return fnName[name.size()] == '.'; + StringRef mnemonic) { + return isOperationDecl(fn->getName(), true, mnemonic); } -bool llvm_dialects::detail::isSimpleOperation(const CallInst *i, StringRef name) { - if (auto* fn = i->getCalledFunction()) - return isSimpleOperationDecl(fn, name); +bool llvm_dialects::detail::isSimpleOperation(const CallInst *i, + StringRef mnemonic) { + if (auto *fn = i->getCalledFunction()) + return isSimpleOperationDecl(fn, mnemonic); return false; } -bool llvm_dialects::detail::isOverloadedOperation(const CallInst *i, StringRef name) { +bool llvm_dialects::detail::isOverloadedOperation(const CallInst *i, + StringRef mnemonic) { if (auto *fn = i->getCalledFunction()) - return isOverloadedOperationDecl(fn, name); + return isOverloadedOperationDecl(fn, mnemonic); return false; } + +bool llvm_dialects::detail::isOperationDecl(llvm::StringRef fn, + bool isOverloaded, + llvm::StringRef mnemonic) { + if (isOverloaded) { + if (mnemonic.size() >= fn.size()) + return false; + if (!fn.startswith(mnemonic)) + return false; + + return fn[mnemonic.size()] == '.'; + } + + return fn == mnemonic; +} diff --git a/lib/Dialect/OpDescription.cpp b/lib/Dialect/OpDescription.cpp index 0d6901f..652a449 100644 --- a/lib/Dialect/OpDescription.cpp +++ b/lib/Dialect/OpDescription.cpp @@ -30,6 +30,13 @@ OpDescription::OpDescription(Kind kind, MutableArrayRef opcodes) llvm::sort(opcodes); } +unsigned OpDescription::getOpcode() const { + const ArrayRef opcodes{getOpcodes()}; + assert(!opcodes.empty() && "OpDescription does not contain any opcode!"); + + return opcodes.front(); +} + ArrayRef OpDescription::getOpcodes() const { assert(m_kind == Kind::Core || m_kind == Kind::Intrinsic); diff --git a/lib/Dialect/Visitor.cpp b/lib/Dialect/Visitor.cpp index af4fad7..2096be0 100644 --- a/lib/Dialect/Visitor.cpp +++ b/lib/Dialect/Visitor.cpp @@ -51,36 +51,39 @@ void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn, handler.callback = fn; handler.data = data; handler.projection = projection; + m_handlers.emplace_back(handler); - unsigned handlerIdx = m_handlers.size() - 1; + const unsigned handlerIdx = m_handlers.size() - 1; if (key.m_kind == VisitorKey::Kind::Intrinsic) { - m_intrinsicIdMap[key.m_intrinsicId].push_back(handlerIdx); - } else { - const OpDescription *description = key.m_description; - switch (description->getKind()) { - case OpDescription::Kind::Core: - for (unsigned opcode : description->getOpcodes()) - m_coreOpcodeMap[opcode].push_back(handlerIdx); - break; - case OpDescription::Kind::Intrinsic: - for (unsigned id : description->getOpcodes()) - m_intrinsicIdMap[id].push_back(handlerIdx); - break; - default: { - auto it = find_if(m_dialectCases, [=](const auto &theCase) { - return theCase.first == description; - }); - if (it != m_dialectCases.end()) { - it->second.push_back(handlerIdx); - } else { - SmallVector handlers; - handlers.push_back(handlerIdx); - m_dialectCases.emplace_back(description, std::move(handlers)); - } - break; + m_opMap[OpDescription::fromIntrinsic(key.m_intrinsicId)].push_back( + handlerIdx); + } else if (key.m_kind == VisitorKey::Kind::OpDescription) { + const OpDescription *opDesc = key.m_description; + + if (opDesc->isCoreOp()) { + for (const unsigned op : opDesc->getOpcodes()) + m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx); + } else if (opDesc->isIntrinsic()) { + for (const unsigned op : opDesc->getOpcodes()) + m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx); + } else { + m_opMap[*opDesc].push_back(handlerIdx); } + } else if (key.m_kind == VisitorKey::Kind::OpSet) { + const OpSet *opSet = key.m_set; + + for (unsigned opcode : opSet->getCoreOpcodes()) + m_opMap[OpDescription::fromCoreOp(opcode)].push_back(handlerIdx); + + for (unsigned intrinsicID : opSet->getIntrinsicIDs()) + m_opMap[OpDescription::fromIntrinsic(intrinsicID)].push_back(handlerIdx); + + for (const auto &dialectOpPair : opSet->getDialectOps()) { + m_opMap[OpDescription::fromDialectOp(dialectOpPair.isOverload, + dialectOpPair.mnemonic)] + .push_back(handlerIdx); } } } @@ -182,26 +185,16 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ) : m_strategy(templ.m_strategy), m_projections(std::move(templ.m_projections)) { if (m_strategy == VisitorStrategy::Default) { - m_strategy = templ.m_coreOpcodeMap.empty() - ? VisitorStrategy::ByFunctionDeclaration - : VisitorStrategy::ByInstruction; + m_strategy = templ.m_opMap.empty() ? VisitorStrategy::ByFunctionDeclaration + : VisitorStrategy::ByInstruction; } BuildHelper helper(*this, templ.m_handlers); - m_coreOpcodeMap.reserve(templ.m_coreOpcodeMap.size()); - m_intrinsicIdMap.reserve(templ.m_intrinsicIdMap.size()); - m_dialectCases.reserve(templ.m_dialectCases.size()); + m_opMap.reserve(templ.m_opMap); - for (const auto &entry : templ.m_coreOpcodeMap) { - m_coreOpcodeMap.try_emplace(entry.first, helper.mapHandlers(entry.second)); - } - for (const auto &entry : templ.m_intrinsicIdMap) { - m_intrinsicIdMap.try_emplace(entry.first, helper.mapHandlers(entry.second)); - } - for (const auto &entry : templ.m_dialectCases) { - m_dialectCases.emplace_back(entry.first, helper.mapHandlers(entry.second)); - } + for (auto it : templ.m_opMap) + m_opMap[it.first] = helper.mapHandlers(it.second); } void VisitorBase::call(HandlerRange handlers, void *payload, @@ -222,30 +215,16 @@ void VisitorBase::call(const VisitorHandler &handler, void *payload, payload = m_projections[idx].projection(payload); } } + handler.callback(handler.data, payload, &inst); } void VisitorBase::visit(void *payload, Instruction &inst) const { - if (auto *callInst = dyn_cast(&inst)) { - // Note: Always fall through to case handlers installed for generic - // CallInst instructions, if there are any. - if (auto *intrinsicInst = dyn_cast(callInst)) { - auto it = m_intrinsicIdMap.find(intrinsicInst->getIntrinsicID()); - if (it != m_intrinsicIdMap.end()) - call(it->second, payload, inst); - } else { - for (const auto &theCase : m_dialectCases) { - if (theCase.first->matchInstruction(inst)) { - call(theCase.second, payload, inst); - break; - } - } - } - } + auto handlers = m_opMap.find(inst); + if (!handlers) + return; - auto it = m_coreOpcodeMap.find(inst.getOpcode()); - if (it != m_coreOpcodeMap.end()) - call(it->second, payload, inst); + call(*handlers.val(), payload, inst); } template @@ -259,26 +238,8 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module, LLVM_DEBUG(dbgs() << "visit " << decl.getName() << '\n'); - HandlerRange handlers{0, 0}; - if (unsigned intrinsicId = decl.getIntrinsicID()) { - auto it = m_intrinsicIdMap.find(intrinsicId); - if (it == m_intrinsicIdMap.end()) { - // Can't be a dialect op, so skip this declaration entirely. - continue; - } - handlers = it->second; - } - - if (handlers.second == 0) { - for (const auto &theCase : m_dialectCases) { - if (theCase.first->matchDeclaration(decl)) { - handlers = theCase.second; - break; - } - } - } - - if (handlers.second == 0) { + auto handlers = m_opMap.find(decl); + if (!handlers) { // Neither a matched intrinsic nor a matched dialect op; skip. continue; } @@ -289,7 +250,7 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module, continue; if (auto *callInst = dyn_cast(inst)) { if (&use == &callInst->getCalledOperandUse()) - call(handlers, payload, *callInst); + call(*handlers.val(), payload, *callInst); } } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2d09a4e..5b56dac 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -51,3 +51,5 @@ add_lit_testsuites(LLVM_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR} ${exclude_from_check_all} DEPENDS ${LLVM_DIALECTS_TEST_DEPENDS} ) + +add_subdirectory(unit) diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 2903bf6..67dd832 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -79,6 +79,16 @@ namespace xd { state.setError(); }); + builder.add([](::llvm_dialects::VerifierState &state, SetReadOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + + builder.add([](::llvm_dialects::VerifierState &state, SetWriteOp &op) { + if (!op.verifier(state.out())) + state.setError(); + }); + builder.add([](::llvm_dialects::VerifierState &state, SizeOfOp &op) { if (!op.verifier(state.out())) state.setError(); @@ -1213,6 +1223,126 @@ index + const ::llvm::StringLiteral SetReadOp::s_name{"xd.set.read"}; + + SetReadOp* SetReadOp::create(llvm_dialects::Builder& b, ::llvm::Type* dataType) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(3); + + std::string mangledName = + ::llvm_dialects::getMangledName(s_name, {dataType}); + auto fnType = ::llvm::FunctionType::get(dataType, { +}, false); + + auto fn = module.getOrInsertFunction(mangledName, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << mangledName << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + return ::llvm::cast(b.CreateCall(fn)); +} + + + bool SetReadOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 0) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 0\n"; + return false; + } + ::llvm::Type * const dataType = getData()->getType(); +(void)dataType; + return true; +} + + +::llvm::Value *SetReadOp::getData() {return this;} + + + + const ::llvm::StringLiteral SetWriteOp::s_name{"xd.set.write"}; + + SetWriteOp* SetWriteOp::create(llvm_dialects::Builder& b, ::llvm::Value * data) { + ::llvm::LLVMContext& context = b.getContext(); + (void)context; + ::llvm::Module& module = *b.GetInsertBlock()->getModule(); + + + const ::llvm::AttributeList attrs + = ExampleDialect::get(context).getAttributeList(0); + auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true); + + auto fn = module.getOrInsertFunction(s_name, fnType, attrs); + ::llvm::SmallString<32> newName; + for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) || + ::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) { + // If a function with the same name but a different types already exists, + // we get a bitcast of a function or a function with the wrong type. + // Try new names until we get one with the correct type. + newName = ""; + ::llvm::raw_svector_ostream newNameStream(newName); + newNameStream << s_name << "_" << i; + fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs); + } + assert(::llvm::isa<::llvm::Function>(fn.getCallee())); + assert(fn.getFunctionType() == fnType); + assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType()); + + ::llvm::SmallVector<::llvm::Value*, 1> args = { +data + }; + + return ::llvm::cast(b.CreateCall(fn, args)); + } + + + bool SetWriteOp::verifier(::llvm::raw_ostream &errs) { + ::llvm::LLVMContext &context = getModule()->getContext(); + (void)context; + + using ::llvm_dialects::printable; + + if (arg_size() != 1) { + errs << " wrong number of arguments: " << arg_size() + << ", expected 1\n"; + return false; + } + ::llvm::Type * const dataType = getData()->getType(); +(void)dataType; + return true; +} + + + ::llvm::Value * SetWriteOp::getData() { + return getArgOperand(0); + } + + void SetWriteOp::setData(::llvm::Value * data) { + setArgOperand(0, data); + } + + + const ::llvm::StringLiteral SizeOfOp::s_name{"xd.sizeof"}; SizeOfOp* SizeOfOp::create(llvm_dialects::Builder& b, ::llvm::Type * sizeofType) { @@ -1769,6 +1899,22 @@ data } + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{true, "xd.set.read"}; + return desc; + } + + + template <> + const ::llvm_dialects::OpDescription & + ::llvm_dialects::OpDescription::get() { + static const ::llvm_dialects::OpDescription desc{false, "xd.set.write"}; + return desc; + } + + template <> const ::llvm_dialects::OpDescription & ::llvm_dialects::OpDescription::get() { diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index c95bc7b..691317a 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -317,6 +317,49 @@ bool verifier(::llvm::raw_ostream &errs); ::llvm::Value * getData(); + }; + + class SetReadOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.set.read"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isOverloadedOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static SetReadOp* create(::llvm_dialects::Builder& b, ::llvm::Type* dataType); + +bool verifier(::llvm::raw_ostream &errs); + + +::llvm::Value * getData(); + + + }; + + class SetWriteOp : public ::llvm::CallInst { + static const ::llvm::StringLiteral s_name; //{"xd.set.write"}; + + public: + static bool classof(const ::llvm::CallInst* i) { + return ::llvm_dialects::detail::isSimpleOperation(i, s_name); + } + static bool classof(const ::llvm::Value* v) { + return ::llvm::isa<::llvm::CallInst>(v) && + classof(::llvm::cast<::llvm::CallInst>(v)); + } + static SetWriteOp* create(::llvm_dialects::Builder& b, ::llvm::Value * data); + +bool verifier(::llvm::raw_ostream &errs); + +::llvm::Value * getData(); + void setData(::llvm::Value * data); + + + }; class SizeOfOp : public ::llvm::CallInst { diff --git a/test/example/test-builder.test b/test/example/test-builder.test index 05a6fe8..0b8df76 100644 --- a/test/example/test-builder.test +++ b/test/example/test-builder.test @@ -25,11 +25,13 @@ ; CHECK-NEXT: call void (...) @xd.write(i8 [[P2]]) ; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]]) ; CHECK-NEXT: [[TMP14:%.*]] = call target("xd.handle") @xd.handle.get() -; CHECK-NEXT: [[TMP15:%.*]] = call [[TMP0]] @xd.read.s_s() -; CHECK-NEXT: [[TMP16:%.*]] = call [[TMP1]] @xd.read.s_s_0() -; CHECK-NEXT: [[TMP17:%.*]] = call [[TMP2]] @xd.read.s_s_1() -; CHECK-NEXT: call void (...) @xd.write([[TMP0]] [[TMP15]]) -; CHECK-NEXT: call void (...) @xd.write([[TMP1]] [[TMP16]]) -; CHECK-NEXT: call void (...) @xd.write([[TMP2]] [[TMP17]]) +; CHECK-NEXT: [[TMP15:%.*]] = call <2 x i32> @xd.set.read.v2i32() +; CHECK-NEXT: call void (...) @xd.set.write(target("xd.vector", i32, 1, 2) [[TMP13]]) +; CHECK-NEXT: [[TMP16:%.*]] = call [[TMP0]] @xd.read.s_s() +; CHECK-NEXT: [[TMP17:%.*]] = call [[TMP1]] @xd.read.s_s_0() +; CHECK-NEXT: [[TMP18:%.*]] = call [[TMP2]] @xd.read.s_s_1() +; CHECK-NEXT: call void (...) @xd.write([[TMP0]] [[TMP16]]) +; CHECK-NEXT: call void (...) @xd.write([[TMP1]] [[TMP17]]) +; CHECK-NEXT: call void (...) @xd.write([[TMP2]] [[TMP18]]) ; CHECK-NEXT: ret void ; diff --git a/test/example/visitor-basic.ll b/test/example/visitor-basic.ll index 05d0be8..7d5813c 100644 --- a/test/example/visitor-basic.ll +++ b/test/example/visitor-basic.ll @@ -6,9 +6,14 @@ ; DEFAULT-NEXT: visiting BinaryOperator: %v1 = add i32 %v, %w ; DEFAULT-NEXT: visiting umax intrinsic: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q) ; DEFAULT-NEXT: visiting WriteOp: call void (...) @xd.write(i8 %t) +; DEFAULT-NEXT: visiting SetReadOp (set): %v.1 = call i32 @xd.set.read.i32() +; DEFAULT-NEXT: visiting UnaryInstruction: %v.2 = trunc i32 %v.1 to i8 +; DEFAULT-NEXT: visiting SetWriteOp (set): call void (...) @xd.set.write(i8 %v.2) ; DEFAULT-NEXT: visiting WriteVarArgOp: call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q) ; DEFAULT-NEXT: %v2 = ; DEFAULT-NEXT: %q = +; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) +; DEFAULT-NEXT: visiting Ret (set): ret void ; DEFAULT-NEXT: visiting ReturnInst: ret void ; DEFAULT-NEXT: inner.counter = 1 @@ -22,13 +27,20 @@ entry: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q) %t = call i8 (...) @xd.itrunc.i8(i32 %v2) call void (...) @xd.write(i8 %t) + %v.1 = call i32 @xd.set.read.i32() + %v.2 = trunc i32 %v.1 to i8 + call void (...) @xd.set.write(i8 %v.2) call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q) + %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) ret void } declare i32 @xd.read.i32() +declare i32 @xd.set.read.i32() declare void @xd.write(...) +declare void @xd.set.write(...) declare void @xd.write.vararg(...) declare i8 @xd.itrunc.i8(...) declare i32 @llvm.umax.i32(i32, i32) +declare i32 @llvm.umin.i32(i32, i32) diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt new file mode 100644 index 0000000..e8379c5 --- /dev/null +++ b/test/unit/CMakeLists.txt @@ -0,0 +1,39 @@ +add_custom_target(DialectsUnitTests) +set_target_properties(DialectsUnitTests PROPERTIES FOLDER "Dialects Unit Tests") + +llvm_map_components_to_libnames(llvm_libs Support Core) + +# Inputs for lit.cfg.site.py.in +set(DIALECTS_UNIT_TEST_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(DIALECTS_UNIT_TEST_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +function(add_dialects_unit_test test_folder) + add_unittest(DialectsUnitTests ${test_folder} ${ARGN}) + target_link_libraries(${test_folder} PRIVATE ${llvm_libs} llvm_dialects) + + # Link to the generated dialect sources + target_sources(${test_folder} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../dialect/TestDialect.cpp) + + target_include_directories(${test_folder} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../dialect + ${CMAKE_CURRENT_BINARY_DIR}/../dialect) +endfunction() + +add_subdirectory(dialect) +add_subdirectory(interface) + +# Let lit discover the GTest tests +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py +) + +add_lit_testsuite(check-llvm-dialects-units "Running the llvm-dialects unit tests" + ${CMAKE_CURRENT_SOURCE_DIR} + ${EXCLUDE_FROM_CHECK_ALL} + DEPENDS TestDialectTableGen DialectsUnitTests) diff --git a/test/unit/dialect/CMakeLists.txt b/test/unit/dialect/CMakeLists.txt new file mode 100644 index 0000000..a5a05bd --- /dev/null +++ b/test/unit/dialect/CMakeLists.txt @@ -0,0 +1,12 @@ + +### TableGen for the test dialect + +set(TEST_TABLEGEN_EXE ${LLVM_TOOLS_BINARY_DIR}/llvm-dialects-tblgen) +set(TEST_TABLEGEN_TARGET llvm-dialects-tblgen) +set(LLVM_TARGET_DEFINITIONS TestDialect.td) + +tablegen(TEST TestDialect.h.inc -gen-dialect-decls --dialect test + EXTRA_INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}/../../../include) +tablegen(TEST TestDialect.cpp.inc -gen-dialect-defs --dialect test + EXTRA_INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}/../../../include) +add_public_tablegen_target(TestDialectTableGen) diff --git a/test/unit/dialect/TestDialect.cpp b/test/unit/dialect/TestDialect.cpp new file mode 100644 index 0000000..e3a24e9 --- /dev/null +++ b/test/unit/dialect/TestDialect.cpp @@ -0,0 +1,32 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#include "TestDialect.h" + +#define GET_INCLUDES +#include "TestDialect.cpp.inc" + +#define GET_DIALECT_DEFS +#include "TestDialect.cpp.inc" diff --git a/test/unit/dialect/TestDialect.h b/test/unit/dialect/TestDialect.h new file mode 100644 index 0000000..b9af23b --- /dev/null +++ b/test/unit/dialect/TestDialect.h @@ -0,0 +1,32 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#pragma once + +#define GET_INCLUDES +#include "TestDialect.h.inc" + +#define GET_DIALECT_DECLS +#include "TestDialect.h.inc" diff --git a/test/unit/dialect/TestDialect.td b/test/unit/dialect/TestDialect.td new file mode 100644 index 0000000..6add4ae --- /dev/null +++ b/test/unit/dialect/TestDialect.td @@ -0,0 +1,67 @@ +/* + *********************************************************************************************************************** + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + **********************************************************************************************************************/ + +include "llvm-dialects/Dialect/Dialect.td" + +def TestDialect : Dialect { + let name = "test"; + let cppNamespace = "test"; +} + +class TestOp traits_> + : Op; + +def DialectOp1 : TestOp<"dialect.op.1", + []> { + let results = (outs); + let arguments = (ins); + + let summary = "Test operation 1"; +} + +def DialectOp2 : TestOp<"dialect.op.2", + []> { + let results = (outs); + let arguments = (ins); + + let summary = "Test operation 2"; +} + +def DialectOp3 : TestOp<"dialect.op.3", + []> { + let results = (outs); + let arguments = (ins); + + let summary = "Test operation 3"; +} + +def DialectOp4 : TestOp<"dialect.op.4", []> { + let results = (outs value:$r); + let arguments = (ins value:$v); + + let defaultBuilderHasExplicitResultType = true; + + let summary = "Test operation 4"; +} diff --git a/test/unit/interface/CMakeLists.txt b/test/unit/interface/CMakeLists.txt new file mode 100644 index 0000000..664e266 --- /dev/null +++ b/test/unit/interface/CMakeLists.txt @@ -0,0 +1,6 @@ +add_dialects_unit_test(DialectsADTTests + OpSetTests.cpp + OpMapTests.cpp + OpMapIRTests.cpp) + +add_dependencies(DialectsADTTests TestDialectTableGen) diff --git a/test/unit/interface/OpMapIRTests.cpp b/test/unit/interface/OpMapIRTests.cpp new file mode 100644 index 0000000..a0aa6c8 --- /dev/null +++ b/test/unit/interface/OpMapIRTests.cpp @@ -0,0 +1,254 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#include "TestDialect.h" +#include "llvm-dialects/Dialect/Builder.h" +#include "llvm-dialects/Dialect/Dialect.h" +#include "llvm-dialects/Dialect/OpMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "gtest/gtest.h" + +#include + +using namespace llvm; +using namespace llvm_dialects; + +class OpMapIRTestFixture : public testing::Test { +protected: + void SetUp() override { + setupDialectsContext(); + makeModule(); + } + + LLVMContext Context; + std::unique_ptr DC; + std::unique_ptr Mod; + Function *EP = nullptr; + + BasicBlock *getEntryBlock() { return EntryBlock; } + +private: + BasicBlock *EntryBlock = nullptr; + + void makeModule() { + Mod = std::make_unique("dialects_test", Context); + const std::array Args = {Type::getInt32Ty(Mod->getContext())}; + FunctionCallee FC = Mod->getOrInsertFunction( + "main", + FunctionType::get(Type::getVoidTy(Mod->getContext()), Args, false)); + EP = cast(FC.getCallee()); + EntryBlock = BasicBlock::Create(Mod->getContext(), "entry", EP); + } + + void setupDialectsContext() { + DC = DialectContext::make(Context); + } +}; + +TEST_F(OpMapIRTestFixture, CoreOpMatchesInstructionTest) { + OpMap map; + + IRBuilder<> Builder{Context}; + Builder.SetInsertPoint(getEntryBlock()); + + const OpDescription SubDesc = OpDescription::fromCoreOp(Instruction::Sub); + map[SubDesc] = "Sub"; + map[OpDescription::fromCoreOp(Instruction::Add)] = "Add"; + + Value *Arg = EP->getArg(0); + + Value *Mul = Builder.CreateMul(Arg, Arg); + Builder.CreateSub(Arg, Mul); + + const Instruction &SubInst = *&getEntryBlock()->back(); + + Builder.CreateAdd(Arg, Arg); + + const Instruction &AddInst = *&getEntryBlock()->back(); + + EXPECT_FALSE(map.lookup(SubInst) == map.lookup(AddInst)); + EXPECT_EQ(map.lookup(SubInst), "Sub"); + EXPECT_EQ(map.lookup(AddInst), "Add"); + + map[SubDesc] = "Sub_Override"; + EXPECT_EQ(map.lookup(SubInst), "Sub_Override"); +} + +TEST_F(OpMapIRTestFixture, IntrinsicOpMatchesInstructionTest) { + OpMap map; + + llvm_dialects::Builder B{Context}; + B.SetInsertPoint(getEntryBlock()); + + const OpDescription SideEffectDesc = + OpDescription::fromIntrinsic(Intrinsic::sideeffect); + const OpDescription AssumeDesc = + OpDescription::fromIntrinsic(Intrinsic::assume); + + map.insert(SideEffectDesc, "sideeffect"); + map.insert(AssumeDesc, "assume"); + + EXPECT_EQ(map[SideEffectDesc], "sideeffect"); + EXPECT_EQ(map[AssumeDesc], "assume"); + + const auto &SideEffect = *B.CreateCall( + Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect)); + + const std::array AssumeArgs = { + ConstantInt::getBool(Type::getInt1Ty(Context), true)}; + const auto &Assume = *B.CreateCall( + Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume), AssumeArgs); + + EXPECT_FALSE(map.lookup(SideEffect) == map.lookup(Assume)); + EXPECT_EQ(map.lookup(SideEffect), "sideeffect"); + EXPECT_EQ(map.lookup(Assume), "assume"); + + map[OpDescription::fromIntrinsic(Intrinsic::sideeffect)] = + "sideeffect_Override"; + EXPECT_EQ(map.lookup(SideEffect), "sideeffect_Override"); +} + +TEST_F(OpMapIRTestFixture, DialectOpMatchesInstructionTest) { + OpMap map; + + map.insert("DialectOp1"); + map.insert("DialectOp2"); + + llvm_dialects::Builder B{Context}; + B.SetInsertPoint(getEntryBlock()); + + const Instruction &Op = *B.create(); + const Instruction &Op2 = *B.create(); + + EXPECT_FALSE(map.lookup(Op) == map.lookup(Op2)); + EXPECT_EQ(map.lookup(Op), "DialectOp1"); + EXPECT_EQ(map.lookup(Op2), "DialectOp2"); + + map[OpDescription::get()] = "DialectOp1_Override"; + EXPECT_EQ(map.lookup(Op), "DialectOp1_Override"); +} + +TEST_F(OpMapIRTestFixture, MixedOpMatchesInstructionTest) { + OpMap map; + + llvm_dialects::Builder B{Context}; + B.SetInsertPoint(getEntryBlock()); + + const OpDescription SideEffectDesc = + OpDescription::fromIntrinsic(Intrinsic::sideeffect); + + map.insert(SideEffectDesc, "sideeffect"); + + const Instruction &Op1 = *B.create(); + const Instruction &Op2 = *B.create(); + + EXPECT_EQ(map[SideEffectDesc], "sideeffect"); + + const auto &SideEffect = *B.CreateCall( + Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect)); + + EXPECT_EQ(map.lookup(SideEffect), "sideeffect"); + + map[OpDescription::get()] = "DO2"; + map[OpDescription::get()] = "DO3"; + + map[OpDescription::fromIntrinsic(Intrinsic::sideeffect)] = + "sideeffect_Override"; + + EXPECT_EQ(map.lookup(SideEffect), "sideeffect_Override"); + EXPECT_EQ(map.lookup(Op1), "DO2"); + EXPECT_EQ(map.lookup(Op2), "DO3"); +} + +TEST_F(OpMapIRTestFixture, DialectOpMatchesFunctionTest) { + OpMap map; + + map.insert("DialectOp1"); + map.insert("DialectOp2"); + + llvm_dialects::Builder B{Context}; + B.SetInsertPoint(getEntryBlock()); + + const auto &Op = *B.create(); + const auto &Op2 = *B.create(); + + const Function &DO1 = *Op.getCalledFunction(); + const Function &DO2 = *Op2.getCalledFunction(); + + EXPECT_FALSE(map.lookup(DO1) == map.lookup(DO2)); + EXPECT_EQ(map.lookup(DO1), "DialectOp1"); + EXPECT_EQ(map.lookup(DO2), "DialectOp2"); + + map[OpDescription::get()] = "DialectOp1_Override"; + EXPECT_EQ(map.lookup(DO1), "DialectOp1_Override"); +} + +TEST_F(OpMapIRTestFixture, OpMapLookupTests) { + OpMap map; + + map.insert("DialectOp1"); + map.insert(OpDescription::fromCoreOp(Instruction::Ret), "RetInst"); + + llvm_dialects::Builder B{Context}; + B.SetInsertPoint(getEntryBlock()); + + const auto &Op = *B.create(); + const auto &Ret = *B.CreateRetVoid(); + + EXPECT_EQ(map.lookup(Op), "DialectOp1"); + EXPECT_EQ(map.lookup(*Op.getCalledFunction()), "DialectOp1"); + EXPECT_EQ(map.lookup(Op), "DialectOp1"); + + EXPECT_EQ(map.lookup(Ret), "RetInst"); +} + +TEST_F(OpMapIRTestFixture, DialectOpOverloadTests) { + OpMap map; + + map.insert("DialectOp4"); + + llvm_dialects::Builder B{Context}; + B.SetInsertPoint(getEntryBlock()); + + Value *Arg = EP->getArg(0); + + Value *Mul = B.CreateMul(Arg, Arg); + B.CreateSub(Arg, Mul); + + Value *AddInt = B.CreateAdd(Arg, Arg); + const auto &Op1 = + *B.create(Type::getInt32Ty(Mod->getContext()), AddInt); + + auto *AddFloat = B.CreateBitCast(AddInt, Type::getFloatTy(Mod->getContext())); + const auto &Op2 = *B.create( + Type::getFloatTy(Mod->getContext()), AddFloat); + + EXPECT_EQ(map.lookup(Op1), "DialectOp4"); + EXPECT_EQ(map.lookup(Op2), "DialectOp4"); +} diff --git a/test/unit/interface/OpMapTests.cpp b/test/unit/interface/OpMapTests.cpp new file mode 100644 index 0000000..00e8483 --- /dev/null +++ b/test/unit/interface/OpMapTests.cpp @@ -0,0 +1,258 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#include "TestDialect.h" +#include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm-dialects/Dialect/OpMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Intrinsics.h" +#include "gtest/gtest.h" + +#include + +using namespace llvm; +using namespace llvm_dialects; + +[[maybe_unused]] constexpr const char OpMapBasicTestsName[] = "OpMapBasicTests"; + +TEST(OpMapBasicTestsName, CoreOpContainsTests) { + OpMap map; + + OpDescription retDesc = OpDescription::fromCoreOp(Instruction::Ret); + OpDescription brDesc = OpDescription::fromCoreOp(Instruction::Br); + map[retDesc] = "RetInst"; + + EXPECT_TRUE(map.containsCoreOp(Instruction::Ret)); + EXPECT_FALSE(map.containsCoreOp(Instruction::Br)); + EXPECT_EQ(map[retDesc], "RetInst"); + + map[brDesc] = "BrInst"; + EXPECT_EQ(map[retDesc], "RetInst"); + EXPECT_TRUE(map.containsCoreOp(Instruction::Br)); + EXPECT_EQ(map[brDesc], "BrInst"); +} + +TEST(OpMapBasicTestsName, IntrinsicOpContainsTests) { + OpMap map; + + OpDescription memCpyDesc = OpDescription::fromIntrinsic(Intrinsic::memcpy); + OpDescription memMoveDesc = OpDescription::fromIntrinsic(Intrinsic::memmove); + map[memCpyDesc] = "MemCpy"; + + EXPECT_TRUE(map.containsIntrinsic(Intrinsic::memcpy)); + EXPECT_FALSE(map.containsIntrinsic(Intrinsic::memmove)); + EXPECT_EQ(map[memCpyDesc], "MemCpy"); + + map[memMoveDesc] = "MemMove"; + EXPECT_EQ(map[memMoveDesc], "MemMove"); + EXPECT_TRUE(map.containsIntrinsic(Intrinsic::memmove)); + EXPECT_EQ(map[memMoveDesc], "MemMove"); +} + +TEST(OpMapBasicTestsName, DialectOpContainsTests) { + OpMap map; + const OpDescription sampleDesc = OpDescription::get(); + + map[OpDescription::get()] = "Hello"; + + EXPECT_TRUE(map.contains()); + EXPECT_TRUE(map.contains(sampleDesc)); + EXPECT_FALSE(map.contains()); + EXPECT_EQ(map[OpDescription::get()], "Hello"); + + map[OpDescription::get()] = "World"; + map[OpDescription::get()] = "DialectOp1"; + + EXPECT_TRUE(map.contains()); + EXPECT_EQ(map[OpDescription::get()], "DialectOp1"); + EXPECT_EQ(map[OpDescription::get()], "World"); + + map[OpDescription::get()] = "DialectOp3"; + EXPECT_TRUE(map.contains()); + EXPECT_EQ(map[OpDescription::get()], "DialectOp3"); +} + +TEST(OpMapBasicTestsName, OpMapLookupTests) { + OpMap map; + map.insert("Hello"); + map.insert("World"); + map.insert("DO3"); + + EXPECT_EQ(static_cast(map.size()), 3); + EXPECT_EQ(map[OpDescription::get()], "Hello"); + EXPECT_EQ(map[OpDescription::get()], "World"); + EXPECT_EQ(map[OpDescription::get()], "DO3"); + map[OpDescription::get()] = "DO3_Override"; + EXPECT_EQ(map[OpDescription::get()], "DO3_Override"); + map.erase(); + EXPECT_EQ(static_cast(map.size()), 2); + EXPECT_FALSE(map.contains()); +} + +TEST(OpMapBasicTestsName, OpMapInitializerTests) { + OpMap map = { + {{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}, + {OpDescription::fromCoreOp(Instruction::Ret), "Ret"}, + {OpDescription::fromIntrinsic(Intrinsic::assume), "Assume"}}}; + + EXPECT_TRUE(map.contains()); + EXPECT_TRUE(map.contains()); + EXPECT_TRUE(map.contains()); + EXPECT_TRUE(map.contains(OpDescription::fromCoreOp(Instruction::Ret))); + EXPECT_TRUE(map.contains(OpDescription::fromIntrinsic(Intrinsic::assume))); + + EXPECT_EQ(map[OpDescription::get()], "Hello"); + EXPECT_EQ(map[OpDescription::get()], "World"); + EXPECT_EQ(map[OpDescription::get()], "DO3"); + EXPECT_EQ(map[OpDescription::fromCoreOp(Instruction::Ret)], "Ret"); + EXPECT_EQ(map[OpDescription::fromIntrinsic(Intrinsic::assume)], "Assume"); + + map[OpDescription::get()] = "DO1"; + EXPECT_EQ(map[OpDescription::get()], "DO1"); + + map[OpDescription::fromCoreOp(Instruction::Ret)] = "RetInst"; + EXPECT_EQ(map[OpDescription::fromCoreOp(Instruction::Ret)], "RetInst"); +} + +TEST(OpMapBasicTestsName, OpMapEqualityTests) { + OpMap map = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + OpMap map2 = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + EXPECT_EQ(map, map2); + + map[OpDescription::get()] = "DO1"; + + EXPECT_NE(map, map2); +} + +TEST(OpMapBasicTestsName, OpMapEqualityOrderingTests) { + OpMap map = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "DO3"}, + {OpDescription::get(), "World"}}}; + + OpMap map2 = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + EXPECT_EQ(map, map2); +} + +TEST(OpMapBasicTestsName, OpMapEqualityEraseTests) { + OpMap map = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + OpMap map2 = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + EXPECT_EQ(map, map2); + + map.erase(); + + EXPECT_NE(map, map2); +} + +TEST(OpMapBasicTestsName, OpMapCopyTests) { + OpMap map = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + OpMap map2 = map; + + (void)map2; + EXPECT_EQ(map, map2); +} + +TEST(OpMapBasicTestsName, OpMapMoveTests) { + OpMap map = {{{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::get(), "DO3"}}}; + + OpMap map2 = std::move(map); + + EXPECT_TRUE(map.empty()); + EXPECT_NE(map, map2); +} + +TEST(OpMapBasicTestsName, OpMapIteratorBaseTests) { + OpMap map = { + {{OpDescription::get(), "Hello"}, + {OpDescription::get(), "World"}, + {OpDescription::fromIntrinsic(Intrinsic::fabs), "fabs"}, + {OpDescription::get(), "DO3"}}}; + + EXPECT_EQ(*map.find(OpDescription::get()).val(), "Hello"); + EXPECT_EQ(*map.find(OpDescription::get()).val(), "World"); + EXPECT_EQ(*map.find(OpDescription::fromIntrinsic(Intrinsic::fabs)).val(), + "fabs"); + EXPECT_EQ(*map.find(OpDescription::get()).val(), "DO3"); +} + +TEST(OpMapBasicTestsName, OpMapIteratorIncTests) { + OpMap map; + + const OpDescription desc1 = OpDescription::get(); + const OpDescription desc2 = OpDescription::get(); + const OpDescription desc3 = OpDescription::fromIntrinsic(Intrinsic::fabs); + const OpDescription desc4 = OpDescription::get(); + const OpDescription desc5 = OpDescription::fromCoreOp(Instruction::FAdd); + + map[desc1] = "DialectOp1"; + map[desc2] = "DialectOp2"; + map[desc3] = "Fabs"; + map[desc4] = "DialectOp3"; + map[desc5] = "FAdd"; + + size_t Idx = 0; + for (auto it = map.begin(); Idx < 5 && it != map.end(); ++it) { + switch (Idx) { + case 0: + EXPECT_EQ((*it).second, "FAdd"); + break; + case 1: + EXPECT_EQ((*it).second, "Fabs"); + break; + case 2: + EXPECT_EQ((*it).second, "DialectOp1"); + break; + case 3: + EXPECT_EQ((*it).second, "DialectOp2"); + break; + case 4: + EXPECT_EQ((*it).second, "DialectOp3"); + break; + } + + ++Idx; + } +} \ No newline at end of file diff --git a/test/unit/interface/OpSetTests.cpp b/test/unit/interface/OpSetTests.cpp new file mode 100644 index 0000000..8a226ad --- /dev/null +++ b/test/unit/interface/OpSetTests.cpp @@ -0,0 +1,114 @@ +/* +*********************************************************************************************************************** +* +* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +*all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +**********************************************************************************************************************/ + +#include "TestDialect.h" +#include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm-dialects/Dialect/OpSet.h" +#include "llvm/IR/Intrinsics.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm_dialects; + +[[maybe_unused]] constexpr const char DialectsOpSetSizeTestsName[] = + "DialectsOpSetSizeTests"; + +[[maybe_unused]] constexpr const char DialectsOpSetContainsTestsName[] = + "DialectsOpSetContainsTests"; + +#define EXPECT_EQ_SIZE(Expression, Value) \ + EXPECT_EQ(static_cast(Expression), Value) + +TEST(DialectsOpSetSizeTestsName, EmptyTest) { + const OpSet set; + EXPECT_TRUE(set.getCoreOpcodes().empty()); + EXPECT_TRUE(set.getIntrinsicIDs().empty()); + EXPECT_TRUE(set.getDialectOps().empty()); +} + +TEST(DialectsOpSetSizeTestsName, NonEmptyCoreOpcodesTest) { + const OpSet set = OpSet::fromCoreOpcodes({1, 2}); + EXPECT_EQ_SIZE(set.getCoreOpcodes().size(), 2); +} + +TEST(DialectsOpSetSizeTestsName, NonEmptyIntrinsicsTest) { + const OpSet set = OpSet::fromIntrinsicIDs( + {Intrinsic::lifetime_start, Intrinsic::lifetime_end}); + EXPECT_EQ_SIZE(set.getIntrinsicIDs().size(), 2); +} + +TEST(DialectsOpSetSizeTestsName, NonEmptyOpDescriptionsTest) { + OpDescription const desc1 = OpDescription::get(); + OpDescription const desc2 = OpDescription::get(); + const OpSet set = OpSet::fromOpDescriptions({desc1, desc2}); + EXPECT_EQ_SIZE(set.getDialectOps().size(), 2); +} + +TEST(DialectsOpSetSizeTestsName, NonEmptyOpDescriptionsTemplatizedMakerTest) { + const OpSet set = OpSet::get(); + EXPECT_EQ_SIZE(set.getDialectOps().size(), 2); +} + +TEST(DialectsOpSetContainsTestsName, containsCoreOps) { + const OpSet set = OpSet::fromCoreOpcodes({1, 2}); // Ret, Br + EXPECT_TRUE(set.containsCoreOp(Instruction::Ret)); + EXPECT_TRUE(set.containsCoreOp(Instruction::Br)); + EXPECT_FALSE(set.containsCoreOp(Instruction::Switch)); +} + +TEST(DialectsOpSetContainsTestsName, ContainsIntrinsicIDs) { + const OpSet set = OpSet::fromIntrinsicIDs({Intrinsic::sqrt, Intrinsic::fabs}); + EXPECT_TRUE(set.containsIntrinsicID(Intrinsic::sqrt)); + EXPECT_TRUE(set.containsIntrinsicID(Intrinsic::fabs)); + EXPECT_FALSE(set.containsIntrinsicID(Intrinsic::floor)); +} + +TEST(DialectsOpSetContainsTestsName, contains) { + const OpSet set = OpSet::get(); + EXPECT_TRUE(set.contains()); + EXPECT_TRUE(set.contains()); + EXPECT_TRUE(set.contains()); + EXPECT_TRUE(set.contains()); + EXPECT_FALSE(set.contains()); +} + +TEST(DialectsOpSetSizeTestsName, + StoreDuplicateOpDescriptionsOnceTemplatizedMaker) { + const OpSet set = OpSet::get(); + EXPECT_EQ_SIZE(set.getDialectOps().size(), 1); + EXPECT_TRUE(set.contains()); + EXPECT_FALSE(set.contains()); +} + +TEST(DialectsOpSetSizeTestsName, StoreDuplicateOpDescriptionsOnce) { + const OpDescription desc1 = OpDescription::get(); + const OpDescription desc2 = OpDescription::get(); + OpSet set = OpSet::fromOpDescriptions({desc1, desc2}); + EXPECT_EQ_SIZE(set.getDialectOps().size(), 1); + EXPECT_TRUE(set.contains()); + EXPECT_FALSE(set.contains()); +} + +#undef EXPECT_EQ_SIZE diff --git a/test/unit/lit.cfg.py b/test/unit/lit.cfg.py new file mode 100644 index 0000000..3c9b3d4 --- /dev/null +++ b/test/unit/lit.cfg.py @@ -0,0 +1,60 @@ +## +####################################################################################################################### +# +# Copyright (c) 2021 Google LLC. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +####################################################################################################################### + +# Configuration file for the 'lit' test runner for llvm-dialects unit tests. Based on the MLIR unit test config. +import os + +import lit.formats + +# name: The name of this test suite. +config.name = 'Dialects_Unit' + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = [] + +# test_source_root: The root path where tests are located. +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.dialects_unit_test_binary_dir, "interface") +config.test_source_root = config.test_exec_root + +# testFormat: The test format to use to interpret tests. +config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests') + +# Propagate the temp directory. Windows requires this because it uses \Windows\ +# if none of these are present. +if 'TMP' in os.environ: + config.environment['TMP'] = os.environ['TMP'] +if 'TEMP' in os.environ: + config.environment['TEMP'] = os.environ['TEMP'] + +# Propagate HOME as it can be used to override incorrect homedir in passwd +# that causes the tests to fail. +if 'HOME' in os.environ: + config.environment['HOME'] = os.environ['HOME'] + +# Propagate path to symbolizer for ASan/MSan. +for symbolizer in ['ASAN_SYMBOLIZER_PATH', 'MSAN_SYMBOLIZER_PATH']: + if symbolizer in os.environ: + config.environment[symbolizer] = os.environ[symbolizer] diff --git a/test/unit/lit.site.cfg.py.in b/test/unit/lit.site.cfg.py.in new file mode 100644 index 0000000..3fc3a07 --- /dev/null +++ b/test/unit/lit.site.cfg.py.in @@ -0,0 +1,22 @@ +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.llvm_src_root = "@LLVM_BUILD_MAIN_SRC_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_build_mode = "@LLVM_BUILD_MODE@" +config.dialects_unit_test_binary_dir = "@DIALECTS_UNIT_TEST_BINARY_DIR@" + +# Support substitution of the tools and libs dirs with user parameters. This is +# used when we can't determine the tool dir at configuration time. +try: + config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params + config.llvm_build_mode = config.llvm_build_mode % lit_config.params +except KeyError: + e = sys.exc_info()[1] + key, = e.args + lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) + +# Let the main config do the real work. +lit_config.load_config(config, "@DIALECTS_UNIT_TEST_SOURCE_DIR@/lit.cfg.py")