Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
alex28sh committed Nov 29, 2024
1 parent 9743971 commit 60f88bf
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
The following errors occurred during verification:
{error}

Please fix the error by adding, removing or modifying the implementation, invariants or assertions and return the fixed program.
Don't add any additional text comments, your response must contain only program with invariants.
Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
There are still some errors:
{error}

Could you please fix them?
Don't add any additional text comments, your response must contain only program with invariants.
Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Generally, you should use helper functions (marked with @Pure annotation) only in invariants, asserts and conditions (in `if` or `while` conditions), not in the plain code.
But, the following helper functions you can use anywhere: {helpers}.
Do not change helper functions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
We detected an improper usage of helper functions. Here is the list of helper functions used in a wrong way:
{invalid_helpers}
You should use helper functions only in invariants, asserts and conditions (in `if` or `while` conditions), not in the plain code.
The following helper functions you can use anywhere: {helpers}.
We replaced all improper usages with `invalid_call()` and got the following program:
{program}
You should rewrite this program without changing pre/postconditions and helper functions (denoted with @Pure).
After rewriting your code should verify.
Your code should not contain any `invalid_call()` invocations.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
Rewrite the following Nagini code with implementations of some functions missing. While rewriting it, ensure that it verifies. Include invariants and assertions. Don't remove any helper functions (they are marked with @Pure annotation), they are there to help you. Prefer loops to recursion.
Use helper functions only in invariants, asserts and conditions (in `if` or `while` conditions). Don't use helpers in the plain code.
Do not change helper functions.
Add code and invariants to other functions.
Ensure that the invariants are as comprehensive as they can be.
Even if you think some invariant is not totally necessary, better add it than not.
Don't add any additional text comments, your response must contain only program with invariants.
Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else.


You remember the following aspects of Nagini syntax:

1. Nagini DOES NOT SUPPORT some Python features as list comprehensions (k + 1 for k in range(5)), as double inequalities (a <= b <= c).
Instead of double inequalities it's customary to use two separate inequalities (a <= b and b <= c).

2. In Nagini method preconditions (Requires) and postconditions (Ensures) placed right after method signature, like here:
"
def Sum(a : List[int], s : int, t : int) -> int :
Requires(Acc(list_pred(a)))
Requires(((0) <= (s)) and ((s) <= (t)) and ((t) <= (len(a))))
Ensures(Acc(list_pred(a)))
...
"

3. Invariant are placed right after `while` statement and before the code of `while` body:
"
while i < len(numbers):
Invariant(Acc(list_pred(numbers)))
Invariant(0 <= i and i <= len(numbers))
s = s + numbers[i]
"
Invariants CANNOT be placed in any other position.
You remember that each invariant (and each expression) should contain equal number of opening and closing brackets, so that it is valid.
You should sustain balanced parentheses.

4. Nagini requires special annotations for working with lists `Acc(list_pred(..))`. You can use these constructs only inside `Invariant`,
anywhere else you should not use `Acc()` or `list_pred()`:
"
while i < len(numbers):
Invariant(Acc(list_pred(numbers)))
"

5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`),
second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`.

6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true.
You can use it inside invariants and asserts.

You might need to work with accumulating functions, such as sum, so here's an example of how to do that:
```
from typing import cast, List, Dict, Set, Optional, Union
from nagini_contracts.contracts import *

@Pure
def Sum(a : List[int], s : int, t : int) -> int :
Requires(Acc(list_pred(a)))
Requires(((0) <= (s)) and ((s) <= (t)) and ((t) <= (len(a))))

if s == t:
return 0
else:
return (a)[t - 1] + (Sum(a, s, t - 1))

def sum_loop(numbers: List[int]) -> int:
Requires(Acc(list_pred(numbers)))
Ensures(Acc(list_pred(numbers)))
Ensures(Result() == Sum(numbers, 0, len(numbers)))
s = int(0)
i = int(0)
while (i) < (len(numbers)):
Invariant(Acc(list_pred(numbers)))
Invariant(0 <= i and i <= len(numbers))
Invariant(Forall(int, lambda d_1_p_:
(Implies(0 <= d_1_p_ and d_1_p_ < len(numbers), Sum(numbers, 0, d_1_p_ + 1) == Sum(numbers, 0, d_1_p_) + numbers[d_1_p_]), [[Sum(numbers, 0, d_1_p_ + 1)]])))
Invariant(s == Sum(numbers, 0, i))
Assert(Sum(numbers, 0, i + 1) == Sum(numbers, 0, i) + numbers[i])
s = s + (numbers)[i]
i = i + 1
return s
```

