diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index 91203cc..3ac7e5e 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -148,6 +148,10 @@ struct VisitorNest { void visitBinaryOperator(BinaryOperator &inst) { *out << "visiting BinaryOperator: " << inst << '\n'; } + VisitorResult visitUnaryInstruction(UnaryInstruction &inst) { + *out << "visiting UnaryInstruction (pre): " << inst << '\n'; + return isa(inst) ? VisitorResult::Stop : VisitorResult::Continue; + } }; struct VisitorContainer { @@ -181,6 +185,12 @@ template const Visitor &getExampleVisitor() { b.add([](VisitorNest &self, xd::ReadOp &op) { *self.out << "visiting ReadOp: " << op << '\n'; }); + b.add(&VisitorNest::visitUnaryInstruction); + b.add([](VisitorNest &self, xd::SetReadOp &op) { + *self.out << "visiting SetReadOp: " << op << '\n'; + return op.getType()->isIntegerTy(1) ? VisitorResult::Stop + : VisitorResult::Continue; + }); b.addSet( [](VisitorNest &self, llvm::Instruction &op) { if (isa(op)) { diff --git a/include/llvm-dialects/Dialect/Visitor.h b/include/llvm-dialects/Dialect/Visitor.h index a3706f1..bce34ca 100644 --- a/include/llvm-dialects/Dialect/Visitor.h +++ b/include/llvm-dialects/Dialect/Visitor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved. + * Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,6 +89,25 @@ struct VisitorPayloadProjection { static constexpr std::size_t offset = offsetof(PayloadT, field); \ }; +/// @brief Possible result states of visitor callbacks +/// +/// A visitor may have multiple callbacks registered that match on the same +/// instruction. By default, all matching callbacks are invoked in the order in +/// which they were registered with the visitor. This may not be appropriate. +/// A common issue is when the callback erases and replaces the visited +/// instruction. +/// +/// Callbacks may explicitly return a result state to indicate whether further +/// visits are desired. +enum class VisitorResult { + /// Continue with the next callbacks on the same instruction. This is the + /// default when the callback does not return a value. + Continue, + + /// Skip subsequent callbacks + Stop, +}; + namespace detail { class VisitorBase; @@ -158,8 +177,8 @@ struct VisitorCallbackData : public Foo0, Foo1 { char data[Size]; }; -using VisitorCallback = void(const VisitorCallbackData &, void *, - llvm::Instruction *); +using VisitorCallback = VisitorResult(const VisitorCallbackData &, void *, + llvm::Instruction *); using PayloadProjectionCallback = void *(void *); struct VisitorHandler { @@ -290,8 +309,8 @@ class VisitorBase { void call(HandlerRange handlers, void *payload, llvm::Instruction &inst) const; - void call(const VisitorHandler &handler, void *payload, - llvm::Instruction &inst) const; + VisitorResult call(const VisitorHandler &handler, void *payload, + llvm::Instruction &inst) const; template void visitByDeclarations(void *payload, llvm::Module &module, @@ -369,34 +388,74 @@ class VisitorBuilder : private detail::VisitorBuilderBase { Visitor build() { return VisitorBuilderBase::build(); } + template + VisitorBuilder &add(VisitorResult (*fn)(PayloadT &, OpT &)) { + addCase(detail::VisitorKey::op(), fn); + return *this; + } + template VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) { addCase(detail::VisitorKey::op(), fn); return *this; } + template + VisitorBuilder &addSet(VisitorResult (*fn)(PayloadT &, + llvm::Instruction &I)) { + addSetCase(detail::VisitorKey::opSet(), fn); + return *this; + } + template VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) { addSetCase(detail::VisitorKey::opSet(), fn); return *this; } + VisitorBuilder &addSet(const OpSet &opSet, + VisitorResult (*fn)(PayloadT &, + llvm::Instruction &I)) { + addSetCase(detail::VisitorKey::opSet(opSet), fn); + return *this; + } + VisitorBuilder &addSet(const OpSet &opSet, void (*fn)(PayloadT &, llvm::Instruction &I)) { addSetCase(detail::VisitorKey::opSet(opSet), fn); return *this; } + template + VisitorBuilder &add(VisitorResult (PayloadT::*fn)(OpT &)) { + addMemberFnCase(detail::VisitorKey::op(), fn); + return *this; + } + template VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) { addMemberFnCase(detail::VisitorKey::op(), fn); return *this; } + VisitorBuilder &addIntrinsic(unsigned id, + VisitorResult (*fn)(PayloadT &, + llvm::IntrinsicInst &)) { + addCase(detail::VisitorKey::intrinsic(id), fn); + return *this; + } + VisitorBuilder &addIntrinsic(unsigned id, void (*fn)(PayloadT &, llvm::IntrinsicInst &)) { addCase(detail::VisitorKey::intrinsic(id), fn); return *this; } + VisitorBuilder & + addIntrinsic(unsigned id, + VisitorResult (PayloadT::*fn)(llvm::IntrinsicInst &)) { + addMemberFnCase(detail::VisitorKey::intrinsic(id), fn); + return *this; + } + VisitorBuilder &addIntrinsic(unsigned id, void (PayloadT::*fn)(llvm::IntrinsicInst &)) { addMemberFnCase(detail::VisitorKey::intrinsic(id), fn); @@ -433,52 +492,72 @@ class VisitorBuilder : private detail::VisitorBuilderBase { detail::PayloadProjectionCallback *projection) : VisitorBuilderBase(parent, projection) {} - template - void addCase(detail::VisitorKey key, void (*fn)(PayloadT &, OpT &)) { + template + void addCase(detail::VisitorKey key, ReturnT (*fn)(PayloadT &, OpT &)) { detail::VisitorCallbackData data{}; static_assert(sizeof(fn) <= sizeof(data.data)); memcpy(&data.data, &fn, sizeof(fn)); - VisitorBuilderBase::add(key, &VisitorBuilder::forwarder, data); + VisitorBuilderBase::add(key, &VisitorBuilder::forwarder, + data); } + template void addSetCase(detail::VisitorKey key, - void (*fn)(PayloadT &, llvm::Instruction &)) { + ReturnT (*fn)(PayloadT &, llvm::Instruction &)) { detail::VisitorCallbackData data{}; static_assert(sizeof(fn) <= sizeof(data.data)); memcpy(&data.data, &fn, sizeof(fn)); - VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data); + VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data); } - template - void addMemberFnCase(detail::VisitorKey key, void (PayloadT::*fn)(OpT &)) { + template + void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) { detail::VisitorCallbackData data{}; static_assert(sizeof(fn) <= sizeof(data.data)); memcpy(&data.data, &fn, sizeof(fn)); - VisitorBuilderBase::add(key, &VisitorBuilder::memberFnForwarder, data); + VisitorBuilderBase::add( + key, &VisitorBuilder::memberFnForwarder, data); } - template - static void forwarder(const detail::VisitorCallbackData &data, void *payload, - llvm::Instruction *op) { - void (*fn)(PayloadT &, OpT &); + template + static VisitorResult forwarder(const detail::VisitorCallbackData &data, + void *payload, llvm::Instruction *op) { + ReturnT (*fn)(PayloadT &, OpT &); memcpy(&fn, &data.data, sizeof(fn)); - fn(*static_cast(payload), *llvm::cast(op)); + if constexpr (std::is_same_v) { + fn(*static_cast(payload), *llvm::cast(op)); + return VisitorResult::Continue; + } else { + return fn(*static_cast(payload), *llvm::cast(op)); + } } - static void setForwarder(const detail::VisitorCallbackData &data, - void *payload, llvm::Instruction *op) { - void (*fn)(PayloadT &, llvm::Instruction &); + template + static VisitorResult setForwarder(const detail::VisitorCallbackData &data, + void *payload, llvm::Instruction *op) { + ReturnT (*fn)(PayloadT &, llvm::Instruction &); memcpy(&fn, &data.data, sizeof(fn)); - fn(*static_cast(payload), *op); + if constexpr (std::is_same_v) { + fn(*static_cast(payload), *op); + return VisitorResult::Continue; + } else { + return fn(*static_cast(payload), *op); + } } - template - static void memberFnForwarder(const detail::VisitorCallbackData &data, - void *payload, llvm::Instruction *op) { - void (PayloadT::*fn)(OpT &); + template + static VisitorResult + memberFnForwarder(const detail::VisitorCallbackData &data, void *payload, + llvm::Instruction *op) { + ReturnT (PayloadT::*fn)(OpT &); memcpy(&fn, &data.data, sizeof(fn)); PayloadT *self = static_cast(payload); - (self->*fn)(*llvm::cast(op)); + if constexpr (std::is_same_v) { + (self->*fn)(*llvm::cast(op)); + return VisitorResult::Continue; + } else { + return (self->*fn)(*llvm::cast(op)); + } } }; diff --git a/lib/Dialect/Visitor.cpp b/lib/Dialect/Visitor.cpp index 50aa21a..69e9e4f 100644 --- a/lib/Dialect/Visitor.cpp +++ b/lib/Dialect/Visitor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved. + * Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -199,12 +199,15 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ) void VisitorBase::call(HandlerRange handlers, void *payload, Instruction &inst) const { - for (unsigned idx = handlers.first; idx != handlers.second; ++idx) - call(m_handlers[idx], payload, inst); + for (unsigned idx = handlers.first; idx != handlers.second; ++idx) { + VisitorResult result = call(m_handlers[idx], payload, inst); + if (result == VisitorResult::Stop) + return; + } } -void VisitorBase::call(const VisitorHandler &handler, void *payload, - Instruction &inst) const { +VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload, + Instruction &inst) const { if (handler.projection.isOffset()) { payload = (char *)payload + handler.projection.getOffset(); } else { @@ -216,7 +219,7 @@ void VisitorBase::call(const VisitorHandler &handler, void *payload, } } - handler.callback(handler.data, payload, &inst); + return handler.callback(handler.data, payload, &inst); } void VisitorBase::visit(void *payload, Instruction &inst) const { diff --git a/test/example/visitor-basic.ll b/test/example/visitor-basic.ll index 7d5813c..c7c1d38 100644 --- a/test/example/visitor-basic.ll +++ b/test/example/visitor-basic.ll @@ -1,12 +1,15 @@ ; RUN: llvm-dialects-example -visit %s | FileCheck --check-prefixes=DEFAULT %s ; DEFAULT: visiting ReadOp: %v = call i32 @xd.read.i32() -; DEFAULT-NEXT: visiting UnaryInstruction: %w = load i32, ptr %p -; DEFAULT-NEXT: visiting UnaryInstruction: %q = load i32, ptr %p1 +; DEFAULT-NEXT: visiting UnaryInstruction (pre): %w = load i32, ptr %p +; DEFAULT-NEXT: visiting UnaryInstruction (pre): %q = load i32, ptr %p1 ; DEFAULT-NEXT: visiting BinaryOperator: %v1 = add i32 %v, %w ; DEFAULT-NEXT: visiting umax intrinsic: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q) ; DEFAULT-NEXT: visiting WriteOp: call void (...) @xd.write(i8 %t) +; DEFAULT-NEXT: visiting SetReadOp: %v.0 = call i1 @xd.set.read.i1() +; DEFAULT-NEXT: visiting SetReadOp: %v.1 = call i32 @xd.set.read.i32() ; DEFAULT-NEXT: visiting SetReadOp (set): %v.1 = call i32 @xd.set.read.i32() +; DEFAULT-NEXT: visiting UnaryInstruction (pre): %v.2 = trunc i32 %v.1 to i8 ; DEFAULT-NEXT: visiting UnaryInstruction: %v.2 = trunc i32 %v.1 to i8 ; DEFAULT-NEXT: visiting SetWriteOp (set): call void (...) @xd.set.write(i8 %v.2) ; DEFAULT-NEXT: visiting WriteVarArgOp: call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q) @@ -27,6 +30,7 @@ entry: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q) %t = call i8 (...) @xd.itrunc.i8(i32 %v2) call void (...) @xd.write(i8 %t) + %v.0 = call i1 @xd.set.read.i1() %v.1 = call i32 @xd.set.read.i32() %v.2 = trunc i32 %v.1 to i8 call void (...) @xd.set.write(i8 %v.2) @@ -36,6 +40,7 @@ entry: } declare i32 @xd.read.i32() +declare i1 @xd.set.read.i1() declare i32 @xd.set.read.i32() declare void @xd.write(...) declare void @xd.set.write(...)