Skip to content

Commit

Permalink
Adds IndexNotationVisitors
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikbk committed Apr 27, 2018
1 parent b29b033 commit bf0a6e0
Show file tree
Hide file tree
Showing 16 changed files with 136 additions and 135 deletions.
7 changes: 4 additions & 3 deletions include/taco/index_notation/expr_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
namespace taco {

class IndexVar;
class ExprVisitorStrict;
class IndexExprVisitorStrict;
class IndexNotationVisitorStrict;
class OperatorSplit;

/// A node of a scalar index expression tree.
Expand All @@ -20,7 +21,7 @@ struct ExprNode : public util::Manageable<ExprNode>, private util::Uncopyable {
ExprNode();
ExprNode(DataType type);
virtual ~ExprNode() = default;
virtual void accept(ExprVisitorStrict*) const = 0;
virtual void accept(IndexExprVisitorStrict*) const = 0;

/// Split the expression.
void splitOperator(IndexVar old, IndexVar left, IndexVar right);
Expand All @@ -43,7 +44,7 @@ struct TensorExprNode : public util::Manageable<TensorExprNode>,
TensorExprNode();
TensorExprNode(Type type);
virtual ~TensorExprNode() = default;
virtual void accept(ExprVisitorStrict*) const = 0;
virtual void accept(IndexNotationVisitorStrict*) const = 0;

Type getType() const;

Expand Down
30 changes: 15 additions & 15 deletions include/taco/index_notation/expr_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct AccessNode : public ExprNode {
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices)
: ExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand Down Expand Up @@ -46,15 +46,15 @@ struct UnaryExprNode : public ExprNode {
struct NegNode : public UnaryExprNode {
NegNode(IndexExpr operand) : UnaryExprNode(operand) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}
};

struct SqrtNode : public UnaryExprNode {
SqrtNode(IndexExpr operand) : UnaryExprNode(operand) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -80,7 +80,7 @@ struct AddNode : public BinaryExprNode {
return "+";
}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}
};
Expand All @@ -92,7 +92,7 @@ struct SubNode : public BinaryExprNode {
return "-";
}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}
};
Expand All @@ -104,7 +104,7 @@ struct MulNode : public BinaryExprNode {
return "*";
}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}
};
Expand All @@ -116,15 +116,15 @@ struct DivNode : public BinaryExprNode {
return "/";
}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}
};