To help you, here's a text description given to a person who wrote this program:

{text_description}

The program:
{program}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
You are an expert in a Python verification framework Nagini.
You will be given tasks dealing with Python programs including precise annotations.
Do not provide ANY explanations. Don't include markdown backticks. Respond only in Python code, nothing else.
You respond only with code blocks.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The verifier timed out during the verification.
This usually means that the provided invariants were too broad or were difficult to check.
Could you please try to improve the invariants and try again?
10 changes: 9 additions & 1 deletion verified_cogen/args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
from typing import Optional, no_type_check
from typing import Optional, no_type_check, List

from verified_cogen.tools.modes import VALID_MODES

Expand All @@ -26,6 +26,7 @@ class ProgramArgs:
remove_conditions: bool
remove_implementations: bool
include_text_descriptions: bool
manual_rewriters: List[str]

@no_type_check
def __init__(self, args):
Expand All @@ -49,6 +50,7 @@ def __init__(self, args):
self.remove_conditions = args.remove_conditions
self.remove_implementations = args.remove_implementations
self.include_text_descriptions = args.include_text_descriptions
self.manual_rewriters = args.manual_rewriters


def get_default_parser():
Expand Down Expand Up @@ -129,6 +131,12 @@ def get_default_parser():
default=False,
action="store_true",
)
parser.add_argument(
"--manual-rewriters",
help="Manual rewriters for additional program modifications",
default=[],
nargs="+",
)
return parser


Expand Down
8 changes: 6 additions & 2 deletions verified_cogen/experiments/incremental_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from verified_cogen.args import ProgramArgs, get_default_parser
from verified_cogen.llm.llm import LLM
from verified_cogen.main import make_runner_cls
from verified_cogen.main import make_runner_cls, construct_rewriter
from verified_cogen.runners import RunnerConfig
from verified_cogen.runners.languages import AnnotationType, register_basic_languages
from verified_cogen.tools import (
Expand Down Expand Up @@ -36,6 +36,7 @@ def main():
"--ignore-failed", help="Ignore failed files", action="store_true"
)
args = IncrementalRunArgs(parser.parse_args())
print(args.manual_rewriters)

