From 737f5f15d33d74b3107e83a6a061e2dbe6e1e99e Mon Sep 17 00:00:00 2001 From: AlexShefY Date: Sat, 2 Nov 2024 12:19:54 +0100 Subject: [PATCH 1/2] start manual rewriting --- scripts/plot_statistics.py | 14 +- tests/test_inequality_replacer.py | 262 ++++++++++++ tests/test_rewriters.py | 383 ++++++++++++++++++ .../runners/invariants_with_rewriting.py | 32 ++ verified_cogen/runners/rewriters/__init__.py | 5 + .../runners/rewriters/nagini_rewriter.py | 22 + .../rewriters/nagini_rewriter_fixing.py | 74 ++++ .../rewriters/nagini_rewriter_fixing_ast.py | 32 ++ verified_cogen/tools/inequality_replacer.py | 48 +++ 9 files changed, 866 insertions(+), 6 deletions(-) create mode 100644 tests/test_inequality_replacer.py create mode 100644 tests/test_rewriters.py create mode 100644 verified_cogen/runners/invariants_with_rewriting.py create mode 100644 verified_cogen/runners/rewriters/__init__.py create mode 100644 verified_cogen/runners/rewriters/nagini_rewriter.py create mode 100644 verified_cogen/runners/rewriters/nagini_rewriter_fixing.py create mode 100644 verified_cogen/runners/rewriters/nagini_rewriter_fixing_ast.py create mode 100644 verified_cogen/tools/inequality_replacer.py diff --git a/scripts/plot_statistics.py b/scripts/plot_statistics.py index e2ecaf2..15664da 100644 --- a/scripts/plot_statistics.py +++ b/scripts/plot_statistics.py @@ -1,5 +1,5 @@ # %% -with open('log_tries/logs3.txt', 'r') as file: +with open('log_tries/logs9.txt', 'r') as file: # Read the file and split the contents into lines lines = file.readlines() @@ -16,6 +16,7 @@ "Loop invariant might not be preserved." : 0, "Loop invariant might not hold on entry." : 0, "Assert might fail." : 0, + "Verification timed out" : 0, } explanations = { @@ -27,12 +28,13 @@ "Loop invariant might not be preserved." : "Loop invariant might not be preserved", "Loop invariant might not hold on entry." : "Loop invariant might not hold on entry", "Assert might fail." : "Assert might fail", + "Verification timed out" : "Verification timed out", } dict_erros_numbered = {} for (key, value) in dict_erros.items(): - for j in range(1, 11): + for j in range(1, 6): dict_erros_numbered[(key, j)] = 0 print(dict_erros_numbered) @@ -40,11 +42,11 @@ # The 'lines' variable will now be a list of lines from the file idx_line = 0 for line in lines: - if "Verification failed:" in line: + if "Verification failed:" in line or "Verification timed out" in line: idx_line += 1 - if "verified with" in line: + if "Verified" in line or "Failed to verify" in line: idx_line = 0 - if idx_line == 11: + if idx_line == 6: idx_line = 1 for (key, value) in dict_erros.items(): if key in line: @@ -80,7 +82,7 @@ # Step 3: Customize the plot plt.xlabel('Try') -tries_range = range(1, 11) # Assuming the "tries" are from 1 to 5 +tries_range = range(1, 6) # Assuming the "tries" are from 1 to 5 plt.xticks(tries_range) plt.ylabel('Number of Occurrences') diff --git a/tests/test_inequality_replacer.py b/tests/test_inequality_replacer.py new file mode 100644 index 0000000..12a07c5 --- /dev/null +++ b/tests/test_inequality_replacer.py @@ -0,0 +1,262 @@ +from textwrap import dedent +from verified_cogen.tools.inequality_replacer import replace_inequalities, contains_double_inequality + +def test_simple_contains1(): + code = dedent( + """\ + def test(): + if a >= b >= c > d: + print('Chained') + """ + ) + + assert contains_double_inequality(code) + +def test_simple_contains2(): + code = dedent( + """\ +Implies(1 <= k < d_6_i_, xs[k - 1] <= xs[k]) +d_4_increasing_ == Forall(int, lambda k: Implies(1 <= k < d_6_i_, xs[k - 1] <= xs[k])) +Invariant(d_4_increasing_ == Forall(int, lambda k: Implies(1 <= k < d_6_i_, xs[k - 1] <= xs[k]))) + """ + ) + + new_code = replace_inequalities(code) + compare_code = dedent( + """\ +Implies(1 <= k and k < d_6_i_, xs[k - 1] <= xs[k]) +d_4_increasing_ == Forall(int, lambda k: Implies(1 <= k and k < d_6_i_, xs[k - 1] <= xs[k])) +Invariant(d_4_increasing_ == Forall(int, lambda k: Implies(1 <= k and k < d_6_i_, xs[k - 1] <= xs[k])))""" + ) + + assert contains_double_inequality(code) + assert new_code == compare_code + +def test_simple_contains3(): + code = dedent( + """\ + def test(): + if a < b and b < c and c <= d: + print('Chained') + """ + ) + + assert not contains_double_inequality(code) + +def test_simple(): + code = dedent( + """\ + def test(): + if a < b < c < d: + print('Chained') + """ + ) + + new_code = replace_inequalities(code) + + compare_code = dedent( + """\ + def test(): + if a < b and b < c and (c < d): + print('Chained')""" + ) + + assert new_code == compare_code + +def test_simple1(): + code = dedent( + """\ + def test(): + if a < b <= c < d: + print('Chained') + """ + ) + + new_code = replace_inequalities(code) + + compare_code = dedent( + """\ + def test(): + if a < b and b <= c and (c < d): + print('Chained')""" + ) + + assert new_code == compare_code + +def test_complicated(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def factorial__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n : int) -> int : + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n : int) -> List[int]: + Requires((n) >= (1)) + Ensures(Acc(list_pred(Result()))) + Ensures((len(Result())) == (n)) + Ensures(Forall(int, lambda d_2_i_: + not ((((d_2_i_) >= (0)) and ((d_2_i_) < (len(Result())))) and (((d_2_i_ % 2)) == (0))) or (((Result())[d_2_i_]) == (factorial__spec(d_2_i_ - 1))))) + Ensures(Forall(int, lambda d_3_i_: + not ((((d_3_i_) >= (0)) and ((d_3_i_) < (len(Result())))) and (((d_3_i_ % 2)) != (0))) or (((Result())[d_3_i_]) == (sum__spec(d_3_i_ - 1))))) + + result: List[int] = [] + d_4_i_ = 0 + while (d_4_i_) < (n): + Invariant(0 <= d_4_i_ <= n) + Invariant(len(result) == d_4_i_) + Invariant(Acc(list_pred(result))) + Invariant(Forall(int, lambda i:Implies( (0 <= i < d_4_i_ and i % 2 == 0) , result[i] == factorial__spec(i - 1)))) + Invariant(Forall(int, lambda i:Implies( (0 <= i < d_4_i_ and i % 2 != 0) , result[i] == sum__spec(i - 1)))) + + if ((d_4_i_ % 2)) == (0): + d_7_x_ = 1 + d_8_j_ = 0 + while (d_8_j_) < (d_4_i_): + Invariant(0 <= d_8_j_ <= d_4_i_) + Invariant(d_7_x_ == factorial__spec(d_8_j_ - 1)) + d_7_x_ = (d_7_x_) * (d_8_j_ + 1) + d_8_j_ = (d_8_j_) + (1) + result = (result) + [d_7_x_] + else: + d_9_x_ = 0 + d_10_j_ = 0 + while (d_10_j_) < (d_4_i_): + Invariant(0 <= d_10_j_ <= d_4_i_) + Invariant(d_9_x_ == sum__spec(d_10_j_ - 1)) + d_9_x_ = (d_9_x_) + (d_10_j_ + 1) + d_10_j_ = (d_10_j_) + (1) + result = (result) + [d_9_x_] + d_4_i_ = (d_4_i_) + (1) + return result + """ + ) + + new_code = replace_inequalities(code) + + compare_code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +@Pure +def factorial__spec(n: int) -> int: + Requires(n >= -1) + Ensures(Result() >= 0) + if n == -1: + return 1 + else: + return (n + 1) * factorial__spec(n - 1) + +@Pure +def sum__spec(n: int) -> int: + Requires(n >= -1) + Ensures(Result() >= 0) + if 0 > n: + return 0 + else: + return n + 1 + sum__spec(n - 1) + +def f(n: int) -> List[int]: + Requires(n >= 1) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == n) + Ensures(Forall(int, lambda d_2_i_: not ((d_2_i_ >= 0 and d_2_i_ < len(Result())) and d_2_i_ % 2 == 0) or Result()[d_2_i_] == factorial__spec(d_2_i_ - 1))) + Ensures(Forall(int, lambda d_3_i_: not ((d_3_i_ >= 0 and d_3_i_ < len(Result())) and d_3_i_ % 2 != 0) or Result()[d_3_i_] == sum__spec(d_3_i_ - 1))) + result: List[int] = [] + d_4_i_ = 0 + while d_4_i_ < n: + Invariant(0 <= d_4_i_ and d_4_i_ <= n) + Invariant(len(result) == d_4_i_) + Invariant(Acc(list_pred(result))) + Invariant(Forall(int, lambda i: Implies((0 <= i and i < d_4_i_) and i % 2 == 0, result[i] == factorial__spec(i - 1)))) + Invariant(Forall(int, lambda i: Implies((0 <= i and i < d_4_i_) and i % 2 != 0, result[i] == sum__spec(i - 1)))) + if d_4_i_ % 2 == 0: + d_7_x_ = 1 + d_8_j_ = 0 + while d_8_j_ < d_4_i_: + Invariant(0 <= d_8_j_ and d_8_j_ <= d_4_i_) + Invariant(d_7_x_ == factorial__spec(d_8_j_ - 1)) + d_7_x_ = d_7_x_ * (d_8_j_ + 1) + d_8_j_ = d_8_j_ + 1 + result = result + [d_7_x_] + else: + d_9_x_ = 0 + d_10_j_ = 0 + while d_10_j_ < d_4_i_: + Invariant(0 <= d_10_j_ and d_10_j_ <= d_4_i_) + Invariant(d_9_x_ == sum__spec(d_10_j_ - 1)) + d_9_x_ = d_9_x_ + (d_10_j_ + 1) + d_10_j_ = d_10_j_ + 1 + result = result + [d_9_x_] + d_4_i_ = d_4_i_ + 1 + return result""" + ) + + assert new_code == compare_code + +def test_complicated1(): + + code = dedent("""\ +def BubbleSort(a1 : List[int]) -> List[int]: + Requires(Acc(list_pred(a1), 1/2)) + Requires(Forall(int, lambda i : Implies(i >= 0 and i < len(a1), a1[i] >= 0))) + Ensures(Acc(list_pred(a1), 1/2)) + Ensures(Acc(list_pred(Result()))) + Ensures((len(a1)) == (len(Result()))) + Ensures(Forall(int, lambda i : Implies(i >= 0 and i < len(Result()), Result()[i] >= 0))) + Ensures(Forall(int, lambda d_0_i_: + Forall(int, lambda d_1_j_: + Implies((((0) <= (d_0_i_)) and ((d_0_i_) < (d_1_j_))) and ((d_1_j_) < (len((Result())))), popcount((Result())[d_0_i_]) <= popcount((Result())[d_1_j_]))))) + + a = list(a1) + d_2_i_ = int(0) + d_2_i_ = (len((a))) - (1) + while (d_2_i_) > (0): + Invariant(Acc(list_pred(a))) + Invariant(0 <= d_2_i_ < len(a)) + Invariant(Forall(int, lambda k: Implies(d_2_i_ < k and k < len(a), Forall(int, lambda m: Implies(0 <= m and m < k, popcount(a[m]) <= popcount(a[k])))))) + Invariant(Forall(int, lambda i: Implies(0 <= i and i < len(a), a[i] >= 0))) + Invariant(len(a) == len(a1)) + Invariant(Forall(int, lambda i: Implies(0 <= i and i < len(a), Exists(int, lambda j: (0 <= j and j < len(a1) and a[i] == a1[j]))))) + Invariant(Forall(int, lambda i: Implies(0 <= i and i < len(a1), Exists(int, lambda j: (0 <= j and j < len(a) and a1[i] == a[j]))))) + + d_7_j_ = int(0) + d_7_j_ = 0 + while (d_7_j_) < (d_2_i_): + Invariant(Acc(list_pred(a))) + Invariant(0 <= d_7_j_ <= d_2_i_ < len(a)) + Invariant(Forall(int, lambda k: Implies(0 <= k and k < d_7_j_, popcount(a[k]) <= popcount(a[d_7_j_])))) + Invariant(Forall(int, lambda i: Implies(0 <= i and i < len(a), a[i] >= 0))) + Invariant(len(a) == len(a1)) + Invariant(Forall(int, lambda i: Implies(0 <= i and i < len(a), Exists(int, lambda j: (0 <= j and j < len(a1) and a[i] == a1[j]))))) + Invariant(Forall(int, lambda i: Implies(0 <= i and i < len(a1), Exists(int, lambda j: (0 <= j and j < len(a) and a1[i] == a[j]))))) + + if popcount((a)[d_7_j_]) > popcount((a)[(d_7_j_) + (1)]): + rhs0_ = (a)[(d_7_j_) + (1)] + (a)[(d_7_j_) + (1)] = (a)[d_7_j_] + (a)[d_7_j_] = rhs0_ + d_7_j_ = (d_7_j_) + (1) + d_2_i_ = (d_2_i_) - (1) + return a + """) + + new_code = replace_inequalities(code) + + print(new_code) \ No newline at end of file diff --git a/tests/test_rewriters.py b/tests/test_rewriters.py new file mode 100644 index 0000000..399ea18 --- /dev/null +++ b/tests/test_rewriters.py @@ -0,0 +1,383 @@ +from textwrap import dedent + +from verified_cogen.runners.rewriters.nagini_rewriter import NaginiRewriter +from verified_cogen.runners.rewriters.nagini_rewriter_fixing import NaginiRewriterFixing +from verified_cogen.runners.rewriters.nagini_rewriter_fixing_ast import NaginiRewriterFixingAST + + +def test_nagini_rewriter(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def Compare(scores: List[int], guesses: List[int]) -> List[int]: + Requires(Acc(list_pred(guesses))) + Requires(Acc(list_pred(scores))) + Requires(len(scores) == len(guesses)) + Ensures(Acc(list_pred(guesses))) + Ensures(Acc(list_pred(scores))) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == len(scores)) + Ensures(len(scores) == len(guesses)) + Ensures(Forall(int, lambda d_0_i_: + not (0 <= d_0_i_ and d_0_i_ < len(Result())) or (Result()[d_0_i_] == abs(scores[d_0_i_] - guesses[d_0_i_])))) + + result = [int(0)] * 0 + nw0_ = [int(0)] * len(scores) + result = nw0_ + d_1_i_ = int(0) + d_1_i_ = 0 + while d_1_i_ < len(scores): + Invariant(Acc(list_pred(scores))) + Invariant(Acc(list_pred(guesses))) + Invariant(Acc(list_pred(result))) + Invariant(0 <= d_1_i_ and d_1_i_ <= len(scores)) + Invariant(len(result) == len(scores)) + Invariant(len(scores) == len(guesses)) + Invariant(Forall(int, lambda k: (0 <= k and k < d_1_i_) ==> (result[k] == abs(scores[k] - guesses[k])))) + Invariant(Forall(int, lambda k: (0 <= k and k < len(scores)) ==> (scores[k] == Old(scores[k])))) + Invariant(Forall(int, lambda k: (0 <= k and k < len(guesses)) ==> (guesses[k] == Old(guesses[k])))) + result[d_1_i_] = abs(scores[d_1_i_] - guesses[d_1_i_]) + d_1_i_ = d_1_i_ + 1 + return result + """ + ) + + error = dedent( + """\ +Manual inspection revealed occurrences of `==>` operator for implication on the following positions: +(28, 65), (29, 70), (30, 71) +`==>` operator does not exist in Nagini. All occurrences of `==>` operator should be replaced with `Implies(a, b)` operator, that is used to express implication in Nagini + """ + ) + + prg, prompt = NaginiRewriter().rewrite(code) + assert prompt == error + + +def test_nagini_rewriter1(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def get__positive(l : List[int]) -> List[int]: + Requires(Acc(list_pred(l))) + Ensures(Acc(list_pred(l))) + Ensures(Acc(list_pred(Result()))) + Ensures(Forall(int, lambda d_0_i_: + not (((d_0_i_) >= (0)) and ((d_0_i_) < (len(Result())))) or (((Result())[d_0_i_]) > (0)))) + Ensures((len(Result())) <= (len(l))) + Ensures(Forall(int, lambda d_1_i1_: + not (((d_1_i1_) >= (0)) and ((d_1_i1_) < (len(l)))) or (not (((l)[d_1_i1_]) > (0)) or (Exists(int, lambda d_2_i2_: + (((d_2_i2_) >= (0)) and ((d_2_i2_) < (len(Result())))) and (((Result())[d_2_i2_]) == ((l)[d_1_i1_]))))))) + Ensures(((len(Result())) == (0)) or (Forall(int, lambda d_3_i1_: + not (((d_3_i1_) >= (0)) and ((d_3_i1_) < (len(Result())))) or (Exists(int, lambda d_4_i2_: + (((d_4_i2_) >= (0)) and ((d_4_i2_) < (len(l)))) and (((l)[d_4_i2_]) == ((Result())[d_3_i1_]))))))) + result = list([0] * 0) + d_5_i_ = int(0) + d_5_i_ = 0 + while (d_5_i_) < (len(l)): + Invariant(Acc(list_pred(l))) + Invariant(Acc(list_pred(result))) + Invariant(0 <= d_5_i_ and d_5_i_ <= len(l)) + Invariant(len(result) <= d_5_i_) + Invariant(Forall(int, lambda j: (0 <= j and j < len(result)) ==> (result[j] > 0))) + Invariant(Forall(int, lambda j: (0 <= j and j < d_5_i_) ==> + (l[j] > 0 ==> Exists(int, lambda k: (0 <= k and k < len(result) and result[k] == l[j]))))) + Invariant(Forall(int, lambda j: (0 <= j and j < len(result)) ==> + Exists(int, lambda k: (0 <= k and k < d_5_i_ and l[k] == result[j])))) + Invariant(Forall(int, lambda j1, j2: (0 <= j1 and j1 < j2 and j2 < len(result)) ==> + Exists(int, lambda k1, k2: (0 <= k1 and k1 < k2 and k2 < d_5_i_ and + l[k1] == result[j1] and l[k2] == result[j2])))) + d_13_n_ = int(0) + d_13_n_ = (l)[d_5_i_] + if (d_13_n_) > (0): + d_17_res__prev_ = result + result = (result) + ([d_13_n_]) + d_5_i_ = (d_5_i_) + (1) + return result + """ + ) + + res = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def get__positive(l : List[int]) -> List[int]: + Requires(Acc(list_pred(l))) + Ensures(Acc(list_pred(l))) + Ensures(Acc(list_pred(Result()))) + Ensures(Forall(int, lambda d_0_i_: + not (((d_0_i_) >= (0)) and ((d_0_i_) < (len(Result())))) or (((Result())[d_0_i_]) > (0)))) + Ensures((len(Result())) <= (len(l))) + Ensures(Forall(int, lambda d_1_i1_: + not (((d_1_i1_) >= (0)) and ((d_1_i1_) < (len(l)))) or (not (((l)[d_1_i1_]) > (0)) or (Exists(int, lambda d_2_i2_: + (((d_2_i2_) >= (0)) and ((d_2_i2_) < (len(Result())))) and (((Result())[d_2_i2_]) == ((l)[d_1_i1_]))))))) + Ensures(((len(Result())) == (0)) or (Forall(int, lambda d_3_i1_: + not (((d_3_i1_) >= (0)) and ((d_3_i1_) < (len(Result())))) or (Exists(int, lambda d_4_i2_: + (((d_4_i2_) >= (0)) and ((d_4_i2_) < (len(l)))) and (((l)[d_4_i2_]) == ((Result())[d_3_i1_]))))))) + result = list([0] * 0) + d_5_i_ = int(0) + d_5_i_ = 0 + while (d_5_i_) < (len(l)): + Invariant(Acc(list_pred(l))) + Invariant(Acc(list_pred(result))) + Invariant(0 <= d_5_i_ and d_5_i_ <= len(l)) + Invariant(len(result) <= d_5_i_) + Invariant(Forall(int, lambda j:Implies( (0 <= j and j < len(result)) , (result[j] > 0)))) + Invariant(Forall(int, lambda j:Implies( (0 <= j and j < d_5_i_) , + (Implies(l[j] > 0 , Exists(int, lambda k: (0 <= k and k < len(result) and result[k] == l[j]))))))) + Invariant(Forall(int, lambda j:Implies( (0 <= j and j < len(result)) , + Exists(int, lambda k: (0 <= k and k < d_5_i_ and l[k] == result[j]))))) + Invariant(Forall(int, lambda j1, j2:Implies( (0 <= j1 and j1 < j2 and j2 < len(result)) , + Exists(int, lambda k1, k2: (0 <= k1 and k1 < k2 and k2 < d_5_i_ and + l[k1] == result[j1] and l[k2] == result[j2]))))) + d_13_n_ = int(0) + d_13_n_ = (l)[d_5_i_] + if (d_13_n_) > (0): + d_17_res__prev_ = result + result = (result) + ([d_13_n_]) + d_5_i_ = (d_5_i_) + (1) + return result + """ + ) + + prg = NaginiRewriterFixing().replace_impl(code) + assert prg == res + +def test_nagini_rewriter2(): + code = dedent( + """\ +def fizz__buzz(n : int) -> int: + Requires(n >= 0) + Ensures(Result() >= 0) + Ensures((Result()) == fizz_buzz_fun(n)) + result = int(0) + result = 0 + d_1_i_ = int(0) + d_1_i_ = 0 + while (d_1_i_) < (n): + Invariant(0 <= d_1_i_) + Invariant(d_1_i_ <= n) + Invariant(result >= 0) + Invariant(result == fizz_buzz_fun(d_1_i_)) + Invariant(Forall(int, lambda k: (0 <= k and k < d_1_i_) ==> + (((k % 11 == 0) or (k % 13 == 0)) ==> + (fizz_buzz_fun(k+1) == fizz_buzz_fun(k) + count7__r(k))))) + if (((d_1_i_ % 11)) == (0)) or (((d_1_i_ % 13)) == (0)): + d_4_cnt_ = int(0) + d_4_cnt_ = count7(d_1_i_) + result = (result) + (d_4_cnt_) + d_1_i_ = (d_1_i_) + (1) + return result + """ + ) + + res = dedent( + """\ +def fizz__buzz(n : int) -> int: + Requires(n >= 0) + Ensures(Result() >= 0) + Ensures((Result()) == fizz_buzz_fun(n)) + result = int(0) + result = 0 + d_1_i_ = int(0) + d_1_i_ = 0 + while (d_1_i_) < (n): + Invariant(0 <= d_1_i_) + Invariant(d_1_i_ <= n) + Invariant(result >= 0) + Invariant(result == fizz_buzz_fun(d_1_i_)) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < d_1_i_) , + (Implies(((k % 11 == 0) or (k % 13 == 0)) , + (fizz_buzz_fun(k+1) == fizz_buzz_fun(k) + count7__r(k))))))) + if (((d_1_i_ % 11)) == (0)) or (((d_1_i_ % 13)) == (0)): + d_4_cnt_ = int(0) + d_4_cnt_ = count7(d_1_i_) + result = (result) + (d_4_cnt_) + d_1_i_ = (d_1_i_) + (1) + return result + """ + ) + + prg = NaginiRewriterFixing().replace_impl(code) + assert prg == res + +def test_nagini_rewriter3(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def Compare(scores: List[int], guesses: List[int]) -> List[int]: + Requires(Acc(list_pred(guesses))) + Requires(Acc(list_pred(scores))) + Requires(len(scores) == len(guesses)) + Ensures(Acc(list_pred(guesses))) + Ensures(Acc(list_pred(scores))) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == len(scores)) + Ensures(len(scores) == len(guesses)) + Ensures(Forall(int, lambda d_0_i_: + not (0 <= d_0_i_ and d_0_i_ < len(Result())) or (Result()[d_0_i_] == abs(scores[d_0_i_] - guesses[d_0_i_])))) + + result = [int(0)] * 0 + nw0_ = [int(0)] * len(scores) + result = nw0_ + d_1_i_ = int(0) + d_1_i_ = 0 + while d_1_i_ < len(scores): + Invariant(Acc(list_pred(scores))) + Invariant(Acc(list_pred(guesses))) + Invariant(Acc(list_pred(result))) + Invariant(0 <= d_1_i_ and d_1_i_ <= len(scores)) + Invariant(len(result) == len(scores)) + Invariant(len(scores) == len(guesses)) + Invariant(Forall(int, lambda k: (0 <= k and k < d_1_i_) ==> (result[k] == abs(scores[k] - guesses[k])))) + Invariant(Forall(int, lambda k: (0 <= k and k < len(scores)) ==> (scores[k] == Old(scores[k])))) + Invariant(Forall(int, lambda k: (0 <= k and k < len(guesses)) ==> (guesses[k] == Old(guesses[k])))) + result[d_1_i_] = abs(scores[d_1_i_] - guesses[d_1_i_]) + d_1_i_ = d_1_i_ + 1 + return result + """ + ) + + error = dedent( + """\ +Manual inspection revealed occurrences of `==>` operator for implication on the following positions: +(28, 65), (29, 70), (30, 71) +`==>` operator does not exist in Nagini. All occurrences of `==>` operator should be replaced with `Implies(a, b)` operator, that is used to express implication in Nagini +We fixed errors with `==>` occurrences for you, and got the following program: +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def Compare(scores: List[int], guesses: List[int]) -> List[int]: + Requires(Acc(list_pred(guesses))) + Requires(Acc(list_pred(scores))) + Requires(len(scores) == len(guesses)) + Ensures(Acc(list_pred(guesses))) + Ensures(Acc(list_pred(scores))) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == len(scores)) + Ensures(len(scores) == len(guesses)) + Ensures(Forall(int, lambda d_0_i_: + not (0 <= d_0_i_ and d_0_i_ < len(Result())) or (Result()[d_0_i_] == abs(scores[d_0_i_] - guesses[d_0_i_])))) + + result = [int(0)] * 0 + nw0_ = [int(0)] * len(scores) + result = nw0_ + d_1_i_ = int(0) + d_1_i_ = 0 + while d_1_i_ < len(scores): + Invariant(Acc(list_pred(scores))) + Invariant(Acc(list_pred(guesses))) + Invariant(Acc(list_pred(result))) + Invariant(0 <= d_1_i_ and d_1_i_ <= len(scores)) + Invariant(len(result) == len(scores)) + Invariant(len(scores) == len(guesses)) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < d_1_i_) , (result[k] == abs(scores[k] - guesses[k]))))) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < len(scores)) , (scores[k] == Old(scores[k]))))) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < len(guesses)) , (guesses[k] == Old(guesses[k]))))) + result[d_1_i_] = abs(scores[d_1_i_] - guesses[d_1_i_]) + d_1_i_ = d_1_i_ + 1 + return result + +Next, we run verifier on this program. Using the following verdict, you should possibly modify this program. + """ + ) + + prg, prompt = NaginiRewriterFixing(NaginiRewriter()).rewrite(code) + assert prompt == error + + +def test_nagini_rewriter4(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def Compare(scores: List[int], guesses: List[int]) -> List[int]: + Requires(Acc(list_pred(guesses))) + Requires(Acc(list_pred(scores))) + Requires(len(scores) == len(guesses)) + Ensures(Acc(list_pred(guesses))) + Ensures(Acc(list_pred(scores))) + Ensures(Acc(list_pred(Result()))) + Ensures(len(Result()) == len(scores)) + Ensures(len(scores) == len(guesses)) + Ensures(Forall(int, lambda d_0_i_: + not (0 <= d_0_i_ and d_0_i_ < len(Result())) or (Result()[d_0_i_] == abs(scores[d_0_i_] - guesses[d_0_i_])))) + + result = [int(0)] * 0 + nw0_ = [int(0)] * len(scores) + result = nw0_ + d_1_i_ = int(0) + d_1_i_ = 0 + while d_1_i_ < len(scores): + Invariant(Acc(list_pred(scores))) + Invariant(Acc(list_pred(guesses))) + Invariant(Acc(list_pred(result))) + Invariant(0 <= d_1_i_ and d_1_i_ <= len(scores)) + Invariant(len(result) == len(scores)) + Invariant(len(scores) == len(guesses)) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < d_1_i_) , (result[k] == abs(scores[k] - guesses[k]))))) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < len(scores)) , (scores[k] == Old(scores[k]))))) + Invariant(Forall(int, lambda k:Implies( (0 <= k and k < len(guesses)) , (guesses[k] == Old(guesses[k]))))) + result[d_1_i_] = abs(scores[d_1_i_] - guesses[d_1_i_]) + d_1_i_ = d_1_i_ + 1 + return result + """ + ) + + error = "" + + prg, prompt = NaginiRewriterFixing(NaginiRewriter()).rewrite(code) + assert prompt == error + + +def test_nagini_rewriter5(): + code = dedent( + """\ +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def Compare(scores: List[int], guesses: List[int]) -> List[int]: + result = [int(0)] * 0 + nw0_ = [int(0)] * len(scores) + result = nw0_ + d_1_i_ = int(0) + d_1_i_ = 0 + while d_1_i_ < len(scores): + Invariant(0 <= d_1_i_ <= len(scores)) + result[d_1_i_] = abs(scores[d_1_i_] - guesses[d_1_i_]) + d_1_i_ = d_1_i_ + 1 + return result + """ + ) + + error = dedent( + """\ +We replaced all double (triple and so on) inequalities with their equivalents (as they are prohibited) and got the following program: +from typing import cast, List, Dict, Set, Optional, Union +from nagini_contracts.contracts import * + +def Compare(scores: List[int], guesses: List[int]) -> List[int]: + result = [int(0)] * 0 + nw0_ = [int(0)] * len(scores) + result = nw0_ + d_1_i_ = int(0) + d_1_i_ = 0 + while d_1_i_ < len(scores): + Invariant(0 <= d_1_i_ and d_1_i_ <= len(scores)) + result[d_1_i_] = abs(scores[d_1_i_] - guesses[d_1_i_]) + d_1_i_ = d_1_i_ + 1 + return result +Next, we run verifier on this program. Using the following verdict, you should possibly modify this program. + """ + ) + + prg, prompt = NaginiRewriterFixingAST(NaginiRewriterFixing(NaginiRewriter())).rewrite(code) + assert prompt == error \ No newline at end of file diff --git a/verified_cogen/runners/invariants_with_rewriting.py b/verified_cogen/runners/invariants_with_rewriting.py new file mode 100644 index 0000000..c1e9c25 --- /dev/null +++ b/verified_cogen/runners/invariants_with_rewriting.py @@ -0,0 +1,32 @@ +from logging import Logger + +from verified_cogen.llm import LLM +from verified_cogen.runners import RunnerConfig +from verified_cogen.runners.invariants import InvariantRunner +from verified_cogen.runners.rewriters import Rewriter +from verified_cogen.tools.verifier import Verifier + + +class InvariantsWithRewriting(InvariantRunner): + rewriter: Rewriter + + def __init__( + self, + llm: LLM, + logger: Logger, + verifier: Verifier, + config: RunnerConfig, + rewriter: Rewriter, + ): + super().__init__(llm, logger, verifier, config) + self.rewriter = rewriter + + def postprocess(self, inv_prg: str) -> str: + prg, prompt = self.rewriter.rewrite(inv_prg) + + if prompt != "": + self.logger.info("Manual rewriting results: ") + self.logger.info(prompt) + self.llm.add_user_prompt(prompt) + + return prg diff --git a/verified_cogen/runners/rewriters/__init__.py b/verified_cogen/runners/rewriters/__init__.py new file mode 100644 index 0000000..9572800 --- /dev/null +++ b/verified_cogen/runners/rewriters/__init__.py @@ -0,0 +1,5 @@ +from typing import Tuple + + +class Rewriter: + def rewrite(self, prg: str) -> Tuple[str, str]: ... diff --git a/verified_cogen/runners/rewriters/nagini_rewriter.py b/verified_cogen/runners/rewriters/nagini_rewriter.py new file mode 100644 index 0000000..8a5123f --- /dev/null +++ b/verified_cogen/runners/rewriters/nagini_rewriter.py @@ -0,0 +1,22 @@ +from typing import Tuple, List + +from verified_cogen.runners.rewriters import Rewriter + + +class NaginiRewriter(Rewriter): + def rewrite(self, prg: str) -> Tuple[str, str]: + pos_implications: List[Tuple[int, int]] = [] + + for idx, line in enumerate(prg.splitlines()): + for j in range(0, len(line) - 2): + if line[j : j + 3] == "==>": + pos_implications.append((idx + 1, j + 1)) + + if len(pos_implications) == 0: + return prg, "" + + prompt = "Manual inspection revealed occurrences of `==>` operator for implication on the following positions:\n" + prompt += ", ".join(f"({a}, {b})" for a, b in pos_implications) + "\n" + prompt += "`==>` operator does not exist in Nagini. All occurrences of `==>` operator should be replaced with `Implies(a, b)` operator, that is used to express implication in Nagini\n" + + return prg, prompt diff --git a/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py b/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py new file mode 100644 index 0000000..480a658 --- /dev/null +++ b/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py @@ -0,0 +1,74 @@ +from typing import Tuple, Optional, Dict, List + +from verified_cogen.runners.rewriters.__init__ import Rewriter + + +class NaginiRewriterFixing(Rewriter): + wrapped_rewriter: Optional[Rewriter] + + def __init__(self, rewriter: Optional[Rewriter] = None): + super().__init__() + self.wrapped_rewriter = rewriter + + def replace_impl(self, prg: str): + indices: List[Tuple[int, str]] = [] + + for i in range(len(prg) - 2): + if prg[i : i + 3] == "==>": + cnt = 0 + for j in range(i - 1, -1, -1): + if cnt == 0 and (prg[j] == ":" or prg[j] == "("): + indices.append((j, "Implies(")) + break + if prg[j] == "(": + cnt -= 1 + elif prg[j] == ")": + cnt += 1 + cnt = 0 + for j in range(i + 3, len(prg)): + if cnt == 0 and prg[j] == ")": + indices.append((j - 1, ")")) + break + if prg[j] == ")": + cnt -= 1 + elif prg[j] == "(": + cnt += 1 + + dict: Dict[int, str] = {} + for i, st in indices: + dict[i] = st + + new_prg = "" + + j = 0 + while j < len(prg): + if j < len(prg) - 2 and prg[j : j + 3] == "==>": + j = j + 3 + new_prg += "," + continue + new_prg += prg[j] + if j in dict: + for s in dict[j]: + new_prg += s + j = j + 1 + + return new_prg + + def rewrite(self, prg: str) -> Tuple[str, str]: + prompt: str = "" + + if self.wrapped_rewriter is not None: + prg, prompt = self.wrapped_rewriter.rewrite(prg) + + if prompt == "": + return prg, "" + + prompt += "We fixed errors with `==>` occurrences for you, and got the following program:\n" + + new_prg = self.replace_impl(prg) + + prompt += new_prg + "\n" + + prompt += "Next, we run verifier on this program. Using the following verdict, you should possibly modify this program.\n" + + return new_prg, prompt diff --git a/verified_cogen/runners/rewriters/nagini_rewriter_fixing_ast.py b/verified_cogen/runners/rewriters/nagini_rewriter_fixing_ast.py new file mode 100644 index 0000000..bf2eb7d --- /dev/null +++ b/verified_cogen/runners/rewriters/nagini_rewriter_fixing_ast.py @@ -0,0 +1,32 @@ +from typing import Tuple, Optional + +from verified_cogen.runners.rewriters.__init__ import Rewriter +from verified_cogen.tools.inequality_replacer import ( + replace_inequalities, + contains_double_inequality, +) + + +class NaginiRewriterFixingAST(Rewriter): + wrapped_rewriter: Optional[Rewriter] + + def __init__(self, rewriter: Optional[Rewriter] = None): + super().__init__() + self.wrapped_rewriter = rewriter + + def rewrite(self, prg: str) -> Tuple[str, str]: + prompt: str = "" + + if self.wrapped_rewriter is not None: + prg, prompt = self.wrapped_rewriter.rewrite(prg) + + try: + if contains_double_inequality(prg): + prg = replace_inequalities(prg) + prompt += "We replaced all double (triple and so on) inequalities with their equivalents (as they are prohibited) and got the following program:\n" + prompt += prg + "\n" + prompt += "Next, we run verifier on this program. Using the following verdict, you should possibly modify this program.\n" + except Exception: + pass + + return prg, prompt diff --git a/verified_cogen/tools/inequality_replacer.py b/verified_cogen/tools/inequality_replacer.py new file mode 100644 index 0000000..0c91d8e --- /dev/null +++ b/verified_cogen/tools/inequality_replacer.py @@ -0,0 +1,48 @@ +import ast +from _ast import Compare, expr +from typing import List + + +class InequalityReplacer(ast.NodeTransformer): + def visit_Compare(self, node: Compare): + if len(node.comparators) > 1: + new_nodes: List[expr] = [] + left: expr = node.left + left = self.visit(left) + for op, right in zip(node.ops, node.comparators): + right = self.visit(right) + new_nodes.append(ast.Compare(left=left, ops=[op], comparators=[right])) + left = right + return ast.BoolOp(op=ast.And(), values=new_nodes) + self.generic_visit(node) + return node + + +def replace_inequalities(code: str) -> str: + tree = ast.parse(code) + + transformer = InequalityReplacer() + modified_tree = transformer.visit(tree) + + ast.fix_missing_locations(modified_tree) + + new_code = ast.unparse(modified_tree) + return new_code + + +class DoubleInequalityChecker(ast.NodeVisitor): + def __init__(self): + self.has_double_inequality: bool = False + + def visit_Compare(self, node: Compare): + if len(node.comparators) > 1: + self.has_double_inequality = True + self.generic_visit(node) + + +def contains_double_inequality(code: str) -> bool: + tree = ast.parse(code) + + checker = DoubleInequalityChecker() + checker.visit(tree) + return checker.has_double_inequality From d9b44c2bd70467d5b14041b4c7b86d1fe3f7f7ea Mon Sep 17 00:00:00 2001 From: AlexShefY Date: Thu, 7 Nov 2024 10:46:07 +0100 Subject: [PATCH 2/2] fixes in manual --- .../ask_for_fixed.txt | 1 + .../ask_for_fixed_had_errors.txt | 1 + .../humaneval-nagini-cot-instruct/timeout.txt | 1 + tests/test_rewriters.py | 34 ++++++++++++++++++- verified_cogen/args.py | 9 +++-- verified_cogen/llm/llm.py | 10 +++--- .../runners/invariants_with_rewriting.py | 4 ++- .../rewriters/nagini_rewriter_fixing.py | 17 +++++++--- 8 files changed, 62 insertions(+), 15 deletions(-) diff --git a/prompts/humaneval-nagini-cot-instruct/ask_for_fixed.txt b/prompts/humaneval-nagini-cot-instruct/ask_for_fixed.txt index 38f5c7b..7ac35ab 100644 --- a/prompts/humaneval-nagini-cot-instruct/ask_for_fixed.txt +++ b/prompts/humaneval-nagini-cot-instruct/ask_for_fixed.txt @@ -3,4 +3,5 @@ The following errors occurred during verification: Please fix the error by adding, removing or modifying the invariants and return the fixed program. Don't add any additional text comments, your response must contain only program with invariants. +You SHOULD NOT modify anything except invariants and asserts. Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. \ No newline at end of file diff --git a/prompts/humaneval-nagini-cot-instruct/ask_for_fixed_had_errors.txt b/prompts/humaneval-nagini-cot-instruct/ask_for_fixed_had_errors.txt index 9e6f1ab..174212b 100644 --- a/prompts/humaneval-nagini-cot-instruct/ask_for_fixed_had_errors.txt +++ b/prompts/humaneval-nagini-cot-instruct/ask_for_fixed_had_errors.txt @@ -3,4 +3,5 @@ There are still some errors: Could you please fix them? Don't add any additional text comments, your response must contain only program with invariants. +You SHOULD NOT modify anything except invariants and asserts. Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else. \ No newline at end of file diff --git a/prompts/humaneval-nagini-cot-instruct/timeout.txt b/prompts/humaneval-nagini-cot-instruct/timeout.txt index 1c50276..20a8c1b 100644 --- a/prompts/humaneval-nagini-cot-instruct/timeout.txt +++ b/prompts/humaneval-nagini-cot-instruct/timeout.txt @@ -1,3 +1,4 @@ The verifier timed out during the verification. This usually means that the provided invariants were too broad or were difficult to check. +You SHOULD NOT modify anything except invariants and asserts. Could you please try to improve the invariants and try again? \ No newline at end of file diff --git a/tests/test_rewriters.py b/tests/test_rewriters.py index 399ea18..60668b0 100644 --- a/tests/test_rewriters.py +++ b/tests/test_rewriters.py @@ -380,4 +380,36 @@ def Compare(scores: List[int], guesses: List[int]) -> List[int]: ) prg, prompt = NaginiRewriterFixingAST(NaginiRewriterFixing(NaginiRewriter())).rewrite(code) - assert prompt == error \ No newline at end of file + assert prompt == error + +def test_rewriter6(): + code = dedent( + """\ + Invariant(0 <= k and k < d_1_i_ ==> s[k] != s[len(s) - 1 - k] ==> c > smallest__change__fun(s, 0, k)) + """ + ) + + prg, prompt = NaginiRewriterFixing(NaginiRewriter()).rewrite(code) + + assert prompt != "" + assert prg == dedent( + """\ + Invariant(Implies(0 <= k and k < d_1_i_ ,Implies( s[k] != s[len(s) - 1 - k] , c > smallest__change__fun(s, 0, k)))) + """ + ) + +def test_rewriter7(): + code = dedent( + """\ + Invariant(Forall(int, lambda i: Forall(int, lambda j: Implies(0 <= i and i < j and j < len(result), result[i] < result[j])))) + """ + ) + + prg, prompt = NaginiRewriterFixingAST(NaginiRewriterFixing(NaginiRewriter())).rewrite(code) + + assert prompt == "" + assert prg == dedent( + """\ + Invariant(Forall(int, lambda i: Forall(int, lambda j: Implies(0 <= i and i < j and j < len(result), result[i] < result[j])))) + """ + ) \ No newline at end of file diff --git a/verified_cogen/args.py b/verified_cogen/args.py index ed056c8..730aaeb 100644 --- a/verified_cogen/args.py +++ b/verified_cogen/args.py @@ -11,7 +11,7 @@ class ProgramArgs: runs: int insert_conditions_mode: str bench_type: str - temperature: int + temperature: float verifier_command: str verifier_timeout: int prompts_directory: str @@ -68,7 +68,12 @@ def get_default_parser(): help="benchmark type, available: {invariants, generic, generate, validating, step-by-step, step-by-step-flush}", default="invariants", ) - parser.add_argument("--temperature", help="model temperature", default=0, type=int) + parser.add_argument( + "--temperature", + help="model temperature", + default=0, + type=float, + ) parser.add_argument( "--verifier-command", help="command to run (cmd [file_path]) to verify a file", diff --git a/verified_cogen/llm/llm.py b/verified_cogen/llm/llm.py index be5b194..9a872cf 100644 --- a/verified_cogen/llm/llm.py +++ b/verified_cogen/llm/llm.py @@ -1,5 +1,4 @@ import logging -from http.client import RemoteDisconnected from pathlib import Path from typing import Optional @@ -9,7 +8,6 @@ from grazie.api.client.gateway import ( AuthType, GrazieApiGatewayClient, - RequestFailedException, ) from grazie.api.client.llm_parameters import LLMParameters from grazie.api.client.parameters import Parameters @@ -104,12 +102,12 @@ def _request( LLMParameters.Temperature: Parameters.FloatValue(temperature) }, ) - except RemoteDisconnected: + except Exception: logger.warning("Grazie API is down, retrying...") return self._request(temperature, tries - 1) - except RequestFailedException as e: - self.dump_history(Path("err_dump.txt")) - raise e + # except RequestFailedException as e: + # self.dump_history(Path("err_dump.txt")) + # raise e def make_request(self) -> str: response = self._request().content diff --git a/verified_cogen/runners/invariants_with_rewriting.py b/verified_cogen/runners/invariants_with_rewriting.py index c1e9c25..b45a92e 100644 --- a/verified_cogen/runners/invariants_with_rewriting.py +++ b/verified_cogen/runners/invariants_with_rewriting.py @@ -25,7 +25,9 @@ def postprocess(self, inv_prg: str) -> str: prg, prompt = self.rewriter.rewrite(inv_prg) if prompt != "": - self.logger.info("Manual rewriting results: ") + self.logger.info("Manually rewrite:") + self.logger.info(inv_prg) + self.logger.info("Manual rewriting results:") self.logger.info(prompt) self.llm.add_user_prompt(prompt) diff --git a/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py b/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py index 480a658..a34440e 100644 --- a/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py +++ b/verified_cogen/runners/rewriters/nagini_rewriter_fixing.py @@ -17,7 +17,12 @@ def replace_impl(self, prg: str): if prg[i : i + 3] == "==>": cnt = 0 for j in range(i - 1, -1, -1): - if cnt == 0 and (prg[j] == ":" or prg[j] == "("): + if cnt == 0 and ( + prg[j] == ":" + or prg[j] == "(" + or j >= 2 + and prg[j - 2 : j + 1] == "==>" + ): indices.append((j, "Implies(")) break if prg[j] == "(": @@ -36,17 +41,19 @@ def replace_impl(self, prg: str): dict: Dict[int, str] = {} for i, st in indices: - dict[i] = st + if i not in dict: + dict[i] = "" + dict[i] = dict[i] + st new_prg = "" j = 0 while j < len(prg): if j < len(prg) - 2 and prg[j : j + 3] == "==>": - j = j + 3 + j = j + 2 new_prg += "," - continue - new_prg += prg[j] + else: + new_prg += prg[j] if j in dict: for s in dict[j]: new_prg += s