Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ESBMC output formatting and source code formatting #117

Merged
merged 5 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
"--context-bound",
"2"
],
"requests": {
"esbmc_output_type": "vp",
"source_code_format": "full",
"llm_requests": {
"max_tries": 5,
"timeout": 60
},
Expand Down Expand Up @@ -68,6 +70,7 @@
]
},
"generate_solution": {
"max_attempts": 5,
"temperature": 1.3,
"scenarios": {
"division by zero": {
Expand Down
3 changes: 3 additions & 0 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def main() -> None:

check_health()

printv(f"Source code format: {config.source_code_format}")
printv(f"ESBMC output type: {config.esbmc_output_type}")

anim: LoadingWidget = create_loading_widget()

# Read the source code and esbmc output.
Expand Down
64 changes: 32 additions & 32 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from .. import config
from ..msg_bus import Signal
from ..loading_widget import create_loading_widget
from ..esbmc_util import esbmc_load_source_code
from ..solution_generator import SolutionGenerator
from ..esbmc_util import (
esbmc_get_error_type,
esbmc_load_source_code,
)
from ..solution_generator import SolutionGenerator, get_esbmc_output_formatted
from ..logging import printv, printvv

# TODO Remove built in messages and move them to config.
Expand All @@ -28,28 +31,14 @@ def __init__(self) -> None:
)
self.anim = create_loading_widget()

def _resolve_scenario(self, esbmc_output: str) -> str:
# Start search from the marker.
marker: str = "Violated property:\n"
violated_property_index: int = esbmc_output.rfind(marker) + len(marker)
from_loc_error_msg: str = esbmc_output[violated_property_index:]
# Find second new line which contains the location of the violated
# property and that should point to the line with the type of error.
# In this case, the type of error is the "scenario".
scenario_index: int = from_loc_error_msg.find("\n")
scenario: str = from_loc_error_msg[scenario_index + 1 :]
scenario_end_l_index: int = scenario.find("\n")
scenario = scenario[:scenario_end_l_index].strip()
return scenario

@override
def execute(self, **kwargs: Any) -> Tuple[bool, str]:
file_name: str = kwargs["file_name"]
source_code: str = kwargs["source_code"]
esbmc_output: str = kwargs["esbmc_output"]

# Parse the esbmc output here and determine what "Scenario" to use.
scenario: str = self._resolve_scenario(esbmc_output)
scenario: str = esbmc_get_error_type(esbmc_output)

printv(f"Scenario: {scenario}")
printv(
Expand All @@ -58,33 +47,33 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
else "Using generic prompt..."
)

llm = config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
)

solution_generator = SolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
source_code=source_code,
esbmc_output=esbmc_output,
ai_model=config.ai_model,
llm=llm,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)

print()

max_retries: int = 10
max_retries: int = config.fix_code_max_attempts
for idx in range(max_retries):
# Get a response. Use while loop to account for if the message stack
# gets full, then need to compress and retry.
response: str = ""
llm_solution: str = ""
while True:
# Generate AI solution
self.anim.start("Generating Solution... Please Wait")
response, finish_reason = solution_generator.generate_solution()
llm_solution, finish_reason = solution_generator.generate_solution()
self.anim.stop()
if finish_reason == FinishReason.length:
self.anim.start("Compressing message stack... Please Wait")
Expand All @@ -96,7 +85,7 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
# Print verbose lvl 2
printvv("\nGeneration:")
printvv("-" * get_terminal_size().columns)
printvv(response)
printvv(llm_solution)
printvv("-" * get_terminal_size().columns)
printvv("")

Expand All @@ -105,22 +94,33 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
self.anim.start("Verifying with ESBMC... Please Wait")
exit_code, esbmc_output, esbmc_err_output = esbmc_load_source_code(
file_path=file_name,
source_code=str(response),
source_code=llm_solution,
esbmc_params=config.esbmc_params,
auto_clean=config.temp_auto_clean,
timeout=config.verifier_timeout,
)
self.anim.stop()

# TODO Move this process into Solution Generator since have (beginning) is done
# inside, and the other half is done here.
try:
esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=config.esbmc_output_type,
esbmc_output=esbmc_output,
)
except ValueError:
# Probably did not compile and so ESBMC source code is clang output.
pass

# Print verbose lvl 2
printvv("-" * get_terminal_size().columns)
printvv(esbmc_output)
printvv(esbmc_err_output)
printvv("-" * get_terminal_size().columns)

if exit_code == 0:
self.on_solution_signal.emit(response)
return False, response
self.on_solution_signal.emit(llm_solution)
return False, llm_solution

# Failure case
print(f"Failure {idx+1}/{max_retries}: Retrying...")
Expand Down
42 changes: 40 additions & 2 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
temp_file_dir: str = "."
ai_model: AIModel = AIModels.GPT_3.value

esbmc_output_type: str = "full"
source_code_format: str = "full"

fix_code_max_attempts: int = 5