all_removed = [AnnotationType.INVARIANTS, AnnotationType.ASSERTS]
if args.remove_conditions:
Expand Down Expand Up @@ -85,9 +86,12 @@ def main():
args.prompts_directory,
args.temperature,
)
rewriter = construct_rewriter(
extension_from_file_list([file]), args.manual_rewriters
)
runner = make_runner_cls(
args.bench_type, extension_from_file_list([file]), config
)(llm, logger, verifier)
)(llm, logger, verifier, rewriter)
display_name = rename_file(file)
marker_name = str(file.relative_to(directory))
if (
Expand Down
69 changes: 57 additions & 12 deletions verified_cogen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
import pathlib
from logging import Logger
from pathlib import Path
from typing import Callable
from typing import Callable, List, Optional

from verified_cogen.args import ProgramArgs, get_args
from verified_cogen.llm import LLM
from verified_cogen.runners import Runner, RunnerConfig
from verified_cogen.runners.flush import FlushRunner
from verified_cogen.runners.generate import GenerateRunner
from verified_cogen.runners.generic import GenericRunner
from verified_cogen.runners.invariants import InvariantRunner
from verified_cogen.runners.languages import register_basic_languages
from verified_cogen.runners.languages.language import AnnotationType, LanguageDatabase
from verified_cogen.runners.rewriters import Rewriter
from verified_cogen.runners.rewriters.nagini_rewriter import NaginiRewriter
from verified_cogen.runners.rewriters.nagini_rewriter_fixing import NaginiRewriterFixing
from verified_cogen.runners.rewriters.nagini_rewriter_fixing_ast import (
NaginiRewriterFixingAST,
)
from verified_cogen.runners.step_by_step import StepByStepRunner
from verified_cogen.runners.step_by_step_flush import StepByStepFlushRunner
from verified_cogen.runners.validating import ValidatingRunner
from verified_cogen.tools import (
ext_glob,
Expand All @@ -33,9 +39,10 @@
def run_once(
files: list[Path],
args: ProgramArgs,
runner_cls: Callable[[LLM, Logger, Verifier], Runner],
runner_cls: Callable[[LLM, Logger, Verifier, Optional[Rewriter]], Runner],
verifier: Verifier,
mode: Mode,
rewriter: Optional[Rewriter],
is_once: bool,
) -> tuple[int, int, int, dict[str, int]]:
_init: tuple[list[str], list[str], list[str]] = ([], [], [])
Expand All @@ -51,7 +58,7 @@ def run_once(
args.temperature,
)

runner = runner_cls(llm, logger, verifier)
runner = runner_cls(llm, logger, verifier, rewriter)

retries = args.retries + 1
tries = None
Expand Down Expand Up @@ -96,10 +103,36 @@ def run_once(
return len(success_zero_tries), len(success), len(failed), cnt


def construct_nagini_rewriter(runner_types: List[str]) -> Optional[Runner]:
runner = None
for runner_type in runner_types:
if runner_type == "NaginiRewriter":
runner = NaginiRewriter()
elif runner_type == "NaginiRewriterFixing":
runner = NaginiRewriterFixing(runner)
elif runner_type == "NaginiRewriterFixingAST":
runner = NaginiRewriterFixingAST(runner)
else:
raise ValueError(f"Unexpected nagini rewriter type: {runner_type}")
return runner


def construct_rewriter(extension: str, runner_types: List[str]) -> Optional[Runner]:
if extension == "py":
return construct_nagini_rewriter(runner_types)
if runner_types:
raise ValueError(
f"Not implemented rewriters for language: {LanguageDatabase().regularise[extension]}"
)
return None


def make_runner_cls(
bench_type: str, extension: str, config: RunnerConfig
) -> Callable[[LLM, Logger, Verifier], Runner]:
def runner_cls(llm: LLM, logger: Logger, verifier: Verifier):
) -> Callable[[LLM, Logger, Verifier, Optional[Rewriter]], Runner]:
def runner_cls(
llm: LLM, logger: Logger, verifier: Verifier, rewriter: Optional[Rewriter]
):
if bench_type == "invariants":
return InvariantRunner(llm, logger, verifier, config)
elif bench_type == "generic":
Expand All @@ -108,17 +141,23 @@ def runner_cls(llm: LLM, logger: Logger, verifier: Verifier):
return GenerateRunner(llm, logger, verifier, config)
elif bench_type == "validating":
return ValidatingRunner(
InvariantRunner(llm, logger, verifier, config),
InvariantRunner(llm, logger, verifier, config, rewriter),
LanguageDatabase().get(extension),
)
elif bench_type == "step-by-step":
return ValidatingRunner(
StepByStepRunner(InvariantRunner(llm, logger, verifier, config)),
StepByStepRunner(
InvariantRunner(llm, logger, verifier, config, rewriter)
),
LanguageDatabase().get(extension),
)
elif bench_type == "step-by-step-flush":
return ValidatingRunner(
StepByStepFlushRunner(InvariantRunner(llm, logger, verifier, config)),
FlushRunner(
StepByStepRunner(
InvariantRunner(llm, logger, verifier, config, rewriter)
)
),
LanguageDatabase().get(extension),
)
else:
Expand Down Expand Up @@ -152,13 +191,18 @@ def main():

verifier = Verifier(args.verifier_command, args.verifier_timeout)
config = RunnerConfig(
log_tries=log_tries, include_text_descriptions=args.include_text_descriptions
log_tries=log_tries,
include_text_descriptions=args.include_text_descriptions,
remove_implementations=args.remove_implementations,
)
if args.dir is not None:
files = sorted(list(pathlib.Path(args.dir).glob(ext_glob(args.filter_by_ext))))
runner_cls = make_runner_cls(
args.bench_type, extension_from_file_list(files), config
)
rewriter = construct_rewriter(
extension_from_file_list(files), args.manual_rewriters
)
runner = runner_cls(
LLM(
args.grazie_token,
Expand All @@ -168,21 +212,22 @@ def main():
),
logger,
verifier,
rewriter,
)
for file in files:
with open(file) as f:
runner.precheck(f.read(), mode)

if args.runs == 1:
_, _, _, total_cnt = run_once(
files, args, runner_cls, verifier, mode, is_once=True
files, args, runner_cls, verifier, mode, rewriter, is_once=True
)
else:
success_zero_tries, success, failed = 0, 0, 0
total_cnt = {rename_file(f): 0 for f in files}
for _ in range(args.runs):
s0, s, f, cnt = run_once(
files, args, runner_cls, verifier, mode, is_once=False
files, args, runner_cls, verifier, mode, rewriter, is_once=False
)
success_zero_tries += s0
success += s
Expand Down
Loading

0 comments on commit 60f88bf

Please sign in to comment.