Skip to content

Commit

Permalink
ref: introduce templated formula and expression visitors
Browse files Browse the repository at this point in the history
  • Loading branch information
TendTo committed Oct 16, 2024
1 parent ac1aeaf commit 6e3e387
Show file tree
Hide file tree
Showing 31 changed files with 810 additions and 609 deletions.
4 changes: 2 additions & 2 deletions dlinear/parser/smt2/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ void Smt2Driver::GetValue(const std::vector<Term> &term_list) const {
switch (term.type()) {
case Term::Type::EXPRESSION: {
const Expression &e{term.expression()};
const ExpressionEvaluator evaluator{e};
const ExpressionEvaluator evaluator{e, context_.config()};
pp.Print(e);
term_str = ss.str();
const Interval iv{ExpressionEvaluator(term.expression())(box)};
const Interval iv{ExpressionEvaluator(term.expression(), context_.config())(box)};
value_str = (std::stringstream{} << iv).str();
break;
}
Expand Down
48 changes: 24 additions & 24 deletions dlinear/parser/smt2/Sort.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/**
* @author Ernesto Casablanca (casablancaernesto@gmail.com)
* @copyright 2024 dlinear
* @licence BSD 3-Clause License
* Sort enum.
*/
* @author Ernesto Casablanca (casablancaernesto@gmail.com)
* @copyright 2024 dlinear
* @licence BSD 3-Clause License
* Sort enum.
*/
#pragma once

#include <ostream>
Expand All @@ -15,34 +15,34 @@ namespace dlinear::smt2 {

/** Sort of a term. */
enum class Sort {
Binary, ///< Binary sort.
Bool, ///< Boolean sort.
Int, ///< Integer sort.
Real, ///< Real sort.
Binary, ///< Binary sort.
Bool, ///< Boolean sort.
Int, ///< Integer sort.
Real, ///< Real sort.
};

/**
* Parse a string to a sort.
* @param s string to parse
* @return sort parsed from @p s
*/
* Parse a string to a sort.
* @param s string to parse
* @return sort parsed from @p s
*/
Sort ParseSort(const std::string &s);
/**
* Convert a sort to a variable type.
*
* The conversion is as follows:
* - Binary -> BINARY
* - Bool -> BOOLEAN
* - Int -> INTEGER
* - Real -> CONTINUOUS
* @param sort sort to convert
* @return variable type corresponding to @p sort
*/
* Convert a sort to a variable type.
*
* The conversion is as follows:
* - Binary -> BINARY
* - Bool -> BOOLEAN
* - Int -> INTEGER
* - Real -> CONTINUOUS
* @param sort sort to convert
* @return variable type corresponding to @p sort
*/
Variable::Type SortToType(Sort sort);

std::ostream &operator<<(std::ostream &os, const Sort &sort);

} // namespace dlinear::vnnlib
} // namespace dlinear::smt2

#ifdef DLINEAR_INCLUDE_FMT

Expand Down
4 changes: 2 additions & 2 deletions dlinear/parser/vnnlib/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ void VnnlibDriver::GetValue(const std::vector<Term> &term_list) const {
switch (term.type()) {
case Term::Type::EXPRESSION: {
const Expression &e{term.expression()};
const ExpressionEvaluator evaluator{e};
const ExpressionEvaluator evaluator{e, context_.config()};
pp.Print(e);
term_str = ss.str();
const Interval iv{ExpressionEvaluator(term.expression())(box)};
const Interval iv{ExpressionEvaluator(term.expression(), context_.config())(box)};
value_str = (std::stringstream{} << iv).str();
break;
}
Expand Down
6 changes: 3 additions & 3 deletions dlinear/solver/ContextImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ void Context::Impl::AssertPiecewiseLinearFunction(const Variable &var, const For
DLINEAR_ASSERT(!var.is_dummy() && var.get_type() == Variable::Type::CONTINUOUS, "Variable must be a real variable");
DLINEAR_ASSERT(is_relational(cond), "Condition must be a relational formula");

const Formula condition_lit = predicate_abstractor_.Convert(cond);
const Formula active_lit = predicate_abstractor_.Convert(var - active == 0);
const Formula inactive_lit = predicate_abstractor_.Convert(var - inactive == 0);
const Formula condition_lit = predicate_abstractor_(cond);
const Formula active_lit = predicate_abstractor_(var - active == 0);
const Formula inactive_lit = predicate_abstractor_(var - inactive == 0);
// Make sure the cond is assigned a value (true or false) in the SAT solver
const Formula force_assignment(condition_lit || !condition_lit);
const Formula active_assertion{active_lit || !condition_lit};
Expand Down
6 changes: 3 additions & 3 deletions dlinear/solver/SatSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ std::vector<std::vector<Literal>> SatSolver::clauses() const {

void SatSolver::AddFormula(const Formula &f) {
DLINEAR_DEBUG_FMT("SatSolver::AddFormula({})", f);
std::vector<Formula> clauses{cnfizer_.Convert(f)};
auto [clauses, aux] = cnfizer_(f);

// Collect CNF variables and store them in `cnf_variables_`.
for (const Variable &p : cnfizer_.vars()) cnf_variables_.insert(p.get_id());
for (const Variable &p : aux) cnf_variables_.insert(p.get_id());
// Convert a first-order clauses into a Boolean formula by predicate abstraction
// The original can be retrieved by `predicate_abstractor_[abstracted_formula]`.
for (Formula &clause : clauses) clause = predicate_abstractor_.Convert(clause);
for (Formula &clause : clauses) clause = predicate_abstractor_.Process(clause);

AddClauses(clauses);
}
Expand Down
27 changes: 23 additions & 4 deletions dlinear/symbolic/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ dlinear_cc_library(
":literal",
"//dlinear/libs:gmp",
],
deps = [":symbolic"],
deps = [
":expression_visitor",
":formula_visitor",
":symbolic",
],
)

dlinear_cc_library(
Expand All @@ -44,9 +48,20 @@ dlinear_cc_library(

dlinear_cc_library(
name = "formula_visitor",
srcs = ["FormulaVisitor.cpp"],
hdrs = ["FormulaVisitor.h"],
implementation_deps = ["//dlinear/util:exception"],
hdrs = [
"FormulaVisitor.h",
"GenericFormulaVisitor.h",
],
deps = [
":symbolic",
"//dlinear/util:config",
"//dlinear/util:stats",
],
)

dlinear_cc_library(
name = "expression_visitor",
hdrs = ["GenericExpressionVisitor.h"],
deps = [
":symbolic",
"//dlinear/util:config",
Expand Down Expand Up @@ -94,6 +109,7 @@ dlinear_cc_library(
"//dlinear/util:exception",
],
deps = [
":expression_visitor",
":symbolic",
"//dlinear/util:box",
],
Expand All @@ -105,6 +121,7 @@ dlinear_cc_library(
hdrs = ["Nnfizer.h"],
implementation_deps = ["//dlinear/util:logging"],
deps = [
":formula_visitor",
":symbolic",
"//dlinear/util:config",
],
Expand Down Expand Up @@ -132,6 +149,8 @@ dlinear_cc_library(
"//dlinear/util:timer",
],
deps = [
":expression_visitor",
":formula_visitor",
":literal",
":symbolic",
"//dlinear/util:config",
Expand Down
27 changes: 16 additions & 11 deletions dlinear/symbolic/ExpressionEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,26 @@

namespace dlinear {

ExpressionEvaluator::ExpressionEvaluator(Expression e) : e_{std::move(e)} {}
ExpressionEvaluator::ExpressionEvaluator(Expression e, const Config& config)
: GenericExpressionVisitor<Interval, const Box&>{config, "ExpressionEvaluator"}, e_{std::move(e)} {}

Interval ExpressionEvaluator::operator()(const Box& box) const { return Visit(e_, box); }

Interval ExpressionEvaluator::Visit(const Expression& e, const Box& box) const {
return VisitExpression<Interval>(this, e, box);
Interval ExpressionEvaluator::Process(const Box& box) const {
const TimerGuard timer_guard(&stats_.m_timer(), stats_.enabled());
stats_.Increase();
return VisitExpression(e_, box);
}
Interval ExpressionEvaluator::operator()(const Box& box) const { return Process(box); }

Interval ExpressionEvaluator::VisitVariable(const Expression& e, const Box& box) {
Interval ExpressionEvaluator::VisitVariable(const Expression& e, const Box& box) const {
const Variable& var{get_variable(e)};
return box[var];
}

Interval ExpressionEvaluator::VisitConstant(const Expression& e, const Box&) { return Interval{get_constant_value(e)}; }
Interval ExpressionEvaluator::VisitConstant(const Expression& e, const Box&) const {
return Interval{get_constant_value(e)};
}

Interval ExpressionEvaluator::VisitRealConstant(const Expression&, const Box&) {
Interval ExpressionEvaluator::VisitRealConstant(const Expression&, const Box&) const {
DLINEAR_RUNTIME_ERROR("Operation is not supported yet.");
}

Expand All @@ -40,7 +44,7 @@ Interval ExpressionEvaluator::VisitAddition(const Expression& e, const Box& box)
const auto& expr_to_coeff_map = get_expr_to_coeff_map_in_addition(e);
return std::accumulate(expr_to_coeff_map.begin(), expr_to_coeff_map.end(), Interval{c},
[this, &box](const Interval& init, const std::pair<const Expression, mpq_class>& p) {
return init + Visit(p.first, box) * p.second;
return init + VisitExpression(p.first, box) * p.second;
});
}

Expand Down Expand Up @@ -148,11 +152,12 @@ Interval ExpressionEvaluator::VisitMax(const Expression&, const Box&) const {
DLINEAR_RUNTIME_ERROR("Operation is not supported yet.");
}

Interval ExpressionEvaluator::VisitIfThenElse(const Expression& /* unused */, const Box& /* unused */) {
Interval ExpressionEvaluator::VisitIfThenElse(const Expression& /* unused */, const Box& /* unused */) const {
DLINEAR_RUNTIME_ERROR("If-then-else expression is not supported yet.");
}

Interval ExpressionEvaluator::VisitUninterpretedFunction(const Expression& /* unused */, const Box& /* unused */) {
Interval ExpressionEvaluator::VisitUninterpretedFunction(const Expression& /* unused */,
const Box& /* unused */) const {
DLINEAR_RUNTIME_ERROR("Uninterpreted function is not supported.");
}

Expand Down
69 changes: 35 additions & 34 deletions dlinear/symbolic/ExpressionEvaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <iosfwd>

#include "dlinear/symbolic/GenericExpressionVisitor.h"
#include "dlinear/symbolic/symbolic.h"
#include "dlinear/util/Box.h"
#include "dlinear/util/Interval.h"
Expand All @@ -20,51 +21,51 @@ namespace dlinear {
* The ExpressionEvaluator is used to evaluate an expression with a given box.
* The box provides the values of the variables in the expression with intervals.
*/
class ExpressionEvaluator {
class ExpressionEvaluator : public GenericExpressionVisitor<Interval, const Box&> {
public:
explicit ExpressionEvaluator(Expression e);
/**
* Construct a new ExpressionEvaluator object with the given expression and configuration.
* @param e expression to evaluate
* @param config configuration to use
*/
ExpressionEvaluator(Expression e, const Config& config);

/// Evaluates the expression with @p box.
Interval operator()(const Box& box) const;
[[nodiscard]] Interval Process(const Box& box) const;
[[nodiscard]] Interval operator()(const Box& box) const;

[[nodiscard]] const Variables& variables() const { return e_.GetVariables(); }

[[nodiscard]] const Expression& expression() const { return e_; }

private:
[[nodiscard]] Interval Visit(const Expression& e, const Box& box) const;
static Interval VisitVariable(const Expression& e, const Box& box);
static Interval VisitConstant(const Expression& e, const Box& box);
static Interval VisitRealConstant(const Expression& e, const Box& box);
[[nodiscard]] Interval VisitAddition(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitMultiplication(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitDivision(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitLog(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitAbs(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitExp(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitSqrt(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitPow(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitVariable(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitConstant(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitRealConstant(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitAddition(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitMultiplication(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitDivision(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitLog(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitAbs(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitExp(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitSqrt(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitPow(const Expression& e, const Box& box) const override;

// Evaluates `pow(e1, e2)` with the @p box.
[[nodiscard]] Interval VisitPow(const Expression& e1, const Expression& e2, const Box& box) const;
[[nodiscard]] Interval VisitSin(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitCos(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitTan(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitAsin(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitAcos(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitAtan(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitAtan2(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitSinh(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitCosh(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitTanh(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitMin(const Expression& e, const Box& box) const;
[[nodiscard]] Interval VisitMax(const Expression& e, const Box& box) const;
static Interval VisitIfThenElse(const Expression& e, const Box& box);
static Interval VisitUninterpretedFunction(const Expression& e, const Box& box);

// Makes VisitExpression a friend of this class so that it can use private
// operator()s.
friend Interval drake::symbolic::VisitExpression<Interval>(const ExpressionEvaluator*, const Expression&, const Box&);
[[nodiscard]] Interval VisitSin(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitCos(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitTan(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitAsin(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitAcos(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitAtan(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitAtan2(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitSinh(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitCosh(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitTanh(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitMin(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitMax(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitIfThenElse(const Expression& e, const Box& box) const override;
[[nodiscard]] Interval VisitUninterpretedFunction(const Expression& e, const Box& box) const override;

const Expression e_;
};
Expand Down
44 changes: 0 additions & 44 deletions dlinear/symbolic/FormulaVisitor.cpp

This file was deleted.

Loading

0 comments on commit 6e3e387

Please sign in to comment.