From 2c336c134d980dca6d7665d2b22335575c924298 Mon Sep 17 00:00:00 2001 From: Kukovec Date: Thu, 19 Oct 2023 15:03:05 +0200 Subject: [PATCH] `UninterpretedConstOracle` refactor (#2734) * UninterpretedConstOracle * Reduced # of tests for CI * solver reuse * Suggestion by Thomas --- .../oracles/UninterpretedConstOracle.scala | 67 +++++++ .../aux/oracles/TestIntOracle.scala | 2 +- .../TestUninterpretedConstOracle.scala | 168 ++++++++++++++++++ 3 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/UninterpretedConstOracle.scala create mode 100644 tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestUninterpretedConstOracle.scala diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/UninterpretedConstOracle.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/UninterpretedConstOracle.scala new file mode 100644 index 0000000000..92e5549c93 --- /dev/null +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/UninterpretedConstOracle.scala @@ -0,0 +1,67 @@ +package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles + +import at.forsyte.apalache.tla.bmcmt._ +import at.forsyte.apalache.tla.bmcmt.smt.SolverContext +import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache +import at.forsyte.apalache.tla.bmcmt.stratifiedRules.{Rewriter, RewriterScope} +import at.forsyte.apalache.tla.bmcmt.types.CellT +import at.forsyte.apalache.tla.lir.ConstT1 +import at.forsyte.apalache.tla.typecomp.TBuilderInstruction +import at.forsyte.apalache.tla.types.tla + +/** + * An oracle that uses a fixed collection of potential cells. + * + * The oracle value must be equal to one of the cells, if the collection is nonempty. + * + * @author + * Jure Kukovec + */ +class UninterpretedConstOracle(val valueCells: Seq[ArenaCell], val oracleCell: ArenaCell) extends Oracle { + + /** + * The number of values that this oracle is defined over: |valueCells| + */ + override def size: Int = valueCells.size + + override def chosenValueIsEqualToIndexedValue(scope: RewriterScope, index: BigInt): TBuilderInstruction = + if (valueCells.indices.contains(index)) tla.eql(oracleCell.toBuilder, valueCells(index.toInt).toBuilder) + else tla.bool(false) + + def getIndexOfChosenValueFromModel(solverContext: SolverContext): BigInt = + // the oracle must be equal to one of the values. If not, indexWhere returns -1 + valueCells.indexWhere { valueCell => + val eq = tla.eql(valueCell.toBuilder, oracleCell.toBuilder) + solverContext.evalGroundExpr(eq) == tla.bool(true).build + } +} + +object UninterpretedConstOracle { + + /** + * Designated type to be used in this oracle. + */ + val UNINTERPRETED_TYPE: ConstT1 = ConstT1("_ORA") + + def create( + rewriter: Rewriter, + cache: UninterpretedLiteralCache, + scope: RewriterScope, + nvalues: Int): (RewriterScope, UninterpretedConstOracle) = { + require(nvalues >= 0, "UninterpretedConstOracle must have a non-negative number of candidate values.") + val (newArena, valueCells) = + 0.until(nvalues).map(_.toString).foldLeft((scope.arena, Seq.empty[ArenaCell])) { case ((arena, cells), name) => + val (newArena, newCell) = cache.getOrCreate(arena, (UNINTERPRETED_TYPE, name)) + (newArena, cells :+ newCell) + } + val arenaWithCell = newArena.appendCell(CellT.fromType1(UNINTERPRETED_TYPE)) + val newScope = scope.copy(arena = arenaWithCell) + val oracleCell = arenaWithCell.topCell + val oracle = new UninterpretedConstOracle(valueCells, oracleCell) + + // the oracle value must be equal to one of the value cells, if there are any + if (nvalues > 0) + rewriter.assert(tla.or(0.until(nvalues).map(i => oracle.chosenValueIsEqualToIndexedValue(newScope, i)): _*)) + (newScope, oracle) + } +} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestIntOracle.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestIntOracle.scala index 56b99a5cf1..8485c3f3e4 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestIntOracle.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestIntOracle.scala @@ -47,7 +47,7 @@ class TestIntOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers { check(prop, minSuccessful(100), sizeRange(4)) } - test("oracleValueIsEqualToIndexedValue returns an integer comparison") { + test("chosenValueIsEqualToIndexedValue returns an integer comparison") { val prop = forAll(maxSizeAndIndexGen) { case (size, index) => val (scope, oracle) = IntOracle.create(initScope, size) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestUninterpretedConstOracle.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestUninterpretedConstOracle.scala new file mode 100644 index 0000000000..6688fbc3df --- /dev/null +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/oracles/TestUninterpretedConstOracle.scala @@ -0,0 +1,168 @@ +package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles + +import at.forsyte.apalache.tla.bmcmt.PureArena +import at.forsyte.apalache.tla.bmcmt.arena.PureArenaAdapter +import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext} +import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache +import at.forsyte.apalache.tla.bmcmt.stratifiedRules.{Rewriter, RewriterScope, TestingRewriter} +import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.oper.TlaOper +import at.forsyte.apalache.tla.typecomp.TBuilderInstruction +import at.forsyte.apalache.tla.types.tla +import org.junit.runner.RunWith +import org.scalacheck.Prop.forAll +import org.scalacheck.{Gen, Prop} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite +import org.scalatestplus.junit.JUnitRunner +import org.scalatestplus.scalacheck.Checkers + +@RunWith(classOf[JUnitRunner]) +class TestUninterpretedConstOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers { + + var rewriter: Rewriter = TestingRewriter(Map.empty) + var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache + var initScope: RewriterScope = RewriterScope.initial() + + override def beforeEach(): Unit = { + rewriter = TestingRewriter(Map.empty) + cache = new UninterpretedLiteralCache + initScope = RewriterScope.initial() + } + + val intGen: Gen[Int] = Gen.choose(-10, 10) + val nonNegIntGen: Gen[Int] = Gen.choose(0, 10) + + val maxSizeAndIndexGen: Gen[(Int, Int)] = for { + max <- Gen.choose(1, 10) // size 0 is degenerate + idx <- Gen.choose(0, max - 1) // index must be < + } yield (max, idx) + + test("Oracle cannot be constructed with negative size") { + val prop = + forAll(intGen) { + case i if i < 0 => + Prop.throws(classOf[IllegalArgumentException]) { + UninterpretedConstOracle.create(rewriter, cache, initScope, i) + } + case i => UninterpretedConstOracle.create(rewriter, cache, initScope, i)._2.size == i + } + + check(prop, minSuccessful(100), sizeRange(4)) + } + + test("chosenValueIsEqualToIndexedValue returns an equality, or shorthands") { + val prop = + forAll(Gen.zip(nonNegIntGen, intGen)) { case (size, index) => + val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) + val cmp: TlaEx = oracle.chosenValueIsEqualToIndexedValue(scope, index) + if (index < 0 || index >= size) + cmp == tla.bool(false).build + else + cmp match { + case OperEx(TlaOper.eq, NameEx(name1), NameEx(name2)) => + name1 == oracle.oracleCell.toString && name2 == oracle.valueCells(index).toString + case _ => false + } + } + + check(prop, minSuccessful(200), sizeRange(4)) + } + + val (assertionsA, assertionsB): (Seq[TBuilderInstruction], Seq[TBuilderInstruction]) = 0 + .to(10) + .map { i => + (tla.name(s"A$i", BoolT1), tla.name(s"B$i", BoolT1)) + } + .unzip + + test("caseAssertions requires assertion sequences of equal length") { + val assertionsGen: Gen[(Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for { + i <- Gen.choose(0, assertionsA.size) + j <- Gen.choose(0, assertionsB.size) + opt <- Gen.option(Gen.const(assertionsB.take(j))) + } yield (assertionsA.take(i), opt) + + val prop = + forAll(Gen.zip(nonNegIntGen, assertionsGen)) { case (size, (assertionsIfTrue, assertionsIfFalseOpt)) => + val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) + if (assertionsIfTrue.size != oracle.size || assertionsIfFalseOpt.exists { _.size != oracle.size }) + Prop.throws(classOf[IllegalArgumentException]) { + oracle.caseAssertions(scope, assertionsIfTrue, assertionsIfFalseOpt) + } + else true + } + + check(prop, minSuccessful(200), sizeRange(4)) + } + + test("caseAssertions constructs a collection of ITEs, or shorthands") { + val gen: Gen[(Int, Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for { + size <- nonNegIntGen + opt <- Gen.option(Gen.const(assertionsB.take(size))) + } yield (size, assertionsA.take(size), opt) + + val prop = + forAll(gen) { case (size, assertionsIfTrue, assertionsIfFalseOpt) => + val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) + val caseEx: TlaEx = oracle.caseAssertions(scope, assertionsIfTrue, assertionsIfFalseOpt) + size match { + case 0 => + caseEx == PureArena.cellTrue(scope.arena).toBuilder.build + case 1 => + caseEx == assertionsA.head.build + case _ => + assertionsIfFalseOpt match { + case None => + val ites = assertionsIfTrue.zip(oracle.valueCells).map { case (a, c) => + tla.ite(tla.eql(oracle.oracleCell.toBuilder, c.toBuilder), a, tla.bool(true)) + } + caseEx == tla.and(ites: _*).build + case Some(assertionsIfFalse) => + val ites = assertionsIfTrue.zip(assertionsIfFalse).zip(oracle.valueCells).map { case ((at, af), c) => + tla.ite(tla.eql(oracle.oracleCell.toBuilder, c.toBuilder), at, af) + } + caseEx == tla.and(ites: _*).build + } + } + } + + check(prop, minSuccessful(200), sizeRange(4)) + } + + // We cannot test getIndexOfChosenValueFromModel without running the solver + // Ignored until we figure out why it's killing GH CLI + ignore("getIndexOfChosenValueFromModel recovers the index correctly for nonempty cell collection") { + val ctx = new Z3SolverContext(SolverConfig.default) + val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization + initScope = initScope.copy(arena = paa.arena) + val prop = + forAll(maxSizeAndIndexGen) { case (size, index) => + cache.dispose() // prevent redeclarations in every loop + val (scope, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope, size) + ctx.push() + oracle.valueCells.foreach(ctx.declareCell) + ctx.declareCell(oracle.oracleCell) + cache.addAllConstraints(ctx) + val eql = oracle.chosenValueIsEqualToIndexedValue(scope, index) + ctx.assertGroundExpr(eql) + ctx.sat() + val ret = oracle.getIndexOfChosenValueFromModel(ctx) == index + ctx.pop() + ret + } + + // 1000 is too many, since each run invokes the solver + check(prop, minSuccessful(80), sizeRange(4)) + } + + test("getIndexOfChosenValueFromModel recovers the index correctly for empty collections") { + val ctx = new Z3SolverContext(SolverConfig.default) + val paa = PureArenaAdapter.create(ctx) // We use PAA, since it performs the basic context initialization + val (_, oracle) = UninterpretedConstOracle.create(rewriter, cache, initScope.copy(arena = paa.arena), 0) + ctx.declareCell(oracle.oracleCell) + ctx.sat() + assert(oracle.getIndexOfChosenValueFromModel(ctx) == -1) + } + +}