diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index ee80d02..2bbc755 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -2,23 +2,18 @@ # Author: Yiannis Charalambous 2023 -import os from pathlib import Path import re import sys # Enables arrow key functionality for input(). Do not remove import. import readline -from typing import Optional - -from langchain_core.language_models import BaseChatModel - -from esbmc_ai.commands.fix_code_command import FixCodeCommandResult _ = readline -import argparse +from langchain_core.language_models import BaseChatModel +import argparse from esbmc_ai import Config from esbmc_ai import __author__, __version__ @@ -29,9 +24,10 @@ FixCodeCommand, HelpCommand, ExitCommand, + FixCodeCommandResult, ) -from esbmc_ai.loading_widget import LoadingWidget, create_loading_widget +from esbmc_ai.loading_widget import BaseLoadingWidget, LoadingWidget from esbmc_ai.chats import UserChat from esbmc_ai.logging import print_horizontal_line, printv, printvv from esbmc_ai.esbmc_util import ESBMCUtil @@ -110,7 +106,8 @@ def print_assistant_response( def init_commands_list() -> None: - # Setup Help command and commands list. + """Setup Help command and commands list.""" + # Built in commands global help_command commands.extend( [ @@ -129,18 +126,15 @@ def update_solution(source_code: str) -> None: get_solution().files[0].update_content(content=source_code, reset_changes=True) -def _run_esbmc(source_file: SourceFile, anim: Optional[LoadingWidget] = None) -> str: +def _run_esbmc(source_file: SourceFile, anim: BaseLoadingWidget) -> str: assert source_file.file_path - if anim: - anim.start("ESBMC is processing... Please Wait") - exit_code, esbmc_output = ESBMCUtil.esbmc( - path=source_file.file_path, - esbmc_params=Config.get_value("esbmc.params"), - timeout=Config.get_value("esbmc.timeout"), - ) - if anim: - anim.stop() + with anim("ESBMC is processing... Please Wait"): + exit_code, esbmc_output = ESBMCUtil.esbmc( + path=source_file.file_path, + esbmc_params=Config.get_value("esbmc.params"), + timeout=Config.get_value("esbmc.timeout"), + ) # ESBMC will output 0 for verification success and 1 for verification # failed, if anything else gets thrown, it's an ESBMC error. @@ -186,12 +180,18 @@ def _execute_fix_code_command(source_file: SourceFile) -> FixCodeCommandResult: source_code_format=Config.get_value("source_code_format"), esbmc_output_format=Config.get_value("esbmc.output_type"), scenarios=Config.get_fix_code_scenarios(), + temp_file_dir=Config.get_value("temp_file_dir"), + output_dir=Config.output_dir, ) def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None: path_arg: Path = Path(args.filename) + anim: BaseLoadingWidget = ( + LoadingWidget() if Config.get_value("loading_hints") else BaseLoadingWidget() + ) + solution: Solution = get_solution() if path_arg.is_dir(): for path in path_arg.glob("**/*"): @@ -204,7 +204,7 @@ def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None: case fix_code_command.command_name: for source_file in solution.files: # Run ESBMC first round - esbmc_output: str = _run_esbmc(source_file) + esbmc_output: str = _run_esbmc(source_file, anim) source_file.assign_verifier_output(esbmc_output) result: FixCodeCommandResult = _execute_fix_code_command(source_file) @@ -310,12 +310,13 @@ def main() -> None: help="Generate patch files and place them in the same folder as the source files.", ) - # parser.add_argument( - # "--generate-default-config", - # action="store_true", - # default=False, - # help="Will generate and save the default config to the current working directory as 'esbmcai.toml'." - # ) + parser.add_argument( + "-o", + "--output-dir", + default="", + help="Store the result at the following dir. Specifying the same directory will " + + "overwrite the original file.", + ) args: argparse.Namespace = parser.parse_args() @@ -331,7 +332,9 @@ def main() -> None: printv(f"Source code format: {Config.get_value('source_code_format')}") printv(f"ESBMC output type: {Config.get_value('esbmc.output_type')}") - anim: LoadingWidget = create_loading_widget() + anim: BaseLoadingWidget = ( + LoadingWidget() if Config.get_value("loading_hints") else BaseLoadingWidget() + ) # Read the source code and esbmc output. printv("Reading source code...") @@ -378,7 +381,6 @@ def main() -> None: assert len(solution.files) == 1 source_file: SourceFile = solution.files[0] - assert source_file.file_path esbmc_output: str = _run_esbmc(source_file, anim) @@ -394,8 +396,8 @@ def main() -> None: chat_llm: BaseChatModel = Config.get_ai_model().create_llm( api_keys=Config.api_keys, temperature=Config.get_value("user_chat.temperature"), - requests_max_tries=Config.get_value("requests.max_tries"), - requests_timeout=Config.get_value("requests.timeout"), + requests_max_tries=Config.get_value("llm_requests.max_tries"), + requests_timeout=Config.get_value("llm_requests.timeout"), ) printv("Creating user chat") @@ -416,16 +418,14 @@ def main() -> None: response: ChatResponse if len(str(Config.get_user_chat_initial().content)) > 0: printv("Using initial prompt from file...\n") - anim.start("Model is parsing ESBMC output... Please Wait") - try: - response = chat.send_message( - message=str(Config.get_user_chat_initial().content), - ) - except Exception as e: - print("There was an error while generating a response: {e}") - sys.exit(1) - finally: - anim.stop() + with anim("Model is parsing ESBMC output... Please Wait"): + try: + response = chat.send_message( + message=str(Config.get_user_chat_initial().content), + ) + except Exception as e: + print("There was an error while generating a response: {e}") + sys.exit(1) if response.finish_reason == FinishReason.length: raise RuntimeError(f"The token length is too large: {chat.ai_model.tokens}") @@ -479,17 +479,16 @@ def main() -> None: # User chat mode send and process current message response. while True: # Send user message to AI model and process. - anim.start("Generating response... Please Wait") - response = chat.send_message(user_message) - anim.stop() + with anim("Generating response... Please Wait"): + response = chat.send_message(user_message) + if response.finish_reason == FinishReason.stop: break elif response.finish_reason == FinishReason.length: - anim.start( + with anim( "Message stack limit reached. Shortening message stack... Please Wait" - ) - chat.compress_message_stack() - anim.stop() + ): + chat.compress_message_stack() continue else: raise NotImplementedError( diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index db13e6e..33384f9 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -249,8 +249,9 @@ def is_valid_ai_model( name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model # Try accessing openai api and checking if there is a model defined. - # NOTE: This is not tested as no way to mock API currently. - if api_keys and api_keys.openai: + # Will only work on models that start with gpt- to avoid spamming API and + # getting blocked. NOTE: This is not tested as no way to mock API currently. + if name.startswith("gpt-") and api_keys and api_keys.openai: try: for model in Client(api_key=api_keys.openai).models.list().data: if model.id == name: diff --git a/esbmc_ai/chats/base_chat_interface.py b/esbmc_ai/chats/base_chat_interface.py index feb27b0..00c570e 100644 --- a/esbmc_ai/chats/base_chat_interface.py +++ b/esbmc_ai/chats/base_chat_interface.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Optional +import traceback from langchain.schema import ( BaseMessage, diff --git a/esbmc_ai/reverse_order_solution_generator.py b/esbmc_ai/chats/reverse_order_solution_generator.py similarity index 100% rename from esbmc_ai/reverse_order_solution_generator.py rename to esbmc_ai/chats/reverse_order_solution_generator.py diff --git a/esbmc_ai/commands/__init__.py b/esbmc_ai/commands/__init__.py index 01e7577..52f4e49 100644 --- a/esbmc_ai/commands/__init__.py +++ b/esbmc_ai/commands/__init__.py @@ -1,7 +1,8 @@ from .chat_command import ChatCommand from .exit_command import ExitCommand -from .fix_code_command import FixCodeCommand +from .fix_code_command import FixCodeCommand, FixCodeCommandResult from .help_command import HelpCommand +from .command_result import CommandResult """This module contains built-in commands that can be executed by ESBMC-AI.""" @@ -10,4 +11,6 @@ "ExitCommand", "FixCodeCommand", "HelpCommand", + "CommandResult", + "FixCodeCommandResult", ] diff --git a/esbmc_ai/commands/chat_command.py b/esbmc_ai/commands/chat_command.py index 6109e1c..88c7851 100644 --- a/esbmc_ai/commands/chat_command.py +++ b/esbmc_ai/commands/chat_command.py @@ -7,17 +7,16 @@ class ChatCommand(ABC): - command_name: str - help_message: str - def __init__( self, command_name: str = "", help_message: str = "", + authors: str = "", ) -> None: super().__init__() self.command_name = command_name self.help_message = help_message + self.authors = authors @abstractmethod def execute(self, **kwargs: Optional[Any]) -> Optional[CommandResult]: diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index 12c9847..84dc3a4 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -1,5 +1,6 @@ # Author: Yiannis Charalambous +from pathlib import Path import sys from typing import Any, Optional from typing_extensions import override @@ -11,20 +12,28 @@ from esbmc_ai.chats.solution_generator import ESBMCTimedOutException from esbmc_ai.commands.command_result import CommandResult from esbmc_ai.config import FixCodeScenarios -from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator +from esbmc_ai.chats.reverse_order_solution_generator import ( + ReverseOrderSolutionGenerator, +) from esbmc_ai.solution import SourceFile from .chat_command import ChatCommand from ..msg_bus import Signal -from ..loading_widget import create_loading_widget +from ..loading_widget import BaseLoadingWidget from ..esbmc_util import ESBMCUtil from ..logging import print_horizontal_line, printv, printvv class FixCodeCommandResult(CommandResult): - def __init__(self, successful: bool, repaired_source: Optional[str] = None) -> None: + def __init__( + self, + successful: bool, + attempts: int, + repaired_source: Optional[str] = None, + ) -> None: super().__init__() self._successful: bool = successful + self.attempts: int = attempts self.repaired_source: Optional[str] = repaired_source @property @@ -51,7 +60,6 @@ def __init__(self) -> None: command_name="fix-code", help_message="Generates a solution for this code, and reevaluates it with ESBMC.", ) - self.anim = create_loading_widget() @override def execute(self, **kwargs: Any) -> FixCodeCommandResult: @@ -66,7 +74,6 @@ def print_raw_conversation() -> None: # Handle kwargs source_file: SourceFile = kwargs["source_file"] - assert source_file.file_path generate_patches: bool = ( kwargs["generate_patches"] if "generate_patches" in kwargs else False @@ -88,9 +95,18 @@ def print_raw_conversation() -> None: esbmc_params: list[str] = kwargs["esbmc_params"] verifier_timeout: int = kwargs["verifier_timeout"] temp_auto_clean: bool = kwargs["temp_auto_clean"] + temp_file_dir: Optional[Path] = ( + kwargs["temp_file_dir"] if "temp_file_dir" in kwargs else None + ) raw_conversation: bool = ( kwargs["raw_conversation"] if "raw_conversation" in kwargs else False ) + output_dir: Optional[Path] = ( + kwargs["output_dir"] if "output_dir" in kwargs else None + ) + anim: BaseLoadingWidget = ( + kwargs["anim"] if "anim" in kwargs else BaseLoadingWidget() + ) # End of handle kwargs match message_history: @@ -154,12 +170,13 @@ def print_raw_conversation() -> None: # gets full, then need to compress and retry. while True: # Generate AI solution - self.anim.start("Generating Solution... Please Wait") - llm_solution, finish_reason = solution_generator.generate_solution() - self.anim.stop() + with anim("Generating Solution... Please Wait"): + llm_solution, finish_reason = solution_generator.generate_solution() + if finish_reason == FinishReason.length: solution_generator.compress_message_stack() else: + # Update the source file state source_file.update_content(llm_solution) break @@ -172,15 +189,15 @@ def print_raw_conversation() -> None: # Pass to ESBMC, a workaround is used where the file is saved # to a temporary location since ESBMC needs it in file format. - self.anim.start("Verifying with ESBMC... Please Wait") - exit_code, esbmc_output = ESBMCUtil.esbmc_load_source_code( - source_file=source_file, - source_file_content_index=-1, - esbmc_params=esbmc_params, - auto_clean=temp_auto_clean, - timeout=verifier_timeout, - ) - self.anim.stop() + with anim("Verifying with ESBMC... Please Wait"): + exit_code, esbmc_output = ESBMCUtil.esbmc_load_source_code( + source_file=source_file, + source_file_content_index=-1, + esbmc_params=esbmc_params, + auto_clean=temp_auto_clean, + temp_file_dir=temp_file_dir, + timeout=verifier_timeout, + ) source_file.assign_verifier_output(esbmc_output) del esbmc_output @@ -206,7 +223,14 @@ def print_raw_conversation() -> None: else: returned_source = source_file.latest_content - return FixCodeCommandResult(True, returned_source) + # Check if an output directory is specified and save to it + if output_dir: + assert ( + output_dir.is_dir() + ), "FixCodeCommand: Output directory needs to be valid" + with open(output_dir / source_file.file_path.name, "w") as file: + file.write(source_file.latest_content) + return FixCodeCommandResult(True, attempt, returned_source) try: # Update state @@ -228,4 +252,4 @@ def print_raw_conversation() -> None: if raw_conversation: print_raw_conversation() - return FixCodeCommandResult(False, None) + return FixCodeCommandResult(False, max_attempts, None) diff --git a/esbmc_ai/commands/help_command.py b/esbmc_ai/commands/help_command.py index 39d5ee2..0ae5b00 100644 --- a/esbmc_ai/commands/help_command.py +++ b/esbmc_ai/commands/help_command.py @@ -24,6 +24,8 @@ def execute(self, **_: Optional[Any]) -> Optional[Any]: for command in self.commands: print(f"/{command.command_name}: {command.help_message}") + if command.authors: + print(f"\tAuthors: {command.authors}") print() print("Useful AI Questions:") diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index 32fcf36..c9dcbbd 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -96,6 +96,7 @@ class Config: raw_conversation: bool = False cfg_path: Path generate_patches: bool + output_dir: Optional[Path] = None _fields: List[ConfigField] = [ ConfigField( @@ -433,6 +434,17 @@ def _load_args(cls, args) -> None: Config.raw_conversation = args.raw_conversation Config.generate_patches = args.generate_patches + if args.output_dir: + path: Path = Path(args.output_dir).expanduser() + if path.is_dir(): + Config.output_dir = path + else: + print( + "Error while parsing arguments: output_dir: dir does not exist:", + Config.output_dir, + ) + sys.exit(1) + @classmethod def _flatten_dict(cls, d, parent_key="", sep="."): """Recursively flattens a nested dictionary.""" diff --git a/esbmc_ai/loading_widget.py b/esbmc_ai/loading_widget.py index 93008e2..3611884 100644 --- a/esbmc_ai/loading_widget.py +++ b/esbmc_ai/loading_widget.py @@ -10,12 +10,36 @@ from time import sleep from itertools import cycle from threading import Thread -from typing import Optional +from typing_extensions import Optional, override -from esbmc_ai import Config +class BaseLoadingWidget: + """Base loading widget, will not display any information.""" + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _ = exc_type + _ = exc_val + _ = exc_tb + self.stop() + + def __call__(self, text: Optional[str] = None): + _ = text + return self + + def start(self, text: str = "") -> None: + _ = text + + def stop(self) -> None: + pass + + +class LoadingWidget(BaseLoadingWidget): + """Loading widget that can display an animation along with some text.""" -class LoadingWidget(object): done: bool = False thread: Optional[Thread] loading_text: str @@ -38,6 +62,12 @@ def __init__( if len(frame) > self.anim_clear_length: self.anim_clear_length = len(frame) + def __call__(self, text: Optional[str] = None): + """Allows you to set the text in a with statement easily.""" + if text: + self.loading_text = text + return self + def _animate(self) -> None: for c in cycle(self.animation): if self.done: @@ -55,36 +85,17 @@ def _animate(self) -> None: terminal.write("\r") terminal.flush() + @override def start(self, text: str = "Please Wait") -> None: - if not Config.get_value("loading_hints"): - return self.done = False self.loading_text = text self.thread = Thread(target=self._animate) self.thread.daemon = True self.thread.start() + @override def stop(self) -> None: - if not Config.get_value("loading_hints"): - return self.done = True # Block until end. if self.thread: self.thread.join() - - -_widgets: list[LoadingWidget] = [] - - -def create_loading_widget( - anim_speed: float = 0.1, - animation: list[str] = ["|", "/", "-", "\\"], -) -> LoadingWidget: - w = LoadingWidget(anim_speed=anim_speed, animation=animation) - _widgets.append(w) - return w - - -def stop_all() -> None: - for w in _widgets: - w.stop() diff --git a/esbmc_ai/logging.py b/esbmc_ai/logging.py index 3ca1a3a..a132e75 100644 --- a/esbmc_ai/logging.py +++ b/esbmc_ai/logging.py @@ -17,22 +17,22 @@ def set_verbose(level: int) -> None: _verbose = level -def printv(m) -> None: +def printv(*m: object) -> None: """Level 1 verbose printing.""" if _verbose > 0: - print(m) + print(*m) -def printvv(m) -> None: +def printvv(*m: object) -> None: """Level 2 verbose printing.""" if _verbose > 1: - print(m) + print(*m) -def printvvv(m) -> None: +def printvvv(*m: object) -> None: """Level 3 verbose printing.""" if _verbose > 2: - print(m) + print(*m) def print_horizontal_line(verbosity: int) -> None: diff --git a/esbmc_ai/solution.py b/esbmc_ai/solution.py index 05cf2a2..75fbc2c 100644 --- a/esbmc_ai/solution.py +++ b/esbmc_ai/solution.py @@ -39,9 +39,9 @@ def apply_line_patch( return "\n".join(lines) def __init__( - self, file_path: Optional[Path], content: str, file_ext: Optional[str] = None + self, file_path: Path, content: str, file_ext: Optional[str] = None ) -> None: - self._file_path: Optional[Path] = file_path + self._file_path: Path = file_path # Content file shows the file throughout the repair process. Index 0 is # the orignial. self._content: list[str] = [content] @@ -50,7 +50,7 @@ def __init__( self._file_ext: Optional[str] = file_ext @property - def file_path(self) -> Optional[Path]: + def file_path(self) -> Path: """Returns the file path of this source file.""" return self._file_path @@ -162,7 +162,7 @@ def save_file( the saved file in /tmp and use the file_path file name only.""" file_name: Optional[str] = None - dir_path: Optional[Path] = None + dir_path: Path if file_path: # If file path is a file, then use the name and directory. If not # then use a temporary name and just store the folder. @@ -172,19 +172,15 @@ def save_file( else: dir_path = file_path else: - if not self._file_path: - raise ValueError( - "Source code file does not have a name or file_path to save to" - ) # Just store the file and use the temp dir. file_name = self._file_path.name - if temp_dir: - dir_path = Path(gettempdir()) + if not temp_dir: + raise ValueError( + "Need to enable temporary directory or provide file path to store to." + ) - assert ( - dir_path - ), "dir_path could not be retrieved: file_path or temp_dir need to be set." + dir_path = Path(gettempdir()) # Create path if it does not exist. if not os.path.exists(dir_path): @@ -241,29 +237,17 @@ def files(self) -> tuple[SourceFile, ...]: @property def files_mapped(self) -> dict[Path, SourceFile]: """Will return the files mapped to their directory. Returns by value.""" - return { - source_file.file_path: source_file - for source_file in self._files - if source_file.file_path - } - - def add_source_file( - self, file_path: Optional[Path], content: Optional[str] - ) -> None: - """Add a source file to the solution.""" - if file_path: - if content: - self._files.append(SourceFile(file_path, content)) - else: - with open(file_path, "r") as file: - self._files.append(SourceFile(file_path, file.read())) - return + return {source_file.file_path: source_file for source_file in self._files} + def add_source_file(self, file_path: Path, content: Optional[str]) -> None: + """Add a source file to the solution. If content is provided then it will + not be loaded.""" + assert file_path if content: self._files.append(SourceFile(file_path, content)) - return - - raise RuntimeError("file_path and content cannot be both invalid!") + else: + with open(file_path, "r") as file: + self._files.append(SourceFile(file_path, file.read())) # Define a global solution (is not required to be used) diff --git a/tests/test_reverse_order_solution_generator.py b/tests/test_reverse_order_solution_generator.py index 3b0d732..000fb64 100644 --- a/tests/test_reverse_order_solution_generator.py +++ b/tests/test_reverse_order_solution_generator.py @@ -11,7 +11,9 @@ from esbmc_ai.config import default_scenario from esbmc_ai.ai_models import AIModel -from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator +from esbmc_ai.chats.reverse_order_solution_generator import ( + ReverseOrderSolutionGenerator, +) @pytest.fixture(scope="function") diff --git a/tests/test_solution.py b/tests/test_solution.py index f5d189f..4ab3c98 100644 --- a/tests/test_solution.py +++ b/tests/test_solution.py @@ -14,24 +14,34 @@ def solution() -> Solution: def test_add_source_file(solution) -> None: - src: str = "int main(int argc, char** argv) {return 0;}" - solution.add_source_file(None, src) + src = '#include int main(int argc, char** argv) { printf("hello world\n"); return 0;}' + solution.add_source_file("Testfile1", src) + solution.add_source_file("Testfile2", src) + solution.add_source_file("Testfile3", src) + + assert len(solution.files) == 3 + assert ( - len(solution.files) == 1 - and solution.files[0].file_path == None + solution.files[0].file_path == "Testfile1" and solution.files[0].latest_content == src ) - src = '#include int main(int argc, char** argv) { printf("hello world\n"); return 0;}' - solution.add_source_file("Testfile1", src) assert ( - len(solution.files) == 2 - and solution.files[1].file_path == "Testfile1" + solution.files[1].file_path == "Testfile2" and solution.files[1].latest_content == src ) assert ( - len(solution.files_mapped) == 1 + solution.files[2].file_path == "Testfile3" + and solution.files[2].latest_content == src + ) + + assert ( + len(solution.files_mapped) == 3 and solution.files_mapped["Testfile1"].file_path == "Testfile1" and solution.files_mapped["Testfile1"].initial_content == src + and solution.files_mapped["Testfile2"].file_path == "Testfile2" + and solution.files_mapped["Testfile2"].initial_content == src + and solution.files_mapped["Testfile3"].file_path == "Testfile3" + and solution.files_mapped["Testfile3"].initial_content == src ) diff --git a/upload.sh b/upload.sh index 73274f2..6c90584 100755 --- a/upload.sh +++ b/upload.sh @@ -6,16 +6,21 @@ echo "Upload to PyPi" while true; do read -p "Choose repo (pypi, testpypi): " -r choice case "$choice" in - "pypi"|"testpypi") + "pypi") + repo="pypi" + break + ;; + "testpypi") + repo="testpypi" break ;; *) echo "Wrong option" ;; - esac + esac done echo "For username, type __token__ if you want to use a token. You will only be asked if the information is not in the ~/.pypi file." -python3 -m twine upload --skip-existing --repository pypi dist/* +python3 -m twine upload --skip-existing --repository "$repo" dist/*