Skip to content

Commit

Permalink
Merge branch 'main' into HumanEval-Nagini
Browse files Browse the repository at this point in the history
  • Loading branch information
alex28sh authored Sep 22, 2024
2 parents 333f960 + f9c0a9f commit 47640f8
Show file tree
Hide file tree
Showing 14 changed files with 183 additions and 52 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ data
llm-generated
.envrc
**/.DS_Store
pyrightconfig.json
benches/DafnyBench/
benches/HumanEval-Dafny-Mini/
break_assert.rs
Expand All @@ -13,6 +12,7 @@ break_assert.rs
log
.ruff_cache
.vscode
.zed
run.sh
/dist/
**/.pytest_cache
Expand All @@ -22,3 +22,4 @@ results
results/*
/log_tries/
/log_tries/*
.direnv
45 changes: 45 additions & 0 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
inputs = {
flakelight.url = "github:nix-community/flakelight";
};
outputs = { flakelight, ... }@inputs:
flakelight ./. {
inherit inputs;

systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
devShell.packages = pkgs: with pkgs; [
poetry
dafny
];
formatter = pkgs: with pkgs; [ nixpkgs-fmt ];
};
}
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ incremental_run = "verified_cogen.experiments.incremental_run:main"
profile = "black"
src_paths = ["verified_cogen"]

[tool.pyright]
typeCheckingMode = "strict"

[[tool.poetry.source]]
name = "PyPI"
priority = "primary"
Expand All @@ -33,7 +36,7 @@ ruff = "^0.5.4"
pytest = "^8.3.1"
matplotlib = "^3.9.2"
ipykernel = "^6.29.5"
pyright = "^1.1.380"
pyright = "^1.1.381"

[build-system]
requires = ["poetry-core"]
Expand Down
49 changes: 45 additions & 4 deletions verified_cogen/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,47 @@
import os

from verified_cogen.tools.modes import VALID_MODES
from typing import no_type_check, Optional


class ProgramArgs:
input: Optional[str]
dir: Optional[str]
runs: int
insert_conditions_mode: str
bench_type: str
temperature: int
shell: str
verifier_command: str
verifier_timeout: int
prompts_directory: str
grazie_token: str
llm_profile: str
tries: int
retries: int
output_style: str
filter_by_ext: Optional[str]
log_tries: Optional[str]

@no_type_check
def __init__(self, args):
self.input = args.input
self.dir = args.dir
self.runs = args.runs
self.insert_conditions_mode = args.insert_conditions_mode
self.bench_type = args.bench_type
self.temperature = args.temperature
self.shell = args.shell
self.verifier_command = args.verifier_command
self.verifier_timeout = args.verifier_timeout
self.prompts_directory = args.prompts_directory
self.grazie_token = args.grazie_token
self.llm_profile = args.llm_profile
self.tries = args.tries
self.retries = args.retries
self.output_style = args.output_style
self.filter_by_ext = args.filter_by_ext
self.log_tries = args.log_tries


def get_default_parser():
Expand Down Expand Up @@ -50,15 +91,15 @@ def get_default_parser():
parser.add_argument(
"-s", "--output-style", choices=["stats", "full"], default="full"
)
parser.add_argument("--filter-by-ext", help="filter by extension", default=None)
parser.add_argument("--filter-by-ext", help="filter by extension", required=False)
parser.add_argument(
"--log-tries", help="Save output of every try to given dir", default=None
"--log-tries", help="Save output of every try to given dir", required=False
)
parser.add_argument(
"--output-logging", help="Print logs to standard output", default=False
)
return parser


def get_args():
return get_default_parser().parse_args()
def get_args() -> ProgramArgs:
return ProgramArgs(get_default_parser().parse_args())
32 changes: 25 additions & 7 deletions verified_cogen/experiments/use_houdini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import os
from typing import Optional
from typing import Optional, no_type_check

from verified_cogen.llm import LLM
from verified_cogen.runners import LLM_GENERATED_DIR
Expand All @@ -12,6 +12,24 @@
log = logging.getLogger(__name__)


class ProgramArgs:
grazie_token: str
profile: str
prompt_dir: str
program: str
verifier_command: str

@no_type_check
def __init__(self, *args):
(
self.grazie_token,
self.profile,
self.prompt_dir,
self.program,
self.verifier_command,
) = args


INVARIANTS_JSON_PROMPT = """Given the following Rust program, output Verus invariants that should go into the `while` loop
in the function {function}.
Ensure that the invariants are as comprehensive as they can be.
Expand Down Expand Up @@ -96,9 +114,9 @@
"""


def collect_invariants(args, prg: str):
def collect_invariants(args: ProgramArgs, prg: str) -> list[str]:
func = basename(args.program)[:-3]
result_invariants = []
result_invariants: list[str] = []
for temperature in [0.0, 0.1, 0.3, 0.4, 0.5, 0.7, 1.0]:
llm = LLM(
grazie_token=args.grazie_token,
Expand All @@ -110,7 +128,7 @@ def collect_invariants(args, prg: str):
llm.user_prompts.append(
INVARIANTS_JSON_PROMPT.replace("{program}", prg).replace("{function}", func)
)
response = llm._make_request()
response = llm._make_request() # type: ignore
try:
invariants = json.loads(response)
result_invariants.extend(invariants)
Expand All @@ -126,7 +144,7 @@ def remove_failed_invariants(
llm: LLM, invariants: list[str], err: str
) -> Optional[list[str]]:
llm.user_prompts.append(REMOVE_FAILED_INVARIANTS_PROMPT.format(error=err))
response = llm._make_request()
response = llm._make_request() # type: ignore
try:
new_invariants = json.loads(response)
log.debug("REMOVED: {}".format(set(invariants).difference(set(new_invariants))))
Expand All @@ -138,7 +156,7 @@ def remove_failed_invariants(


def houdini(
args, verifier: Verifier, prg: str, invariants: list[str]
args: ProgramArgs, verifier: Verifier, prg: str, invariants: list[str]
) -> Optional[list[str]]:
func = basename(args.program).strip(".rs")
log.info(f"Starting Houdini for {func} in file {args.program}")
Expand Down Expand Up @@ -201,7 +219,7 @@ def main():
parser.add_argument("--program", required=True)
parser.add_argument("--verifier-command", required=True)

args = parser.parse_args()
args = ProgramArgs(*parser.parse_args())

log.info("Running on program: {}".format(args.program))

Expand Down
21 changes: 12 additions & 9 deletions verified_cogen/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.chat.response import ChatResponse
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import AuthType, GrazieApiGatewayClient
from grazie.api.client.llm_parameters import LLMParameters
Expand Down Expand Up @@ -32,15 +33,17 @@ def __init__(
self.profile = Profile.get_by_name(profile)
self.prompt_dir = prompt_dir
self.is_gpt = "gpt" in self.profile.name
self.user_prompts = []
self.responses = []
self.user_prompts: list[str] = []
self.responses: list[str] = []
self.had_errors = False
self.temperature = temperature
self.system_prompt = (
system_prompt if system_prompt else prompts.sys_prompt(self.prompt_dir)
)

def _request(self, temperature: Optional[float] = None, tries: int = 5):
def _request(
self, temperature: Optional[float] = None, tries: int = 5
) -> ChatResponse:
if tries == 0:
raise Exception("Exhausted tries to get response from Grazie API")
if temperature is None:
Expand Down Expand Up @@ -70,31 +73,31 @@ def _request(self, temperature: Optional[float] = None, tries: int = 5):
logger.warning("Grazie API is down, retrying...")
return self._request(temperature, tries - 1)

def _make_request(self):
def _make_request(self) -> str:
response = self._request().content
self.responses.append(response)
return extract_code_from_llm_output(response)

def produce(self, prg: str):
def produce(self, prg: str) -> str:
self.user_prompts.append(
prompts.produce_prompt(self.prompt_dir).format(program=prg)
)
return self._make_request()

def add(self, prg: str, checks: str, function: Optional[str] = None):
def add(self, prg: str, checks: str, function: Optional[str] = None) -> str:
prompt = prompts.add_prompt(self.prompt_dir).format(program=prg, checks=checks)
if "{function}" in prompt and function is not None:
prompt = prompt.replace("{function}", function)
self.user_prompts.append(prompt)
return self._make_request()

def rewrite(self, prg: str):
def rewrite(self, prg: str) -> str:
self.user_prompts.append(
prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg)
)
return self._make_request()

def ask_for_fixed(self, err: str):
def ask_for_fixed(self, err: str) -> str:
prompt = (
prompts.ask_for_fixed_had_errors_prompt(self.prompt_dir)
if self.had_errors
Expand All @@ -103,6 +106,6 @@ def ask_for_fixed(self, err: str):
self.user_prompts.append(prompt.format(error=err))
return self._make_request()

def ask_for_timeout(self):
def ask_for_timeout(self) -> str:
self.user_prompts.append(prompts.ask_for_timeout_prompt(self.prompt_dir))
return self._make_request()
8 changes: 4 additions & 4 deletions verified_cogen/llm/prompts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Optional
from typing import Any, Optional


class Singleton(object):
_instance = None

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]):
if not isinstance(cls._instance, cls):
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance


class PromptCache(Singleton):
cache: dict = {}
cache: dict[str, str] = {}

def __init__(self):
def __init__(self, *args: list[Any], **kwargs: dict[str, Any]):
self.cache = {}

def get(self, key: str) -> Optional[str]:
Expand Down
Loading

0 comments on commit 47640f8

Please sign in to comment.