struct ReductionNode : public ExprNode {
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -137,7 +137,7 @@ struct ReductionNode : public ExprNode {
struct IntImmNode : public ImmExprNode {
IntImmNode(long long val) : ImmExprNode(Int(sizeof(long long)*8)), val(val) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -147,7 +147,7 @@ struct IntImmNode : public ImmExprNode {
struct UIntImmNode : public ImmExprNode {
UIntImmNode(unsigned long long val) : ImmExprNode(UInt(sizeof(long long)*8)), val(val) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -157,7 +157,7 @@ struct UIntImmNode : public ImmExprNode {
struct ComplexImmNode : public ImmExprNode {
ComplexImmNode(std::complex<double> val) : ImmExprNode(Complex128), val(val){}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -167,7 +167,7 @@ struct ComplexImmNode : public ImmExprNode {
struct FloatImmNode : public ImmExprNode {
FloatImmNode(double val) : ImmExprNode(Float()), val(val) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -180,7 +180,7 @@ struct AssignmentNode : public TensorExprNode {
AssignmentNode(const Access& lhs, const IndexExpr& rhs, const IndexExpr& op)
: lhs(lhs), rhs(rhs), op(op) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexNotationVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -193,7 +193,7 @@ struct ForallNode : public TensorExprNode {
ForallNode(IndexVar indexVar, TensorExpr expr)
: indexVar(indexVar), expr(expr) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexNotationVisitorStrict* v) const {
v->visit(this);
}

Expand All @@ -205,7 +205,7 @@ struct WhereNode : public TensorExprNode {
WhereNode(TensorExpr consumer, TensorExpr producer)
: consumer(consumer), producer(producer) {}

void accept(ExprVisitorStrict* v) const {
void accept(IndexNotationVisitorStrict* v) const {
v->visit(this);
}

Expand Down
6 changes: 3 additions & 3 deletions include/taco/index_notation/expr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

namespace taco {

class ExprPrinter : public ExprVisitorStrict {
class IndexNotationPrinter : public IndexNotationVisitorStrict {
public:
ExprPrinter(std::ostream& os);
IndexNotationPrinter(std::ostream& os);

void print(const IndexExpr& expr);
void print(const TensorExpr& expr);

using ExprVisitorStrict::visit;
using IndexExprVisitorStrict::visit;

// Scalar Expressions
void visit(const AccessNode*);
Expand Down
4 changes: 2 additions & 2 deletions include/taco/index_notation/expr_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct ReductionNode;

struct AssignmentNode;

class ExprRewriterStrict : public ExprVisitorStrict {
class ExprRewriterStrict : public IndexExprVisitorStrict {
public:
virtual ~ExprRewriterStrict() {}

Expand All @@ -32,7 +32,7 @@ class ExprRewriterStrict : public ExprVisitorStrict {
TensorExpr rewrite(TensorExpr);

protected:
using ExprVisitorStrict::visit;
using IndexExprVisitorStrict::visit;

/// assign to expr in visit methods to replace the visited expr
IndexExpr expr;
Expand Down
37 changes: 22 additions & 15 deletions include/taco/index_notation/expr_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ struct WhereNode;

/// Visit the nodes in an expression. This visitor provides some type safety
/// by requing all visit methods to be overridden.
class ExprVisitorStrict {
class IndexExprVisitorStrict {
public:
virtual ~ExprVisitorStrict();
virtual ~IndexExprVisitorStrict();

void visit(const IndexExpr&);
void visit(const TensorExpr&);

// Scalar Index Expressions
virtual void visit(const AccessNode*) = 0;
Expand All @@ -51,22 +50,30 @@ class ExprVisitorStrict {
virtual void visit(const ComplexImmNode*) = 0;
virtual void visit(const UIntImmNode*) = 0;
virtual void visit(const ReductionNode*) = 0;

// Tensor Expressions
virtual void visit(const AssignmentNode*) {}
virtual void visit(const ForallNode*) {}
virtual void visit(const WhereNode*) {}
};

/// Visit nodes in index notation
class IndexNotationVisitorStrict : public IndexExprVisitorStrict {
public:
virtual ~IndexNotationVisitorStrict();

void visit(const TensorExpr&);

using IndexExprVisitorStrict::visit;

virtual void visit(const AssignmentNode*) = 0;
virtual void visit(const ForallNode*) = 0;
virtual void visit(const WhereNode*) = 0;
};

/// Visit nodes in an expression.
class ExprVisitor : public ExprVisitorStrict {
class IndexNotationVisitor : public IndexExprVisitorStrict {
public:
virtual ~ExprVisitor();
virtual ~IndexNotationVisitor();

using ExprVisitorStrict::visit;
using IndexExprVisitorStrict::visit;

// Scalar Index Expressions
// Index Expressions
virtual void visit(const AccessNode* op);
virtual void visit(const NegNode* op);
virtual void visit(const SqrtNode* op);
Expand Down Expand Up @@ -109,10 +116,10 @@ void visit(const Rule* op) { \
Rule##CtxFunc(op, this); \
return; \
} \
ExprVisitor::visit(op); \
IndexNotationVisitor::visit(op); \
}

class Matcher : public ExprVisitor {
class Matcher : public IndexNotationVisitor {
public:
template <class IndexExpr>
void match(IndexExpr indexExpr) {
Expand All @@ -132,7 +139,7 @@ class Matcher : public ExprVisitor {
unpack(rest...);
}

using ExprVisitor::visit;
using IndexNotationVisitor::visit;
RULE(AccessNode)
RULE(NegNode)
RULE(SqrtNode)
Expand Down
4 changes: 2 additions & 2 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class IndexExpr : public util::IntrusivePtr<const ExprNode> {
const Schedule& getSchedule() const;

/// Visit the index expression's sub-expressions.
void accept(ExprVisitorStrict *) const;
void accept(IndexExprVisitorStrict *) const;

/// Print the index expression.
friend std::ostream& operator<<(std::ostream&, const IndexExpr&);
Expand Down Expand Up @@ -213,7 +213,7 @@ class TensorExpr : public util::IntrusivePtr<const TensorExprNode> {
TensorExpr(const TensorExprNode* n);

/// Visit the tensor expression
void accept(ExprVisitorStrict *) const;
void accept(IndexNotationVisitorStrict *) const;
};

std::ostream& operator<<(std::ostream&, const TensorExpr&);
Expand Down
4 changes: 2 additions & 2 deletions src/index_notation/expr_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ using namespace std;
namespace taco {

vector<TensorVar> getOperands(const IndexExpr& expr) {
struct GetOperands : public ExprVisitor {
using ExprVisitor::visit;
struct GetOperands : public IndexNotationVisitor {
using IndexNotationVisitor::visit;
set<TensorVar> inserted;
vector<TensorVar> operands;
void visit(const AccessNode* node) {
Expand Down
Loading

0 comments on commit bf0a6e0

Please sign in to comment.