requests_max_tries: int = 5
requests_timeout: float = 60
verifier_timeout: float = 60
Expand Down Expand Up @@ -344,18 +349,51 @@ def load_config(file_path: str) -> None:
esbmc_params,
)

global fix_code_max_attempts
fix_code_max_attempts = int(
_load_config_real_number(
config_file=config_file["chat_modes"]["generate_solution"],
name="max_attempts",
default=fix_code_max_attempts,
)
)

global source_code_format
source_code_format, _ = _load_config_value(
config_file=config_file,
name="source_code_format",
default=source_code_format,
)

if source_code_format not in ["full", "single"]:
raise Exception(
f"Source code format in the config is not valid: {source_code_format}"
)

global esbmc_output_type
esbmc_output_type, _ = _load_config_value(
config_file=config_file,
name="esbmc_output_type",
default=esbmc_output_type,
)

if esbmc_output_type not in ["full", "vp", "ce"]:
raise Exception(
f"ESBMC output type in the config is not valid: {esbmc_output_type}"
)

global requests_max_tries
requests_max_tries = int(
_load_config_real_number(
config_file=config_file["requests"],
config_file=config_file["llm_requests"],
name="max_tries",
default=requests_max_tries,
)
)

global requests_timeout
requests_timeout = _load_config_real_number(
config_file=config_file["requests"],
config_file=config_file["llm_requests"],
name="timeout",
default=requests_timeout,
)
Expand Down
58 changes: 58 additions & 0 deletions esbmc_ai/esbmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,64 @@
from . import config


def esbmc_get_violated_property(esbmc_output: str) -> Optional[str]:
"""Gets the violated property line of the ESBMC output."""
# Find "Violated property:" string in ESBMC output
lines: list[str] = esbmc_output.splitlines()
for ix, line in enumerate(lines):
if "Violated property:" == line:
return "\n".join(lines[ix : ix + 3])
return None


def esbmc_get_counter_example(esbmc_output: str) -> Optional[str]:
"""Gets ESBMC output after and including [Counterexample]"""
idx: int = esbmc_output.find("[Counterexample]\n")
if idx == -1:
return None
else:
return esbmc_output[idx:]


def esbmc_get_error_type(esbmc_output: str) -> str:
"""Gets the error of violated property, the entire line."""
# TODO Test me
# Start search from the marker.
marker: str = "Violated property:\n"
violated_property_index: int = esbmc_output.rfind(marker) + len(marker)
from_loc_error_msg: str = esbmc_output[violated_property_index:]
# Find second new line which contains the location of the violated
# property and that should point to the line with the type of error.
# In this case, the type of error is the "scenario".
scenario_index: int = from_loc_error_msg.find("\n")
scenario: str = from_loc_error_msg[scenario_index + 1 :]
scenario_end_l_index: int = scenario.find("\n")
scenario = scenario[:scenario_end_l_index].strip()
return scenario


def get_source_code_err_line(esbmc_output: str) -> Optional[int]:
# Find "Violated property:" string in ESBMC output
violated_property: Optional[str] = esbmc_get_violated_property(esbmc_output)
if violated_property:
# Get the line of the violated property.
pos_line: str = violated_property.splitlines()[1]
pos_line_split: list[str] = pos_line.split(" ")
for ix, word in enumerate(pos_line_split):
if word == "line":
# Get the line number
return int(pos_line_split[ix + 1])
return None


def get_source_code_err_line_idx(esbmc_output: str) -> Optional[int]:
line: Optional[int] = get_source_code_err_line(esbmc_output)
if line:
return line - 1
else:
return None


def esbmc(path: str, esbmc_params: list, timeout: Optional[float] = None):
"""Exit code will be 0 if verification successful, 1 if verification
failed. And any other number for compilation error/general errors."""
Expand Down
20 changes: 20 additions & 0 deletions esbmc_ai/frontend/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,23 @@ def get_main_source_file() -> SourceFile:
def get_source_files() -> list[SourceFile]:
global _source_files
return list(_source_files)


def apply_line_patch(source_code: str, patch: str, start: int, end: int) -> str:
"""Applies a patch to the source code.

To replace a single line, start and end are equal.

Args:
* source_code - The source code to apply the patch to.
* patch - Can be a line or multiple lines but will replace the start and
end region defined.
* start - Line index to mark start of replacement.
* end - Marks the end of the region where the patch will be applied to.
End is non-inclusive."""
assert (
start <= end
), f"start ({start}) needs to be less than or equal to end ({end})"
lines: list[str] = source_code.splitlines()
lines = lines[:start] + [patch] + lines[end + 1 :]
return "\n".join(lines)
2 changes: 2 additions & 0 deletions esbmc_ai/loading_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def start(self, text: str = "Please Wait") -> None:
self.thread.start()

def stop(self) -> None:
if not config.loading_hints:
return
self.done = True
# Block until end.
if self.thread:
Expand Down
Loading
Loading