Skip to content

Commit

Permalink
Add VisitorResult (#81)
Browse files Browse the repository at this point in the history
Allow a visitor callback to decide whether additional matching callbacks
for the same instruction should be called or not.

This is useful when a visitor callback wants to erase an instruction and
there are multiple visitors that might match the same instruction (as
sometimes happens with generic instructions like load and store).
  • Loading branch information
nhaehnle authored Jan 31, 2024
1 parent 69e114f commit daa38d7
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 35 deletions.
10 changes: 10 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ struct VisitorNest {
void visitBinaryOperator(BinaryOperator &inst) {
*out << "visiting BinaryOperator: " << inst << '\n';
}
VisitorResult visitUnaryInstruction(UnaryInstruction &inst) {
*out << "visiting UnaryInstruction (pre): " << inst << '\n';
return isa<LoadInst>(inst) ? VisitorResult::Stop : VisitorResult::Continue;
}
};

struct VisitorContainer {
Expand Down Expand Up @@ -181,6 +185,12 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
b.add<xd::ReadOp>([](VisitorNest &self, xd::ReadOp &op) {
*self.out << "visiting ReadOp: " << op << '\n';
});
b.add(&VisitorNest::visitUnaryInstruction);
b.add<xd::SetReadOp>([](VisitorNest &self, xd::SetReadOp &op) {
*self.out << "visiting SetReadOp: " << op << '\n';
return op.getType()->isIntegerTy(1) ? VisitorResult::Stop
: VisitorResult::Continue;
});
b.addSet<xd::SetReadOp, xd::SetWriteOp>(
[](VisitorNest &self, llvm::Instruction &op) {
if (isa<xd::SetReadOp>(op)) {
Expand Down
133 changes: 106 additions & 27 deletions include/llvm-dialects/Dialect/Visitor.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved.
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,6 +89,25 @@ struct VisitorPayloadProjection {
static constexpr std::size_t offset = offsetof(PayloadT, field); \
};

/// @brief Possible result states of visitor callbacks
///
/// A visitor may have multiple callbacks registered that match on the same
/// instruction. By default, all matching callbacks are invoked in the order in
/// which they were registered with the visitor. This may not be appropriate.
/// A common issue is when the callback erases and replaces the visited
/// instruction.
///
/// Callbacks may explicitly return a result state to indicate whether further
/// visits are desired.
enum class VisitorResult {
/// Continue with the next callbacks on the same instruction. This is the
/// default when the callback does not return a value.
Continue,

/// Skip subsequent callbacks
Stop,
};

namespace detail {

class VisitorBase;
Expand Down Expand Up @@ -158,8 +177,8 @@ struct VisitorCallbackData : public Foo0, Foo1 {
char data[Size];
};

using VisitorCallback = void(const VisitorCallbackData &, void *,
llvm::Instruction *);
using VisitorCallback = VisitorResult(const VisitorCallbackData &, void *,
llvm::Instruction *);
using PayloadProjectionCallback = void *(void *);

struct VisitorHandler {
Expand Down Expand Up @@ -290,8 +309,8 @@ class VisitorBase {

void call(HandlerRange handlers, void *payload,
llvm::Instruction &inst) const;
void call(const VisitorHandler &handler, void *payload,
llvm::Instruction &inst) const;
VisitorResult call(const VisitorHandler &handler, void *payload,
llvm::Instruction &inst) const;

template <typename FilterT>
void visitByDeclarations(void *payload, llvm::Module &module,
Expand Down Expand Up @@ -369,34 +388,74 @@ class VisitorBuilder : private detail::VisitorBuilderBase {

Visitor<PayloadT> build() { return VisitorBuilderBase::build(); }

template <typename OpT>
VisitorBuilder &add(VisitorResult (*fn)(PayloadT &, OpT &)) {
addCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
return *this;
}

template <typename OpT> VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) {
addCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
return *this;
}

template <typename... OpTs>
VisitorBuilder &addSet(VisitorResult (*fn)(PayloadT &,
llvm::Instruction &I)) {
addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn);
return *this;
}

template <typename... OpTs>
VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) {
addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn);
return *this;
}

VisitorBuilder &addSet(const OpSet &opSet,
VisitorResult (*fn)(PayloadT &,
llvm::Instruction &I)) {
addSetCase(detail::VisitorKey::opSet(opSet), fn);
return *this;
}

VisitorBuilder &addSet(const OpSet &opSet,
void (*fn)(PayloadT &, llvm::Instruction &I)) {
addSetCase(detail::VisitorKey::opSet(opSet), fn);
return *this;
}

template <typename OpT>
VisitorBuilder &add(VisitorResult (PayloadT::*fn)(OpT &)) {
addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
return *this;
}

template <typename OpT> VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) {
addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
return *this;
}

VisitorBuilder &addIntrinsic(unsigned id,
VisitorResult (*fn)(PayloadT &,
llvm::IntrinsicInst &)) {
addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
return *this;
}

VisitorBuilder &addIntrinsic(unsigned id,
void (*fn)(PayloadT &, llvm::IntrinsicInst &)) {
addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
return *this;
}

VisitorBuilder &
addIntrinsic(unsigned id,
VisitorResult (PayloadT::*fn)(llvm::IntrinsicInst &)) {
addMemberFnCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
return *this;
}

VisitorBuilder &addIntrinsic(unsigned id,
void (PayloadT::*fn)(llvm::IntrinsicInst &)) {
addMemberFnCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
Expand Down Expand Up @@ -433,52 +492,72 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
detail::PayloadProjectionCallback *projection)
: VisitorBuilderBase(parent, projection) {}

template <typename OpT>
void addCase(detail::VisitorKey key, void (*fn)(PayloadT &, OpT &)) {
template <typename OpT, typename ReturnT>
void addCase(detail::VisitorKey key, ReturnT (*fn)(PayloadT &, OpT &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::add(key, &VisitorBuilder::forwarder<OpT>, data);
VisitorBuilderBase::add(key, &VisitorBuilder::forwarder<OpT, ReturnT>,
data);
}

template <typename ReturnT>
void addSetCase(detail::VisitorKey key,
void (*fn)(PayloadT &, llvm::Instruction &)) {
ReturnT (*fn)(PayloadT &, llvm::Instruction &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data);
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder<ReturnT>, data);
}

template <typename OpT>
void addMemberFnCase(detail::VisitorKey key, void (PayloadT::*fn)(OpT &)) {
template <typename OpT, typename ReturnT>
void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::add(key, &VisitorBuilder::memberFnForwarder<OpT>, data);
VisitorBuilderBase::add(
key, &VisitorBuilder::memberFnForwarder<OpT, ReturnT>, data);
}

template <typename OpT>
static void forwarder(const detail::VisitorCallbackData &data, void *payload,
llvm::Instruction *op) {
void (*fn)(PayloadT &, OpT &);
template <typename OpT, typename ReturnT>
static VisitorResult forwarder(const detail::VisitorCallbackData &data,
void *payload, llvm::Instruction *op) {
ReturnT (*fn)(PayloadT &, OpT &);
memcpy(&fn, &data.data, sizeof(fn));
fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op));
if constexpr (std::is_same_v<ReturnT, void>) {
fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op));
return VisitorResult::Continue;
} else {
return fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op));
}
}

static void setForwarder(const detail::VisitorCallbackData &data,
void *payload, llvm::Instruction *op) {
void (*fn)(PayloadT &, llvm::Instruction &);
template <typename ReturnT>
static VisitorResult setForwarder(const detail::VisitorCallbackData &data,
void *payload, llvm::Instruction *op) {
ReturnT (*fn)(PayloadT &, llvm::Instruction &);
memcpy(&fn, &data.data, sizeof(fn));
fn(*static_cast<PayloadT *>(payload), *op);
if constexpr (std::is_same_v<ReturnT, void>) {
fn(*static_cast<PayloadT *>(payload), *op);
return VisitorResult::Continue;
} else {
return fn(*static_cast<PayloadT *>(payload), *op);
}
}

template <typename OpT>
static void memberFnForwarder(const detail::VisitorCallbackData &data,
void *payload, llvm::Instruction *op) {
void (PayloadT::*fn)(OpT &);
template <typename OpT, typename ReturnT>
static VisitorResult
memberFnForwarder(const detail::VisitorCallbackData &data, void *payload,
llvm::Instruction *op) {
ReturnT (PayloadT::*fn)(OpT &);
memcpy(&fn, &data.data, sizeof(fn));
PayloadT *self = static_cast<PayloadT *>(payload);
(self->*fn)(*llvm::cast<OpT>(op));
if constexpr (std::is_same_v<ReturnT, void>) {
(self->*fn)(*llvm::cast<OpT>(op));
return VisitorResult::Continue;
} else {
return (self->*fn)(*llvm::cast<OpT>(op));
}
}
};

Expand Down
15 changes: 9 additions & 6 deletions lib/Dialect/Visitor.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved.
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -199,12 +199,15 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)

void VisitorBase::call(HandlerRange handlers, void *payload,
Instruction &inst) const {
for (unsigned idx = handlers.first; idx != handlers.second; ++idx)
call(m_handlers[idx], payload, inst);
for (unsigned idx = handlers.first; idx != handlers.second; ++idx) {
VisitorResult result = call(m_handlers[idx], payload, inst);
if (result == VisitorResult::Stop)
return;
}
}

void VisitorBase::call(const VisitorHandler &handler, void *payload,
Instruction &inst) const {
VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
Instruction &inst) const {
if (handler.projection.isOffset()) {
payload = (char *)payload + handler.projection.getOffset();
} else {
Expand All @@ -216,7 +219,7 @@ void VisitorBase::call(const VisitorHandler &handler, void *payload,
}
}

handler.callback(handler.data, payload, &inst);
return handler.callback(handler.data, payload, &inst);
}

void VisitorBase::visit(void *payload, Instruction &inst) const {
Expand Down
9 changes: 7 additions & 2 deletions test/example/visitor-basic.ll
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
; RUN: llvm-dialects-example -visit %s | FileCheck --check-prefixes=DEFAULT %s

; DEFAULT: visiting ReadOp: %v = call i32 @xd.read.i32()
; DEFAULT-NEXT: visiting UnaryInstruction: %w = load i32, ptr %p
; DEFAULT-NEXT: visiting UnaryInstruction: %q = load i32, ptr %p1
; DEFAULT-NEXT: visiting UnaryInstruction (pre): %w = load i32, ptr %p
; DEFAULT-NEXT: visiting UnaryInstruction (pre): %q = load i32, ptr %p1
; DEFAULT-NEXT: visiting BinaryOperator: %v1 = add i32 %v, %w
; DEFAULT-NEXT: visiting umax intrinsic: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q)
; DEFAULT-NEXT: visiting WriteOp: call void (...) @xd.write(i8 %t)
; DEFAULT-NEXT: visiting SetReadOp: %v.0 = call i1 @xd.set.read.i1()
; DEFAULT-NEXT: visiting SetReadOp: %v.1 = call i32 @xd.set.read.i32()
; DEFAULT-NEXT: visiting SetReadOp (set): %v.1 = call i32 @xd.set.read.i32()
; DEFAULT-NEXT: visiting UnaryInstruction (pre): %v.2 = trunc i32 %v.1 to i8
; DEFAULT-NEXT: visiting UnaryInstruction: %v.2 = trunc i32 %v.1 to i8
; DEFAULT-NEXT: visiting SetWriteOp (set): call void (...) @xd.set.write(i8 %v.2)
; DEFAULT-NEXT: visiting WriteVarArgOp: call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q)
Expand All @@ -27,6 +30,7 @@ entry:
%v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q)
%t = call i8 (...) @xd.itrunc.i8(i32 %v2)
call void (...) @xd.write(i8 %t)
%v.0 = call i1 @xd.set.read.i1()
%v.1 = call i32 @xd.set.read.i32()
%v.2 = trunc i32 %v.1 to i8
call void (...) @xd.set.write(i8 %v.2)
Expand All @@ -36,6 +40,7 @@ entry:
}

declare i32 @xd.read.i32()
declare i1 @xd.set.read.i1()
declare i32 @xd.set.read.i32()
declare void @xd.write(...)
declare void @xd.set.write(...)
Expand Down

0 comments on commit daa38d7

Please sign in to comment.