Skip to content

Commit

Permalink
Adds compound operator to assignments
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikbk committed Apr 27, 2018
1 parent 8ab0a12 commit b29b033
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 15 deletions.
4 changes: 3 additions & 1 deletion include/taco/index_notation/expr_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,16 @@ struct FloatImmNode : public ImmExprNode {

// Tensor Index Expressions
struct AssignmentNode : public TensorExprNode {
AssignmentNode(const Access& lhs, const IndexExpr& rhs) : lhs(lhs), rhs(rhs){}
AssignmentNode(const Access& lhs, const IndexExpr& rhs, const IndexExpr& op)
: lhs(lhs), rhs(rhs), op(op) {}

void accept(ExprVisitorStrict* v) const {
v->visit(this);
}

Access lhs;
IndexExpr rhs;
IndexExpr op;
};

struct ForallNode : public TensorExprNode {
Expand Down
10 changes: 7 additions & 3 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ class Access : public IndexExpr {
/// ```
Assignment operator=(const IndexExpr&);

// Must override the default Access operator=, otherwise it is a copy.
/// Must override the default Access operator=, otherwise it is a copy.
Assignment operator=(const Access&);

/// Accumulate the result of an expression to a left-hand-side tensor access.
/// ```
/// a(i) += B(i,j) * c(j);
/// ```
void operator+=(const IndexExpr&);
Assignment operator+=(const IndexExpr&);

private:
const Node* getPtr() const;
Expand Down Expand Up @@ -224,7 +224,11 @@ std::ostream& operator<<(std::ostream&, const TensorExpr&);
class Assignment : public TensorExpr {
public:
Assignment(const AssignmentNode*);
Assignment(TensorVar tensor, std::vector<IndexVar> indices, IndexExpr expr);

/// Create an assignment. Can specify an optional operator `op` that turns the
/// assignment into a compound assignment, e.g. `+=`.
Assignment(TensorVar tensor, std::vector<IndexVar> indices, IndexExpr expr,
IndexExpr op = IndexExpr());

Access getLhs() const;
IndexExpr getRhs() const;
Expand Down
14 changes: 13 additions & 1 deletion src/index_notation/expr_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,20 @@ void ExprPrinter::visit(const ReductionNode* op) {
}

void ExprPrinter::visit(const AssignmentNode* op) {
struct OperatorName : ExprVisitor {
std::string operatorName;
std::string get(IndexExpr expr) {
if (!expr.defined()) return "";
expr.accept(this);
return operatorName;
}
void visit(const BinaryExprNode* node) {
operatorName = node->getOperatorString();
}
};

op->lhs.accept(this);
os << " = ";
os << " " << OperatorName().get(op->op) << "= ";
op->rhs.accept(this);
}

Expand Down
6 changes: 4 additions & 2 deletions src/index_notation/expr_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ void ExprRewriter::visit(const ReductionNode* op) {
}

void ExprRewriter::visit(const AssignmentNode* op) {
// A design decission is to not visit the rhs access expressions or the op,
// as these are considered part of the assignment. When visiting access
// expressions, therefore, we only visit read access expressions.
IndexExpr rhs = rewrite(op->rhs);
if (rhs == op->rhs) {
texpr = op;
}
else {
texpr = new AssignmentNode(op->lhs, rhs);
texpr = new AssignmentNode(op->lhs, rhs, op->op);
}

}


Expand Down
7 changes: 4 additions & 3 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,13 @@ Assignment Access::operator=(const Access& expr) {
return operator=(static_cast<IndexExpr>(expr));
}

void Access::operator+=(const IndexExpr& expr) {
Assignment Access::operator+=(const IndexExpr& expr) {
TensorVar result = getTensorVar();
taco_uassert(!result.getIndexExpr().defined()) << "Cannot reassign " <<result;
// TODO: check that result format is dense. For now only support accumulation
/// into dense. If it's not dense, then we can insert an operator split.
const_cast<AccessNode*>(getPtr())->setIndexExpression(expr, true);
return Assignment(result, result.getFreeVars(), expr, new AddNode);
}


Expand Down Expand Up @@ -286,8 +287,8 @@ Assignment::Assignment(const AssignmentNode* n) : TensorExpr(n) {
}

Assignment::Assignment(TensorVar tensor, vector<IndexVar> indices,
IndexExpr expr)
: Assignment(new AssignmentNode(Access(tensor, indices), expr)) {
IndexExpr expr, IndexExpr op)
: Assignment(new AssignmentNode(Access(tensor, indices), expr, op)) {
}

Access Assignment::getLhs() const {
Expand Down
8 changes: 3 additions & 5 deletions test/concrete-notation-tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@ TEST(concrete, where) {
// std::cout << vecmul << std::endl;
}

TEST(DISABLED_concrete, spmm) {
TEST(concrete, spmm) {
Type t(type<double>(), {3,3});
TensorVar A("A", t, Sparse), B("B", t, Sparse), C("C", t, Sparse);
TensorVar w("w", Type(type<double>(),{3}), Dense);

auto spmm = forall(i,
TensorVar w("w", Type(type<double>(),{3}), Dense); auto spmm = forall(i,
forall(k,
where(forall(j, A(i,j) = w(j)),
forall(j, w(j) = B(i,k)*C(k,j))
forall(j, w(j) += B(i,k)*C(k,j))
)
)
);
Expand Down

0 comments on commit b29b033

Please sign in to comment.