Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument to write result of fix code to directory #148

Merged
merged 11 commits into from
Nov 6, 2024
93 changes: 46 additions & 47 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,18 @@

# Author: Yiannis Charalambous 2023

import os
from pathlib import Path
import re
import sys

# Enables arrow key functionality for input(). Do not remove import.
import readline
from typing import Optional

from langchain_core.language_models import BaseChatModel

from esbmc_ai.commands.fix_code_command import FixCodeCommandResult

_ = readline

import argparse
from langchain_core.language_models import BaseChatModel

import argparse

from esbmc_ai import Config
from esbmc_ai import __author__, __version__
Expand All @@ -29,9 +24,10 @@
FixCodeCommand,
HelpCommand,
ExitCommand,
FixCodeCommandResult,
)

from esbmc_ai.loading_widget import LoadingWidget, create_loading_widget
from esbmc_ai.loading_widget import BaseLoadingWidget, LoadingWidget
from esbmc_ai.chats import UserChat
from esbmc_ai.logging import print_horizontal_line, printv, printvv
from esbmc_ai.esbmc_util import ESBMCUtil
Expand Down Expand Up @@ -110,7 +106,8 @@ def print_assistant_response(


def init_commands_list() -> None:
# Setup Help command and commands list.
"""Setup Help command and commands list."""
# Built in commands
global help_command
commands.extend(
[
Expand All @@ -129,18 +126,15 @@ def update_solution(source_code: str) -> None:
get_solution().files[0].update_content(content=source_code, reset_changes=True)


def _run_esbmc(source_file: SourceFile, anim: Optional[LoadingWidget] = None) -> str:
def _run_esbmc(source_file: SourceFile, anim: BaseLoadingWidget) -> str:
assert source_file.file_path

if anim:
anim.start("ESBMC is processing... Please Wait")
exit_code, esbmc_output = ESBMCUtil.esbmc(
path=source_file.file_path,
esbmc_params=Config.get_value("esbmc.params"),
timeout=Config.get_value("esbmc.timeout"),
)
if anim:
anim.stop()
with anim("ESBMC is processing... Please Wait"):
exit_code, esbmc_output = ESBMCUtil.esbmc(
path=source_file.file_path,
esbmc_params=Config.get_value("esbmc.params"),
timeout=Config.get_value("esbmc.timeout"),
)

# ESBMC will output 0 for verification success and 1 for verification
# failed, if anything else gets thrown, it's an ESBMC error.
Expand Down Expand Up @@ -186,12 +180,18 @@ def _execute_fix_code_command(source_file: SourceFile) -> FixCodeCommandResult:
source_code_format=Config.get_value("source_code_format"),
esbmc_output_format=Config.get_value("esbmc.output_type"),
scenarios=Config.get_fix_code_scenarios(),
temp_file_dir=Config.get_value("temp_file_dir"),
output_dir=Config.output_dir,
)


def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None:
path_arg: Path = Path(args.filename)

anim: BaseLoadingWidget = (
LoadingWidget() if Config.get_value("loading_hints") else BaseLoadingWidget()
)

solution: Solution = get_solution()
if path_arg.is_dir():
for path in path_arg.glob("**/*"):
Expand All @@ -204,7 +204,7 @@ def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None:
case fix_code_command.command_name:
for source_file in solution.files:
# Run ESBMC first round
esbmc_output: str = _run_esbmc(source_file)
esbmc_output: str = _run_esbmc(source_file, anim)
source_file.assign_verifier_output(esbmc_output)

result: FixCodeCommandResult = _execute_fix_code_command(source_file)
Expand Down Expand Up @@ -310,12 +310,13 @@ def main() -> None:
help="Generate patch files and place them in the same folder as the source files.",
)

# parser.add_argument(
# "--generate-default-config",
# action="store_true",
# default=False,
# help="Will generate and save the default config to the current working directory as 'esbmcai.toml'."
# )
parser.add_argument(
"-o",
"--output-dir",
default="",
help="Store the result at the following dir. Specifying the same directory will "
+ "overwrite the original file.",
)

args: argparse.Namespace = parser.parse_args()

Expand All @@ -331,7 +332,9 @@ def main() -> None:
printv(f"Source code format: {Config.get_value('source_code_format')}")
printv(f"ESBMC output type: {Config.get_value('esbmc.output_type')}")

anim: LoadingWidget = create_loading_widget()
anim: BaseLoadingWidget = (
LoadingWidget() if Config.get_value("loading_hints") else BaseLoadingWidget()
)

# Read the source code and esbmc output.
printv("Reading source code...")
Expand Down Expand Up @@ -378,7 +381,6 @@ def main() -> None:
assert len(solution.files) == 1

source_file: SourceFile = solution.files[0]
assert source_file.file_path

esbmc_output: str = _run_esbmc(source_file, anim)

Expand All @@ -394,8 +396,8 @@ def main() -> None:
chat_llm: BaseChatModel = Config.get_ai_model().create_llm(
api_keys=Config.api_keys,
temperature=Config.get_value("user_chat.temperature"),
requests_max_tries=Config.get_value("requests.max_tries"),
requests_timeout=Config.get_value("requests.timeout"),
requests_max_tries=Config.get_value("llm_requests.max_tries"),
requests_timeout=Config.get_value("llm_requests.timeout"),
)

printv("Creating user chat")
Expand All @@ -416,16 +418,14 @@ def main() -> None:
response: ChatResponse
if len(str(Config.get_user_chat_initial().content)) > 0:
printv("Using initial prompt from file...\n")
anim.start("Model is parsing ESBMC output... Please Wait")
try:
response = chat.send_message(
message=str(Config.get_user_chat_initial().content),
)
except Exception as e:
print("There was an error while generating a response: {e}")
sys.exit(1)
finally:
anim.stop()
with anim("Model is parsing ESBMC output... Please Wait"):
try:
response = chat.send_message(
message=str(Config.get_user_chat_initial().content),
)
except Exception as e:
print("There was an error while generating a response: {e}")
sys.exit(1)

if response.finish_reason == FinishReason.length:
raise RuntimeError(f"The token length is too large: {chat.ai_model.tokens}")
Expand Down Expand Up @@ -479,17 +479,16 @@ def main() -> None:
# User chat mode send and process current message response.
while True:
# Send user message to AI model and process.
anim.start("Generating response... Please Wait")
response = chat.send_message(user_message)
anim.stop()
with anim("Generating response... Please Wait"):
response = chat.send_message(user_message)

if response.finish_reason == FinishReason.stop:
break
elif response.finish_reason == FinishReason.length:
anim.start(
with anim(
"Message stack limit reached. Shortening message stack... Please Wait"
)
chat.compress_message_stack()
anim.stop()
):
chat.compress_message_stack()
continue
else:
raise NotImplementedError(
Expand Down
5 changes: 3 additions & 2 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,9 @@ def is_valid_ai_model(
name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model

# Try accessing openai api and checking if there is a model defined.
# NOTE: This is not tested as no way to mock API currently.
if api_keys and api_keys.openai:
# Will only work on models that start with gpt- to avoid spamming API and
# getting blocked. NOTE: This is not tested as no way to mock API currently.
if name.startswith("gpt-") and api_keys and api_keys.openai:
try:
for model in Client(api_key=api_keys.openai).models.list().data:
if model.id == name:
Expand Down
1 change: 1 addition & 0 deletions esbmc_ai/chats/base_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from abc import abstractmethod
from typing import Optional
import traceback

from langchain.schema import (
BaseMessage,
Expand Down
5 changes: 4 additions & 1 deletion esbmc_ai/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .chat_command import ChatCommand
from .exit_command import ExitCommand
from .fix_code_command import FixCodeCommand
from .fix_code_command import FixCodeCommand, FixCodeCommandResult
from .help_command import HelpCommand
from .command_result import CommandResult

"""This module contains built-in commands that can be executed by ESBMC-AI."""

Expand All @@ -10,4 +11,6 @@
"ExitCommand",
"FixCodeCommand",
"HelpCommand",
"CommandResult",
"FixCodeCommandResult",
]
5 changes: 2 additions & 3 deletions esbmc_ai/commands/chat_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@


class ChatCommand(ABC):
command_name: str
help_message: str

def __init__(
self,
command_name: str = "",
help_message: str = "",
authors: str = "",
) -> None:
super().__init__()
self.command_name = command_name
self.help_message = help_message
self.authors = authors

@abstractmethod
def execute(self, **kwargs: Optional[Any]) -> Optional[CommandResult]:
Expand Down
62 changes: 43 additions & 19 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Author: Yiannis Charalambous

from pathlib import Path
import sys
from typing import Any, Optional
from typing_extensions import override
Expand All @@ -11,20 +12,28 @@
from esbmc_ai.chats.solution_generator import ESBMCTimedOutException
from esbmc_ai.commands.command_result import CommandResult
from esbmc_ai.config import FixCodeScenarios
from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator
from esbmc_ai.chats.reverse_order_solution_generator import (
ReverseOrderSolutionGenerator,
)
from esbmc_ai.solution import SourceFile

from .chat_command import ChatCommand
from ..msg_bus import Signal
from ..loading_widget import create_loading_widget
from ..loading_widget import BaseLoadingWidget
from ..esbmc_util import ESBMCUtil
from ..logging import print_horizontal_line, printv, printvv


class FixCodeCommandResult(CommandResult):
def __init__(self, successful: bool, repaired_source: Optional[str] = None) -> None:
def __init__(
self,
successful: bool,
attempts: int,
repaired_source: Optional[str] = None,
) -> None:
super().__init__()
self._successful: bool = successful
self.attempts: int = attempts
self.repaired_source: Optional[str] = repaired_source

@property
Expand All @@ -51,7 +60,6 @@ def __init__(self) -> None:
command_name="fix-code",
help_message="Generates a solution for this code, and reevaluates it with ESBMC.",
)
self.anim = create_loading_widget()

@override
def execute(self, **kwargs: Any) -> FixCodeCommandResult:
Expand All @@ -66,7 +74,6 @@ def print_raw_conversation() -> None:

# Handle kwargs
source_file: SourceFile = kwargs["source_file"]
assert source_file.file_path

generate_patches: bool = (
kwargs["generate_patches"] if "generate_patches" in kwargs else False
Expand All @@ -88,9 +95,18 @@ def print_raw_conversation() -> None:
esbmc_params: list[str] = kwargs["esbmc_params"]
verifier_timeout: int = kwargs["verifier_timeout"]
temp_auto_clean: bool = kwargs["temp_auto_clean"]
temp_file_dir: Optional[Path] = (
kwargs["temp_file_dir"] if "temp_file_dir" in kwargs else None
)
raw_conversation: bool = (
kwargs["raw_conversation"] if "raw_conversation" in kwargs else False
)
output_dir: Optional[Path] = (
kwargs["output_dir"] if "output_dir" in kwargs else None
)
anim: BaseLoadingWidget = (
kwargs["anim"] if "anim" in kwargs else BaseLoadingWidget()
)
# End of handle kwargs

match message_history:
Expand Down Expand Up @@ -154,12 +170,13 @@ def print_raw_conversation() -> None:
# gets full, then need to compress and retry.
while True:
# Generate AI solution
self.anim.start("Generating Solution... Please Wait")
llm_solution, finish_reason = solution_generator.generate_solution()
self.anim.stop()
with anim("Generating Solution... Please Wait"):
llm_solution, finish_reason = solution_generator.generate_solution()

if finish_reason == FinishReason.length:
solution_generator.compress_message_stack()
else:
# Update the source file state
source_file.update_content(llm_solution)
break

Expand All @@ -172,15 +189,15 @@ def print_raw_conversation() -> None:

# Pass to ESBMC, a workaround is used where the file is saved
# to a temporary location since ESBMC needs it in file format.
self.anim.start("Verifying with ESBMC... Please Wait")
exit_code, esbmc_output = ESBMCUtil.esbmc_load_source_code(
source_file=source_file,
source_file_content_index=-1,
esbmc_params=esbmc_params,
auto_clean=temp_auto_clean,
timeout=verifier_timeout,
)
self.anim.stop()
with anim("Verifying with ESBMC... Please Wait"):
exit_code, esbmc_output = ESBMCUtil.esbmc_load_source_code(
source_file=source_file,
source_file_content_index=-1,
esbmc_params=esbmc_params,
auto_clean=temp_auto_clean,
temp_file_dir=temp_file_dir,
timeout=verifier_timeout,
)

source_file.assign_verifier_output(esbmc_output)
del esbmc_output
Expand All @@ -206,7 +223,14 @@ def print_raw_conversation() -> None:
else:
returned_source = source_file.latest_content

return FixCodeCommandResult(True, returned_source)
# Check if an output directory is specified and save to it
if output_dir:
assert (
output_dir.is_dir()
), "FixCodeCommand: Output directory needs to be valid"
with open(output_dir / source_file.file_path.name, "w") as file:
file.write(source_file.latest_content)
return FixCodeCommandResult(True, attempt, returned_source)

try:
# Update state
Expand All @@ -228,4 +252,4 @@ def print_raw_conversation() -> None:
if raw_conversation:
print_raw_conversation()

return FixCodeCommandResult(False, None)
return FixCodeCommandResult(False, max_attempts, None)
Loading
Loading