diff --git a/.env.example b/.env.example index 7a86880..e07b8f8 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,2 @@ OPENAI_API_KEY=XXXXXXXXXXX -HUGGINGFACE_API_KEY=YYYYYYYYYY -ESBMC_AI_CFG_PATH=./config.json \ No newline at end of file +ESBMC_AI_CFG_PATH=./config.toml \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6a4743b..1e964aa 100644 --- a/.gitignore +++ b/.gitignore @@ -162,7 +162,7 @@ cython_debug/ esbmc temp/ -config_dev.json +config_dev.toml # Proprietary source code samples. uav_test.sh diff --git a/Pipfile b/Pipfile deleted file mode 100644 index 3f83258..0000000 --- a/Pipfile +++ /dev/null @@ -1,44 +0,0 @@ -[[source]] -url = "https://pypi.org/simple" -verify_ssl = true -name = "pypi" - -[packages] -openai = "*" -python-dotenv = "==1.0.0" -tiktoken = "*" -aiohttp = "==3.8.4" -aiosignal = "==1.3.1" -async-timeout = "==4.0.2" -attrs = "==23.1.0" -certifi = "==2022.12.7" -charset-normalizer = "==3.1.0" -frozenlist = "==1.3.3" -idna = "==3.4" -multidict = "==6.0.4" -regex = "==2023.3.23" -requests = "==2.29.0" -urllib3 = "==1.26.15" -yarl = "==1.9.2" -libclang = "*" -clang = "*" -langchain = "*" -langchain-openai = "*" -langchain-community = "*" -langchain-ollama = "*" -lizard = "*" - -[dev-packages] -pylint = "*" -ipykernel = "*" -pytest = "*" -pytest-cov = "*" -pytest-regtest = "*" -py = "*" # Dependency of pytest-regtest -twine = "*" -hatch = "*" -transformers = "*" -torch = "*" - -[requires] -python_version = "3.11" diff --git a/config.json b/config.json deleted file mode 100644 index 9e96240..0000000 --- a/config.json +++ /dev/null @@ -1,99 +0,0 @@ -{ - "ai_model": "gpt-3.5-turbo-16k", - "ai_custom": {}, - "esbmc_path": "~/.local/bin/esbmc", - "allow_successful": true, - "verifier_timeout": 90, - "esbmc_params": [ - "--interval-analysis", - "--goto-unwind", - "--unlimited-goto-unwind", - "--k-induction", - "--state-hashing", - "--add-symex-value-sets", - "--k-step", - "2", - "--floatbv", - "--unlimited-k-steps", - "--memory-leak-check", - "--context-bound", - "2" - ], - "esbmc_output_type": "full", - "source_code_format": "full", - "llm_requests": { - "max_tries": 5, - "timeout": 60 - }, - "temp_auto_clean": false, - "temp_file_dir": "./temp", - "loading_hints": true, - "chat_modes": { - "user_chat": { - "temperature": 1.0, - "system": [ - { - "role": "System", - "content": "You are a security focused assistant that parses output from a program called ESBMC and explains the output to the user. ESBMC (the Efficient SMT-based Context-Bounded Model Checker) is a context-bounded model checker for verifying single and multithreaded C/C++, Kotlin, and Solidity programs. It can automatically verify both predefined safety properties (e.g., bounds check, pointer safety, overflow) and user-defined program assertions. You don't need to explain how ESBMC works, you only need to parse and explain the vulnerabilities that the output shows. For each line of code explained, say what the line number is as well. Do not answer any questions outside of these explicit parameters. If you understand, reply OK." - }, - { - "role": "AI", - "content": "OK" - }, - { - "role": "System", - "content": "Reply OK if you understand that the following text is the program source code:\n\n```c{source_code}```" - }, - { - "role": "AI", - "content": "OK" - }, - { - "role": "System", - "content": "Reply OK if you understand that the following text is the output from ESBMC:\n\n```\n{esbmc_output}\n```" - }, - { - "role": "AI", - "content": "OK" - } - ], - "initial": "Walk me through the source code, while also explaining the output of ESBMC at the relevant parts. You shall not start the reply with an acknowledgement message such as 'Certainly'.", - "set_solution": [ - { - "role": "System", - "content": "Here is the corrected code:\n\n```c\n{source_code_solution}```" - }, - { - "role": "AI", - "content": "OK" - } - ] - }, - "generate_solution": { - "max_attempts": 5, - "temperature": 0.0, - "message_history": "normal", - "scenarios": { - "division by zero": { - "system": [ - { - "role": "System", - "content": "Here's a C program with a vulnerability:\n```c\n{source_code}\n```\nA Formal Verification tool identified a division by zero issue:\n{esbmc_output}\nTask: Modify the C code to safely handle scenarios where division by zero might occur. The solution should prevent undefined behavior or crashes due to division by zero. \nGuidelines: Focus on making essential changes only. Avoid adding or modifying comments, and ensure the changes are precise and minimal.\nGuidelines: Ensure the revised code avoids undefined behavior and handles division by zero cases effectively.\nGuidelines: Implement safeguards (like comparison) to prevent division by zero instead of using literal divisions like 1.0/0.0.Output: Provide the corrected, complete C code. The solution should compile and run error-free, addressing the division by zero vulnerability.\nStart the code snippet with ```c and end with ```. Reply OK if you understand." - }, - { - "role": "AI", - "content": "OK." - } - ] - } - }, - "system": [ - { - "role": "System", - "content": "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location that you need to fix. Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text." - } - ], - "initial": "The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text." - } - } -} \ No newline at end of file diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..43b9ca5 --- /dev/null +++ b/config.toml @@ -0,0 +1,95 @@ +ai_model = "gpt-3.5-turbo" +temp_auto_clean = true +#temp_file_dir = "temp" +allow_successful = false +loading_hints = true +source_code_format = "full" + +[esbmc] +path = "~/.local/bin/esbmc" +params = [ + "--interval-analysis", + "--goto-unwind", + "--unlimited-goto-unwind", + "--k-induction", + "--state-hashing", + "--add-symex-value-sets", + "--k-step", + "2", + "--floatbv", + "--unlimited-k-steps", + "--compact-trace", + "--context-bound", + "2", +] +output_type = "full" +timeout = 60 + +[llm_requests] +max_tries = 5 +timeout = 60 + +[user_chat] +temperature = 1.0 + +[fix_code] +temperature = 0.7 +max_attempts = 5 +message_history = "normal" + +# PROMPT TEMPLATES - USER CHAT + +[prompt_templates.user_chat] +initial = "Walk me through the source code, while also explaining the output of ESBMC at the relevant parts. You shall not start the reply with an acknowledgement message such as 'Certainly'." + +[[prompt_templates.user_chat.system]] +role = "System" +content = "You are a security focused assistant that parses output from a program called ESBMC and explains the output to the user. ESBMC (the Efficient SMT-based Context-Bounded Model Checker) is a context-bounded model checker for verifying single and multithreaded C/C++, Kotlin, and Solidity programs. It can automatically verify both predefined safety properties (e.g., bounds check, pointer safety, overflow) and user-defined program assertions. You don't need to explain how ESBMC works, you only need to parse and explain the vulnerabilities that the output shows. For each line of code explained, say what the line number is as well. Do not answer any questions outside of these explicit parameters. If you understand, reply OK." + +[[prompt_templates.user_chat.system]] +role = "AI" +content = "OK" + +[[prompt_templates.user_chat.system]] +role = "System" +content = "Reply OK if you understand that the following text is the program source code:\n\n```c{source_code}```" + +[[prompt_templates.user_chat.system]] +role = "AI" +content = "OK" + +[[prompt_templates.user_chat.system]] +role = "System" +content = "Reply OK if you understand that the following text is the output from ESBMC:\n\n```\n{esbmc_output}\n```" + +[[prompt_templates.user_chat.system]] +role = "AI" +content = "OK" + +[[prompt_templates.user_chat.set_solution]] +role = "System" +content = "Here is the corrected code:\n\n```c\n{source_code_solution}```" + +[[prompt_templates.user_chat.set_solution]] +role = "AI" +content = "OK" + +# PROMPT TEMPLATES - FIX CODE + +[prompt_templates.fix_code.base] +initial = "The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text." + +[[prompt_templates.fix_code.base.system]] +role = "System" +content = "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location that you need to fix. Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text." + +[prompt_templates.fix_code."division by zero"] +initial = "The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text." + +[[prompt_templates.fix_code."division by zero".system]] +role = "System" +content = "Here's a C program with a vulnerability:\n```c\n{source_code}\n```\nA Formal Verification tool identified a division by zero issue:\n{esbmc_output}\nTask: Modify the C code to safely handle scenarios where division by zero might occur. The solution should prevent undefined behavior or crashes due to division by zero. \nGuidelines: Focus on making essential changes only. Avoid adding or modifying comments, and ensure the changes are precise and minimal.\nGuidelines: Ensure the revised code avoids undefined behavior and handles division by zero cases effectively.\nGuidelines: Implement safeguards (like comparison) to prevent division by zero instead of using literal divisions like 1.0/0.0.Output: Provide the corrected, complete C code. The solution should compile and run error-free, addressing the division by zero vulnerability.\nStart the code snippet with ```c and end with ```. Reply OK if you understand." + +[[prompt_templates.fix_code."division by zero".system]] +role = "AI" +content = "OK." diff --git a/esbmc_ai/__about__.py b/esbmc_ai/__about__.py index 7a6bb39..70ab742 100644 --- a/esbmc_ai/__about__.py +++ b/esbmc_ai/__about__.py @@ -1,4 +1,4 @@ # Author: Yiannis Charalambous -__version__ = "v0.5.1" +__version__ = "v0.6.0" __author__: str = "Yiannis Charalambous" diff --git a/esbmc_ai/__init__.py b/esbmc_ai/__init__.py index ca668db..200305f 100644 --- a/esbmc_ai/__init__.py +++ b/esbmc_ai/__init__.py @@ -3,3 +3,7 @@ features such as automatic code fixing and more.""" from esbmc_ai.__about__ import __version__, __author__ + +from esbmc_ai.config import Config + +__all__ = ["Config"] diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index 1804373..33f8e40 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -20,7 +20,7 @@ import argparse -import esbmc_ai.config as config +from esbmc_ai import Config from esbmc_ai import __author__, __version__ from esbmc_ai.solution import SourceFile, Solution, get_solution @@ -34,7 +34,7 @@ from esbmc_ai.loading_widget import LoadingWidget, create_loading_widget from esbmc_ai.chats import UserChat from esbmc_ai.logging import print_horizontal_line, printv, printvv -from esbmc_ai.esbmc_util import esbmc +from esbmc_ai.esbmc_util import ESBMCUtil from esbmc_ai.chat_response import FinishReason, ChatResponse from esbmc_ai.ai_models import _ai_model_names @@ -85,10 +85,11 @@ def check_health() -> None: printv("Performing health check...") # Check that ESBMC exists. - if os.path.exists(config.esbmc_path): + esbmc_path: Path = Config.get_value("esbmc.path") + if esbmc_path.exists(): printv("ESBMC has been located") else: - print(f"Error: ESBMC could not be found in {config.esbmc_path}") + print(f"Error: ESBMC could not be found in {esbmc_path}") sys.exit(3) @@ -133,17 +134,17 @@ def _run_esbmc(source_file: SourceFile, anim: Optional[LoadingWidget] = None) -> if anim: anim.start("ESBMC is processing... Please Wait") - exit_code, esbmc_output = esbmc( + exit_code, esbmc_output = ESBMCUtil.esbmc( path=source_file.file_path, - esbmc_params=config.esbmc_params, - timeout=config.verifier_timeout, + esbmc_params=Config.get_value("esbmc.params"), + timeout=Config.get_value("esbmc.timeout"), ) if anim: anim.stop() # ESBMC will output 0 for verification success and 1 for verification # failed, if anything else gets thrown, it's an ESBMC error. - if not config.allow_successful and exit_code == 0: + if not Config.get_value("allow_successful") and exit_code == 0: printv("Success!") print(esbmc_output) sys.exit(0) @@ -166,6 +167,28 @@ def init_commands() -> None: fix_code_command.on_solution_signal.add_listener(update_solution) +def _execute_fix_code_command(source_file: SourceFile) -> FixCodeCommandResult: + """Shortcut method to execute fix code command.""" + return fix_code_command.execute( + ai_model=Config.get_ai_model(), + source_file=source_file, + generate_patches=Config.generate_patches, + message_history=Config.get_value("fix_code.message_history"), + api_keys=Config.api_keys, + temperature=Config.get_value("fix_code.temperature"), + max_attempts=Config.get_value("fix_code.max_attempts"), + requests_max_tries=Config.get_llm_requests_max_tries(), + requests_timeout=Config.get_llm_requests_timeout(), + esbmc_params=Config.get_value("esbmc.params"), + raw_conversation=Config.raw_conversation, + temp_auto_clean=Config.get_value("temp_auto_clean"), + verifier_timeout=Config.get_value("esbmc.timeout"), + source_code_format=Config.get_value("source_code_format"), + esbmc_output_format=Config.get_value("esbmc.output_type"), + scenarios=Config.get_fix_code_scenarios(), + ) + + def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None: path_arg: Path = Path(args.filename) @@ -184,10 +207,7 @@ def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None: esbmc_output: str = _run_esbmc(source_file) source_file.assign_verifier_output(esbmc_output) - result: FixCodeCommandResult = fix_code_command.execute( - source_file=source_file, - generate_patches=config.generate_patches, - ) + result: FixCodeCommandResult = _execute_fix_code_command(source_file) print(result) case _: @@ -233,7 +253,6 @@ def main() -> None: ) parser.add_argument( - "-V", "--version", action="version", version="%(prog)s {version}".format(version=__version__), @@ -291,26 +310,32 @@ def main() -> None: help="Generate patch files and place them in the same folder as the source files.", ) + # parser.add_argument( + # "--generate-default-config", + # action="store_true", + # default=False, + # help="Will generate and save the default config to the current working directory as 'esbmcai.toml'." + # ) + args: argparse.Namespace = parser.parse_args() print(f"ESBMC-AI {__version__}") print(f"Made by {__author__}") print() - config.load_envs() - config.load_config(config.cfg_path) - config.load_args(args) + Config.init(args) + ESBMCUtil.init(Config.get_value("esbmc.path")) check_health() - printv(f"Source code format: {config.source_code_format}") - printv(f"ESBMC output type: {config.esbmc_output_type}") + printv(f"Source code format: {Config.get_value('source_code_format')}") + printv(f"ESBMC output type: {Config.get_value('esbmc.output_type')}") anim: LoadingWidget = create_loading_widget() # Read the source code and esbmc output. printv("Reading source code...") - print(f"Running ESBMC with {config.esbmc_params}\n") + print(f"Running ESBMC with {Config.get_value('esbmc.params')}\n") assert isinstance(args.filename, str) @@ -365,23 +390,23 @@ def main() -> None: source_file.assign_verifier_output(esbmc_output) del esbmc_output - printv(f"Initializing the LLM: {config.ai_model.name}\n") - chat_llm: BaseChatModel = config.ai_model.create_llm( - api_keys=config.api_keys, - temperature=config.chat_prompt_user_mode.temperature, - requests_max_tries=config.requests_max_tries, - requests_timeout=config.requests_timeout, + printv(f"Initializing the LLM: {Config.get_ai_model().name}\n") + chat_llm: BaseChatModel = Config.get_ai_model().create_llm( + api_keys=Config.api_keys, + temperature=Config.get_value("user_chat.temperature"), + requests_max_tries=Config.get_value("requests.max_tries"), + requests_timeout=Config.get_value("requests.timeout"), ) printv("Creating user chat") global chat chat = UserChat( - ai_model_agent=config.chat_prompt_user_mode, - ai_model=config.ai_model, + ai_model=Config.get_ai_model(), llm=chat_llm, source_code=source_file.latest_content, esbmc_output=source_file.latest_verifier_output, - set_solution_messages=config.chat_prompt_user_mode.scenarios["set_solution"], + system_messages=Config.get_user_chat_system_messages(), + set_solution_messages=Config.get_user_chat_set_solution(), ) printv("Initializing commands...") @@ -389,11 +414,11 @@ def main() -> None: # Show the initial output. response: ChatResponse - if len(config.chat_prompt_user_mode.initial_prompt) > 0: + if len(str(Config.get_user_chat_initial().content)) > 0: printv("Using initial prompt from file...\n") anim.start("Model is parsing ESBMC output... Please Wait") response = chat.send_message( - message=config.chat_prompt_user_mode.initial_prompt, + message=str(Config.get_user_chat_initial().content), ) anim.stop() @@ -421,10 +446,7 @@ def main() -> None: print() print("ESBMC-AI will generate a fix for the code...") - result: FixCodeCommandResult = fix_code_command.execute( - source_file=source_file, - generate_patches=config.generate_patches, - ) + result: FixCodeCommandResult = _execute_fix_code_command(source_file) if result.successful: print( diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index 8010336..6c1f228 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -4,7 +4,7 @@ from typing import Any, Iterable, Optional, Union from enum import Enum from langchain_core.language_models import BaseChatModel -from pydantic.v1.types import SecretStr +from pydantic.types import SecretStr from typing_extensions import override from langchain_openai import ChatOpenAI @@ -115,7 +115,7 @@ def add_safeguards(content: str, char: str, allowed_keys: list[str]) -> str: if content.find("}", look_pointer) != -1: # Do it in reverse with reverse keys. content = add_safeguards(content[::-1], "}", reversed_keys)[::-1] - new_msg = msg.copy() + new_msg = msg.model_copy() new_msg.content = content result.append(new_msg) return result @@ -158,13 +158,50 @@ def create_llm( model_kwargs={}, ) + @classmethod + def get_openai_model_max_tokens(self, name: str) -> int: + """Dynamically resolves the max tokens from a base model.""" + + # https://platform.openai.com/docs/models + tokens = { + "gpt-4o": 128000, + "chatgpt-4o": 128000, + "o1": 128000, + "gpt-4": 8192, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-instruct": 4096, + } + + # Split into - segments and remove each section from the end to find out + # which one matches the most. + + # Base Case + if name in tokens: + return tokens[name] + + # Step Case + name_split: list[str] = name.split("-") + for i in range(1, name.count("-")): + subname: str = "-".join(name_split[:-i]) + if subname in tokens: + return tokens[subname] + + raise ValueError(f"Could not figure out max tokens for model: {name}") + + class OllamaAIModel(AIModel): def __init__(self, name: str, tokens: int, url: str) -> None: super().__init__(name, tokens) self.url: str = url - + @override - def create_llm(self, api_keys: APIKeyCollection, temperature: float = 1, requests_max_tries: int = 5, requests_timeout: float = 60) -> BaseChatModel: + def create_llm( + self, + api_keys: APIKeyCollection, + temperature: float = 1, + requests_max_tries: int = 5, + requests_timeout: float = 60, + ) -> BaseChatModel: # Ollama does not use API keys _ = api_keys _ = requests_max_tries @@ -173,7 +210,7 @@ def create_llm(self, api_keys: APIKeyCollection, temperature: float = 1, request model=self.name, temperature=temperature, client_kwargs={ - "timeout":requests_timeout, + "timeout": requests_timeout, }, ) @@ -226,38 +263,10 @@ def is_valid_ai_model( return name in _ai_model_names -def _get_openai_model_max_tokens(name: str) -> int: - """NOTE: OpenAI currently does not expose an API for getting the model - length. Maybe add a config input value for this?""" - - # https://platform.openai.com/docs/models - tokens = { - "gpt-4o": 128000, - "gpt-4": 8192, - "gpt-3.5-turbo": 16385, - "gpt-3.5-turbo-instruct": 4096, - } - - # Split into - segments and remove each section from the end to find out - # which one matches the most. - - # Base Case - if name in tokens: - return tokens[name] - - # Step Case - name_split: list[str] = name.split("-") - for i in range(1, name.count("-")): - subname: str = "-".join(name_split[:-i]) - if subname in tokens: - return tokens[subname] - - raise ValueError(f"Could not figure out max tokens for model: {name}") - - def get_ai_model_by_name( name: str, api_keys: Optional[APIKeyCollection] = None ) -> AIModel: + """Checks for built-in and custom_ai models""" # Check OpenAI models. if api_keys and api_keys.openai: try: @@ -268,7 +277,7 @@ def get_ai_model_by_name( add_custom_ai_model( AIModelOpenAI( model.id, - _get_openai_model_max_tokens(model.id), + AIModelOpenAI.get_openai_model_max_tokens(model.id), ), ) except ImportError: diff --git a/esbmc_ai/api_key_collection.py b/esbmc_ai/api_key_collection.py index 9aeec6f..e34d72e 100644 --- a/esbmc_ai/api_key_collection.py +++ b/esbmc_ai/api_key_collection.py @@ -9,4 +9,3 @@ class APIKeyCollection(NamedTuple): """Class that is used to pass keys to AIModels.""" openai: Optional[str] - huggingface: Optional[str] diff --git a/esbmc_ai/chat_response.py b/esbmc_ai/chat_response.py index 216216d..29f1934 100644 --- a/esbmc_ai/chat_response.py +++ b/esbmc_ai/chat_response.py @@ -31,7 +31,7 @@ class ChatResponse(NamedTuple): finish_reason: FinishReason = FinishReason.null -def json_to_base_message(json_string: dict) -> BaseMessage: +def dict_to_base_message(json_string: dict) -> BaseMessage: """Converts a json representation of messages (such as in config.json), into LangChain object messages. The three recognized roles are: 1. System @@ -49,6 +49,6 @@ def json_to_base_message(json_string: dict) -> BaseMessage: raise Exception() -def json_to_base_messages(json_messages: list[dict]) -> list[BaseMessage]: +def list_to_base_messages(json_messages: list[dict]) -> list[BaseMessage]: """Converts a list of messages from JSON format to a list of BaseMessage.""" - return [json_to_base_message(msg) for msg in json_messages] + return [dict_to_base_message(msg) for msg in json_messages] diff --git a/esbmc_ai/chats/base_chat_interface.py b/esbmc_ai/chats/base_chat_interface.py index defb2bd..5d2ab07 100644 --- a/esbmc_ai/chats/base_chat_interface.py +++ b/esbmc_ai/chats/base_chat_interface.py @@ -10,7 +10,6 @@ ) from langchain_core.language_models import BaseChatModel -from esbmc_ai.config import ChatPromptSettings from esbmc_ai.chat_response import ChatResponse, FinishReason from esbmc_ai.ai_models import AIModel @@ -21,16 +20,13 @@ class BaseChatInterface(object): def __init__( self, - ai_model_agent: ChatPromptSettings, + system_messages: list[BaseMessage], llm: BaseChatModel, ai_model: AIModel, ) -> None: super().__init__() self.ai_model: AIModel = ai_model - self.ai_model_agent: ChatPromptSettings = ai_model_agent - self._system_messages: list[BaseMessage] = list( - ai_model_agent.system_messages.messages - ) + self._system_messages: list[BaseMessage] = system_messages self.messages: list[BaseMessage] = [] self.llm: BaseChatModel = llm diff --git a/esbmc_ai/chats/latest_state_solution_generator.py b/esbmc_ai/chats/latest_state_solution_generator.py index 0bda6f8..2e40b7b 100644 --- a/esbmc_ai/chats/latest_state_solution_generator.py +++ b/esbmc_ai/chats/latest_state_solution_generator.py @@ -1,5 +1,6 @@ # Author: Yiannis Charalambous +from typing import Optional from typing_extensions import override from langchain_core.messages import BaseMessage from esbmc_ai.chats.solution_generator import SolutionGenerator @@ -13,13 +14,15 @@ class LatestStateSolutionGenerator(SolutionGenerator): output state.""" @override - def generate_solution(self) -> tuple[str, FinishReason]: + def generate_solution( + self, override_scenario: Optional[str] = None + ) -> tuple[str, FinishReason]: # Backup message stack and clear before sending base message. We want # to keep the message stack intact because we will print it with # print_raw_conversation. messages: list[BaseMessage] = self.messages self.messages: list[BaseMessage] = [] - solution, finish_reason = super().generate_solution() + solution, finish_reason = super().generate_solution(override_scenario) # Append last messages to the messages stack messages.extend(self.messages) # Restore diff --git a/esbmc_ai/chats/solution_generator.py b/esbmc_ai/chats/solution_generator.py index 030e4e3..6364e6e 100644 --- a/esbmc_ai/chats/solution_generator.py +++ b/esbmc_ai/chats/solution_generator.py @@ -1,23 +1,17 @@ # Author: Yiannis Charalambous 2023 -from re import S from typing import Optional from langchain_core.language_models import BaseChatModel from typing_extensions import override from langchain.schema import BaseMessage, HumanMessage from esbmc_ai.chat_response import ChatResponse, FinishReason -from esbmc_ai.config import ChatPromptSettings, DynamicAIModelAgent +from esbmc_ai.config import FixCodeScenarios, default_scenario from esbmc_ai.solution import SourceFile from esbmc_ai.ai_models import AIModel from .base_chat_interface import BaseChatInterface -from esbmc_ai.esbmc_util import ( - esbmc_get_counter_example, - esbmc_get_violated_property, - get_source_code_err_line_idx, - get_clang_err_line_index, -) +from esbmc_ai.esbmc_util import ESBMCUtil class ESBMCTimedOutException(Exception): @@ -34,12 +28,12 @@ def get_source_code_formatted( match source_code_format: case "single": # Get source code error line from esbmc output - line: Optional[int] = get_source_code_err_line_idx(esbmc_output) + line: Optional[int] = ESBMCUtil.get_source_code_err_line_idx(esbmc_output) if line: return source_code.splitlines(True)[line] # Check if it parses - line = get_clang_err_line_index(esbmc_output) + line = ESBMCUtil.get_clang_err_line_index(esbmc_output) if line: return source_code.splitlines(True)[line] @@ -64,12 +58,12 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str match esbmc_output_type: case "vp": - value: Optional[str] = esbmc_get_violated_property(esbmc_output) + value: Optional[str] = ESBMCUtil.esbmc_get_violated_property(esbmc_output) if not value: raise ValueError("Not found violated property." + esbmc_output) return value case "ce": - value: Optional[str] = esbmc_get_counter_example(esbmc_output) + value: Optional[str] = ESBMCUtil.esbmc_get_counter_example(esbmc_output) if not value: raise ValueError("Not found counterexample.") return value @@ -82,10 +76,9 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str class SolutionGenerator(BaseChatInterface): def __init__( self, - ai_model_agent: DynamicAIModelAgent | ChatPromptSettings, + scenarios: FixCodeScenarios, llm: BaseChatModel, ai_model: AIModel, - scenario: str = "", source_code_format: str = "full", esbmc_output_type: str = "full", ) -> None: @@ -93,19 +86,15 @@ def __init__( Prompting. Will get the correct scenario from the DynamicAIModelAgent supplied and create a ChatPrompt.""" - chat_prompt: ChatPromptSettings = ai_model_agent - if isinstance(ai_model_agent, DynamicAIModelAgent): - # Convert to chat prompt - chat_prompt = DynamicAIModelAgent.to_chat_prompt_settings( - ai_model_agent=ai_model_agent, scenario=scenario - ) - super().__init__( - ai_model_agent=chat_prompt, ai_model=ai_model, llm=llm, + system_messages=[], # Empty as it will be updated in the update method. ) + self.scenarios: FixCodeScenarios = scenarios + self.scenario: Optional[str] = None + self.esbmc_output_type: str = esbmc_output_type self.source_code_format: str = source_code_format @@ -147,8 +136,12 @@ def get_code_from_solution(cls, solution: str) -> str: return solution def update_state(self, source_code: str, esbmc_output: str) -> None: - """Updates the latest state of the code and ESBMC output. This should be + """Updates the latest state of the code and ESBMC output. It also updates + the scenario, which is the type of error that ESBMC has shown. This should be called before generate_solution.""" + + self.scenario = ESBMCUtil.esbmc_get_error_type(esbmc_output) + self.source_code_raw = source_code # Format ESBMC output @@ -169,16 +162,28 @@ def update_state(self, source_code: str, esbmc_output: str) -> None: esbmc_output=self.esbmc_output, ) - def generate_solution(self) -> tuple[str, FinishReason]: + def generate_solution( + self, + override_scenario: Optional[str] = None, + ) -> tuple[str, FinishReason]: + assert ( self.source_code_raw is not None and self.source_code_formatted is not None and self.esbmc_output is not None ), "Call update_state before calling generate_solution." - self.push_to_message_stack( - HumanMessage(content=self.ai_model_agent.initial_prompt) - ) + initial_message: str + if override_scenario: + initial_message = str(self.scenarios[override_scenario]["initial"]) + else: + assert self.scenario, "Call update or set the scenario" + if self.scenario in self.scenarios: + initial_message = str(self.scenarios[self.scenario]["initial"]) + else: + initial_message = str(self.scenarios[default_scenario]) + + self.push_to_message_stack(HumanMessage(content=initial_message)) # Apply template substitution to message stack self.apply_template_value( @@ -197,10 +202,12 @@ def generate_solution(self) -> tuple[str, FinishReason]: match self.source_code_format: case "single": # Get source code error line from esbmc output - line: Optional[int] = get_source_code_err_line_idx(self.esbmc_output) + line: Optional[int] = ESBMCUtil.get_source_code_err_line_idx( + self.esbmc_output + ) if not line: # Check if it parses - line = get_clang_err_line_index(self.esbmc_output) + line = ESBMCUtil.get_clang_err_line_index(self.esbmc_output) assert ( line diff --git a/esbmc_ai/chats/user_chat.py b/esbmc_ai/chats/user_chat.py index 441c59a..3274f34 100644 --- a/esbmc_ai/chats/user_chat.py +++ b/esbmc_ai/chats/user_chat.py @@ -8,7 +8,6 @@ from langchain_community.chat_message_histories import ChatMessageHistory -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings from esbmc_ai.ai_models import AIModel from .base_chat_interface import BaseChatInterface @@ -19,15 +18,15 @@ class UserChat(BaseChatInterface): def __init__( self, - ai_model_agent: ChatPromptSettings, ai_model: AIModel, llm: BaseChatModel, source_code: str, esbmc_output: str, - set_solution_messages: AIAgentConversation, + system_messages: list[BaseMessage], + set_solution_messages: list[BaseMessage], ) -> None: super().__init__( - ai_model_agent=ai_model_agent, + system_messages=system_messages, ai_model=ai_model, llm=llm, ) @@ -44,7 +43,7 @@ def __init__( def set_solution(self, source_code: str) -> None: """Sets the solution to the problem ESBMC reported, this will inform the AI.""" - for msg in self.set_solution_messages.messages: + for msg in self.set_solution_messages: self.push_to_message_stack(msg) self.apply_template_value(source_code_solution=source_code) diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index dce3e2b..f925942 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -4,23 +4,21 @@ from typing import Any, Optional, Tuple from typing_extensions import override +from esbmc_ai.ai_models import AIModel +from esbmc_ai.api_key_collection import APIKeyCollection from esbmc_ai.chat_response import FinishReason from esbmc_ai.chats import LatestStateSolutionGenerator, SolutionGenerator from esbmc_ai.chats.solution_generator import ESBMCTimedOutException from esbmc_ai.commands.command_result import CommandResult +from esbmc_ai.config import FixCodeScenarios from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator from esbmc_ai.solution import SourceFile from .chat_command import ChatCommand -from .. import config from ..msg_bus import Signal from ..loading_widget import create_loading_widget -from ..esbmc_util import ( - esbmc_get_error_type, - esbmc_load_source_code, -) +from ..esbmc_util import ESBMCUtil from ..logging import print_horizontal_line, printv, printvv -from subprocess import CalledProcessError class FixCodeCommandResult(CommandResult): @@ -64,68 +62,78 @@ def print_raw_conversation() -> None: print("\n" + "\n\n".join(messages)) print("ESBMC-AI Notice: End of raw conversation") + # Handle kwargs source_file: SourceFile = kwargs["source_file"] assert source_file.file_path + generate_patches: bool = ( kwargs["generate_patches"] if "generate_patches" in kwargs else False ) - # Parse the esbmc output here and determine what "Scenario" to use. - scenario: str = esbmc_get_error_type(source_file.initial_verifier_output) + message_history: str = ( + kwargs["message_history"] if "message_history" else "normal" + ) - printv(f"Scenario: {scenario}") - printv( - f"Using dynamic prompt..." - if scenario in config.chat_prompt_generator_mode.scenarios - else "Using generic prompt..." + api_keys: APIKeyCollection = kwargs["api_keys"] + ai_model: AIModel = kwargs["ai_model"] + temperature: float = kwargs["temperature"] + max_tries: int = kwargs["requests_max_tries"] + timeout: int = kwargs["requests_timeout"] + source_code_format: str = kwargs["source_code_format"] + esbmc_output_format: str = kwargs["esbmc_output_format"] + scenarios: FixCodeScenarios = kwargs["scenarios"] + max_attempts: int = kwargs["max_attempts"] + esbmc_params: list[str] = kwargs["esbmc_params"] + verifier_timeout: int = kwargs["verifier_timeout"] + temp_auto_clean: bool = kwargs["temp_auto_clean"] + raw_conversation: bool = ( + kwargs["raw_conversation"] if "raw_conversation" in kwargs else False ) + # End of handle kwargs - match config.fix_code_message_history: + match message_history: case "normal": solution_generator = SolutionGenerator( - ai_model_agent=config.chat_prompt_generator_mode, - ai_model=config.ai_model, - llm=config.ai_model.create_llm( - api_keys=config.api_keys, - temperature=config.chat_prompt_generator_mode.temperature, - requests_max_tries=config.requests_max_tries, - requests_timeout=config.requests_timeout, + ai_model=ai_model, + llm=ai_model.create_llm( + api_keys=api_keys, + temperature=temperature, + requests_max_tries=max_tries, + requests_timeout=timeout, ), - scenario=scenario, - source_code_format=config.source_code_format, - esbmc_output_type=config.esbmc_output_type, + scenarios=scenarios, + source_code_format=source_code_format, + esbmc_output_type=esbmc_output_format, ) case "latest_only": solution_generator = LatestStateSolutionGenerator( - ai_model_agent=config.chat_prompt_generator_mode, - ai_model=config.ai_model, - llm=config.ai_model.create_llm( - api_keys=config.api_keys, - temperature=config.chat_prompt_generator_mode.temperature, - requests_max_tries=config.requests_max_tries, - requests_timeout=config.requests_timeout, + ai_model=ai_model, + llm=ai_model.create_llm( + api_keys=api_keys, + temperature=temperature, + requests_max_tries=max_tries, + requests_timeout=timeout, ), - scenario=scenario, - source_code_format=config.source_code_format, - esbmc_output_type=config.esbmc_output_type, + scenarios=scenarios, + source_code_format=source_code_format, + esbmc_output_type=esbmc_output_format, ) case "reverse": solution_generator = ReverseOrderSolutionGenerator( - ai_model_agent=config.chat_prompt_generator_mode, - ai_model=config.ai_model, - llm=config.ai_model.create_llm( - api_keys=config.api_keys, - temperature=config.chat_prompt_generator_mode.temperature, - requests_max_tries=config.requests_max_tries, - requests_timeout=config.requests_timeout, + ai_model=ai_model, + llm=ai_model.create_llm( + api_keys=api_keys, + temperature=temperature, + requests_max_tries=max_tries, + requests_timeout=timeout, ), - scenario=scenario, - source_code_format=config.source_code_format, - esbmc_output_type=config.esbmc_output_type, + scenarios=scenarios, + source_code_format=source_code_format, + esbmc_output_type=esbmc_output_format, ) case _: raise NotImplementedError( - f"error: {config.fix_code_message_history} has not been implemented in the Fix Code Command" + f"error: {message_history} has not been implemented in the Fix Code Command" ) try: @@ -139,8 +147,7 @@ def print_raw_conversation() -> None: print() - attempts: int = config.fix_code_max_attempts - for attempt in range(1, attempts + 1): + for attempt in range(1, max_attempts + 1): # Get a response. Use while loop to account for if the message stack # gets full, then need to compress and retry. while True: @@ -164,12 +171,12 @@ def print_raw_conversation() -> None: # Pass to ESBMC, a workaround is used where the file is saved # to a temporary location since ESBMC needs it in file format. self.anim.start("Verifying with ESBMC... Please Wait") - exit_code, esbmc_output = esbmc_load_source_code( + exit_code, esbmc_output = ESBMCUtil.esbmc_load_source_code( source_file=source_file, source_file_content_index=-1, - esbmc_params=config.esbmc_params, - auto_clean=config.temp_auto_clean, - timeout=config.verifier_timeout, + esbmc_params=esbmc_params, + auto_clean=temp_auto_clean, + timeout=verifier_timeout, ) self.anim.stop() @@ -186,7 +193,7 @@ def print_raw_conversation() -> None: if exit_code == 0: self.on_solution_signal.emit(source_file.latest_content) - if config.raw_conversation: + if raw_conversation: print_raw_conversation() printv("ESBMC-AI Notice: Successfully verified code") @@ -205,18 +212,18 @@ def print_raw_conversation() -> None: source_file.latest_content, source_file.latest_verifier_output ) except ESBMCTimedOutException: - if config.raw_conversation: + if raw_conversation: print_raw_conversation() print("ESBMC-AI Notice: error: ESBMC has timed out...") sys.exit(1) # Failure case - if attempt != attempts: - print(f"ESBMC-AI Notice: Failure {attempt}/{attempts}: Retrying...") + if attempt != max_attempts: + print(f"ESBMC-AI Notice: Failure {attempt}/{max_attempts}: Retrying...") else: - print(f"ESBMC-AI Notice: Failure {attempt}/{attempts}") + print(f"ESBMC-AI Notice: Failure {attempt}/{max_attempts}") - if config.raw_conversation: + if raw_conversation: print_raw_conversation() return FixCodeCommandResult(False, None) diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index d99141c..d107bb6 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -1,149 +1,465 @@ # Author: Yiannis Charalambous 2023 import os -import json import sys from platform import system as system_name from pathlib import Path from dotenv import load_dotenv, find_dotenv - -from typing import Any, NamedTuple, Optional, Union, Sequence -from dataclasses import dataclass -from langchain.schema import BaseMessage - -from esbmc_ai.logging import printv, set_verbose +from langchain.schema import HumanMessage +import tomllib as toml + +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, +) + +from esbmc_ai.chat_response import list_to_base_messages +from esbmc_ai.logging import set_verbose from .ai_models import * from .api_key_collection import APIKeyCollection -from .chat_response import json_to_base_messages -api_keys: APIKeyCollection +FixCodeScenarios = dict[str, dict[str, str | Sequence[BaseMessage]]] +"""Type for scenarios. A single scenario contains initial and system components. + +* Initial message can be accessed like so: `x["base"]["initial"]` +* System message can be accessed like so: `x["base"]["system"]`""" + +default_scenario: str = "base" + + +class ConfigField(NamedTuple): + # The name of the config field and also namespace + name: str + # If a default value is supplied, then it can be omitted from the config + default_value: Any + # If true, then the default value will be None, so during + # validation, if no value is supplied, then None will be the + # the default value, instead of failing due to None being the + # default value which under normal circumstances means that the + # field is not optional. + default_value_none: bool = False + + # Lambda function to validate if field has a valid value. + # Default is identity function which is return true. + validate: Callable[[Any], bool] = lambda _: True + # Transform the value once loaded, this allows the value to be saved + # as a more complex type than that which is represented in the config + # file. + on_load: Callable[[Any], Any] = lambda v: v + # If defined, will be called and allows to custom load complex types that + # may not match 1-1 in the config. The config file passed as a parameter here + # is the original, unflattened version. The value returned should be the value + # assigned to this field. + on_read: Optional[Callable[[dict[str, Any]], Any]] = None + error_message: Optional[str] = None + + +def _validate_prompt_template_conversation(prompt_template: List[Dict]) -> bool: + """Used to validate if a prompt template conversation is of the correct format + in the config before loading it.""" + + for msg in prompt_template: + if ( + not isinstance(msg, dict) + or "content" not in msg + or "role" not in msg + or not isinstance(msg["content"], str) + or not isinstance(msg["role"], str) + ): + return False + return True + + +def _validate_prompt_template(conv: Dict[str, List[Dict]]) -> bool: + """Used to check if a prompt template (contains conversation and initial message) is + of the correct format.""" + if ( + "initial" not in conv + or not isinstance(conv["initial"], str) + or "system" not in conv + or not _validate_prompt_template_conversation(conv["system"]) + ): + return False + return True + + +class Config: + api_keys: APIKeyCollection + raw_conversation: bool = False + cfg_path: Path + generate_patches: bool + + _fields: List[ConfigField] = [ + ConfigField( + name="ai_model", + default_value=None, + # Api keys are loaded from system env so they are already + # available + validate=lambda v: isinstance(v, str) + and is_valid_ai_model(v, Config.api_keys), + on_load=lambda v: get_ai_model_by_name(v, Config.api_keys), + ), + ConfigField( + name="temp_auto_clean", + default_value=True, + validate=lambda v: isinstance(v, bool), + ), + ConfigField( + name="temp_file_dir", + default_value=None, + validate=lambda v: isinstance(v, str) and Path(v).is_file(), + on_load=lambda v: Path(v), + default_value_none=True, + ), + ConfigField( + name="allow_successful", + default_value=False, + validate=lambda v: isinstance(v, bool), + ), + ConfigField( + name="loading_hints", + default_value=True, + validate=lambda v: isinstance(v, bool), + ), + ConfigField( + name="source_code_format", + default_value="full", + validate=lambda v: isinstance(v, str) and v in ["full", "single"], + error_message="source_code_format can only be 'full' or 'single'", + ), + ConfigField( + name="esbmc.path", + default_value=None, + validate=lambda v: isinstance(v, str) and Path(v).expanduser().is_file(), + on_load=lambda v: Path(v).expanduser(), + ), + ConfigField( + name="esbmc.params", + default_value=[ + "--interval-analysis", + "--goto-unwind", + "--unlimited-goto-unwind", + "--k-induction", + "--state-hashing", + "--add-symex-value-sets", + "--k-step", + "2", + "--floatbv", + "--unlimited-k-steps", + "--compact-trace", + "--context-bound", + "2", + ], + validate=lambda v: isinstance(v, List), + ), + ConfigField( + name="esbmc.output_type", + default_value="full", + validate=lambda v: v in ["full", "vp", "ce"], + ), + ConfigField( + name="esbmc.timeout", + default_value=60, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="llm_requests.max_tries", + default_value=5, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="llm_requests.timeout", + default_value=60, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="user_chat.temperature", + default_value=1.0, + validate=lambda v: isinstance(v, float) and v >= 0 and v <= 2.0, + error_message="Temperature needs to be a value between 0 and 2.0", + ), + ConfigField( + name="fix_code.temperature", + default_value=1.0, + validate=lambda v: isinstance(v, float) and v >= 0 and v <= 2.0, + error_message="Temperature needs to be a value between 0 and 2.0", + ), + ConfigField( + name="fix_code.max_attempts", + default_value=5, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="fix_code.message_history", + default_value="normal", + validate=lambda v: v in ["normal", "latest_only", "reverse"], + error_message='fix_code.message_history can only be "normal", "latest_only", "reverse"', + ), + ConfigField( + name="prompt_templates.user_chat.initial", + default_value=None, + validate=lambda v: isinstance(v, str), + on_load=lambda v: HumanMessage(content=v), + ), + ConfigField( + name="prompt_templates.user_chat.system", + default_value=None, + validate=lambda v: _validate_prompt_template_conversation(v), + on_load=lambda v: list_to_base_messages(v), + ), + ConfigField( + name="prompt_templates.user_chat.set_solution", + default_value=None, + validate=lambda v: _validate_prompt_template_conversation(v), + on_load=lambda v: list_to_base_messages(v), + ), + # Here we have a list of prompt templates that are for each scenario. + # The base scenario prompt template is required. + ConfigField( + name="prompt_templates.fix_code", + default_value=None, + validate=lambda v: default_scenario in v + and all( + [ + _validate_prompt_template(prompt_template) + for prompt_template in v.values() + ] + ), + on_read=lambda config_file: { + scenario: { + "initial": HumanMessage(content=conv["initial"]), + "system": list_to_base_messages(conv["system"]), + } + for scenario, conv in config_file["prompt_templates"][ + "fix_code" + ].items() + }, + ), + ] + _values: Dict[str, Any] = {} + + # Define some shortcuts for the values here (instead of having to use get_value) + + @classmethod + def get_ai_model(cls) -> AIModel: + return cls.get_value("ai_model") -esbmc_path: str = "~/.local/bin/esbmc" -esbmc_params: list[str] = [ - "--interval-analysis", - "--goto-unwind", - "--unlimited-goto-unwind", - "--k-induction", - "--state-hashing", - "--add-symex-value-sets", - "--k-step", - "2", - "--floatbv", - "--unlimited-k-steps", - "--compact-trace", - "--context-bound", - "2", -] + @classmethod + def get_llm_requests_max_tries(cls) -> int: + return cls.get_value("llm_requests.max_tries") -temp_auto_clean: bool = True -temp_file_dir: Optional[str] = None -ai_model: AIModel + @classmethod + def get_llm_requests_timeout(cls) -> float: + return cls.get_value("llm_requests.timeout") -esbmc_output_type: str = "full" -source_code_format: str = "full" + @classmethod + def get_user_chat_initial(cls) -> BaseMessage: + return cls.get_value("prompt_templates.user_chat.initial") -fix_code_max_attempts: int = 5 -fix_code_message_history: str = "" + @classmethod + def get_user_chat_system_messages(cls) -> list[BaseMessage]: + return cls.get_value("prompt_templates.user_chat.system") -requests_max_tries: int = 5 -requests_timeout: float = 60 -verifier_timeout: float = 60 + @classmethod + def get_user_chat_set_solution(cls) -> list[BaseMessage]: + return cls.get_value("prompt_templates.user_chat.set_solution") -loading_hints: bool = False -allow_successful: bool = False -# Show the raw conversation after the command ends -raw_conversation: bool = False + @classmethod + def get_fix_code_scenarios(cls) -> FixCodeScenarios: + return cls.get_value("prompt_templates.fix_code") -cfg_path: str + @classmethod + def init(cls, args: Any) -> None: + cls._load_envs() -generate_patches: bool + if not Config.cfg_path.exists() and Config.cfg_path.is_file(): + print(f"Error: Config not found: {Config.cfg_path}") + sys.exit(1) + with open(Config.cfg_path, "r") as file: + original_config_file: dict[str, Any] = toml.loads(file.read()) + + # Load custom AIs + if "ai_custom" in original_config_file: + _load_custom_ai(original_config_file["ai_custom"]) + + # Flatten dict as the _fields are defined in a flattened format for + # convenience. + config_file: dict[str, Any] = cls._flatten_dict(original_config_file) + + # Load all the config file field entries + for field in cls._fields: + # If on_read is overwritten, then the reading process is manually + # defined so fallback to that. + if field.on_read: + cls._values[field.name] = field.on_read(original_config_file) + continue + + # Proceed to default read + + # Is field entry found in config? + if field.name in config_file: + # Check if None and not allowed! + if ( + field.default_value == None + and not field.default_value_none + and config_file[field.name] == None + ): + raise ValueError( + f"The config entry {field.name} has a None value when it can't be" + ) + + # Validate field + assert field.validate(config_file[field.name]), ( + field.error_message + if field.error_message + else f"Field: {field.name} is invalid: {config_file[field.name]}" + ) -# TODO Get rid of this class as soon as ConfigTool with the pyautoconfig -class AIAgentConversation(NamedTuple): - """Immutable class describing the conversation definition for an AI agent. The - class represents the system messages of the AI agent defined and contains a load - class method for efficiently loading it from config.""" + # Assign field from config file + cls._values[field.name] = field.on_load(config_file[field.name]) + elif field.default_value == None and not field.default_value_none: + raise KeyError(f"{field.name} is missing from config file") + else: + # Use default value + cls._values[field.name] = field.default_value - messages: tuple[BaseMessage, ...] + cls._load_args(args) @classmethod - def from_seq(cls, message_list: Sequence[BaseMessage]) -> "AIAgentConversation": - return cls(messages=tuple(message_list)) + def get_value(cls, name: str) -> Any: + return cls._values[name] @classmethod - def load_from_config( - cls, messages_list: list[dict[str, str]] - ) -> "AIAgentConversation": - return cls(messages=tuple(json_to_base_messages(messages_list))) - - -@dataclass -class ChatPromptSettings: - """Settings for the AI Model. These settings act as an actor/agent, allowing the - AI model to be applied into a specific scenario.""" - - system_messages: AIAgentConversation - """The generic prompt system messages of the AI. Generic meaning it is used in - every scenario, as opposed to dynamic system message. The value is a list of - converstaions.""" - initial_prompt: str - """The generic initial prompt to use for the agent.""" - temperature: float - - -@dataclass -class DynamicAIModelAgent(ChatPromptSettings): - """Extension of the ChatPromptSettings to include dynamic""" - - scenarios: dict[str, AIAgentConversation] - """Scenarios dictionary that contains system messages for different errors that - ESBMC can give. More information can be found in the - [wiki](https://github.com/Yiannis128/esbmc-ai/wiki/Configuration#dynamic-prompts). - Reads from the config file the following hierarchy: - * Dictionary mapping of error type to dictionary. Accepts the following entries: - * `system` mapping to an array. The array contains the conversation for the - system message for this particular error.""" + def set_value(cls, name: str, value: Any) -> None: + cls._values[name] = value @classmethod - def to_chat_prompt_settings( - cls, ai_model_agent: "DynamicAIModelAgent", scenario: str - ) -> ChatPromptSettings: - """DynamicAIModelAgent extensions are not used by BaseChatInterface derived classes - directly, since they only use the SystemMessages of ChatPromptSettings. This applies - the correct scenario as a System Message and returns a pure ChatPromptSettings object - for use. **Will return a shallow copy even if the system message is to be used**. + def _load_envs(cls) -> None: + """Environment variables are loaded in the following order: + + 1. Environment variables already loaded. Any variable not present will be looked for in + .env files in the following locations. + 2. .env file in the current directory, moving upwards in the directory tree. + 3. esbmc-ai.env file in the current directory, moving upwards in the directory tree. + 4. esbmc-ai.env file in $HOME/.config/ for Linux/macOS and %userprofile% for Windows. + + Note: ESBMC_AI_CFG_PATH undergoes tilde user expansion and also environment + variable expansion. """ - if scenario in ai_model_agent.scenarios: - return ChatPromptSettings( - initial_prompt=ai_model_agent.initial_prompt, - system_messages=ai_model_agent.scenarios[scenario], - temperature=ai_model_agent.temperature, - ) + + def get_env_vars() -> None: + """Gets all the system environment variables that are currently loaded.""" + for k in keys: + value: Optional[str] = os.getenv(k) + if value != None: + values[k] = value + + keys: list[str] = ["OPENAI_API_KEY", "ESBMC_AI_CFG_PATH"] + values: dict[str, str] = {} + + # Load from system env + get_env_vars() + + # Find .env in current working directory and load it. + dotenv_file_path: str = find_dotenv(usecwd=True) + if dotenv_file_path != "": + load_dotenv(dotenv_path=dotenv_file_path, override=False, verbose=True) else: - return ChatPromptSettings( - initial_prompt=ai_model_agent.initial_prompt, - system_messages=ai_model_agent.system_messages, - temperature=ai_model_agent.temperature, - ) + # Find esbmc-ai.env in current working directory and load it. + dotenv_file_path: str = find_dotenv(filename="esbmc-ai.env", usecwd=True) + if dotenv_file_path != "": + load_dotenv(dotenv_path=dotenv_file_path, override=False, verbose=True) + + get_env_vars() + + # Look for .env in home folder. + home_path: Path = Path.home() + match system_name(): + case "Linux" | "Darwin": + home_path /= ".config/esbmc-ai.env" + case "Windows": + home_path /= "esbmc-ai.env" + case _: + raise ValueError(f"Unknown OS type: {system_name()}") + load_dotenv(home_path, override=False, verbose=True) -chat_prompt_user_mode: DynamicAIModelAgent -chat_prompt_generator_mode: DynamicAIModelAgent -chat_prompt_optimize_code: ChatPromptSettings + get_env_vars() -esbmc_params_optimize_code: list[str] = [ - "--incremental-bmc", - "--no-bounds-check", - "--no-pointer-check", - "--no-div-by-zero-check", -] + # Check if all the values are set, else error. + for key in keys: + if key not in values: + print(f"Error: No ${key} in environment.") + sys.exit(1) + + cls.api_keys = APIKeyCollection( + openai=str(os.getenv("OPENAI_API_KEY")), + ) + + cls.cfg_path = Path( + os.path.expanduser(os.path.expandvars(str(os.getenv("ESBMC_AI_CFG_PATH")))) + ) + + @classmethod + def _load_args(cls, args) -> None: + set_verbose(args.verbose) + + # AI Model -m + if args.ai_model != "": + if is_valid_ai_model(args.ai_model, cls.api_keys): + ai_model = get_ai_model_by_name(args.ai_model, cls.api_keys) + cls.set_value("ai_model", ai_model) + else: + print(f"Error: invalid --ai-model parameter {args.ai_model}") + sys.exit(4) + + # If append flag is set, then append. + if args.append: + esbmc_params: List[str] = cls.get_value("esbmc.params") + esbmc_params.extend(args.remaining) + cls.set_value("esbmc_params", esbmc_params) + elif len(args.remaining) != 0: + cls.set_value("esbmc_params", args.remaining) + + Config.raw_conversation = args.raw_conversation + Config.generate_patches = args.generate_patches + + @classmethod + def _flatten_dict(cls, d, parent_key="", sep="."): + """Recursively flattens a nested dictionary.""" + items = {} + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.update(cls._flatten_dict(v, new_key, sep=sep)) + else: + items[new_key] = v + return items def _load_custom_ai(config: dict) -> None: - ai_custom: dict = config - for name, ai_data in ai_custom.items(): + """Loads custom AI defined in the config and ascociates it with the AIModels + module.""" + + def _load_config_value( + config_file: dict, name: str, default: object = None + ) -> tuple[Any, bool]: + if name in config_file: + return config_file[name], True + else: + print( + f"Warning: {name} not found in config... Using default value: {default}" + ) + return default, False + + for name, ai_data in config.items(): # Load the max tokens custom_ai_max_tokens, ok = _load_config_value( config_file=ai_data, @@ -153,7 +469,7 @@ def _load_custom_ai(config: dict) -> None: assert ( isinstance(custom_ai_max_tokens, int) and custom_ai_max_tokens > 0 ), f'custom_ai_max_tokens in ai_custom entry "{name}" needs to be an int and greater than 0.' - + # Load the URL custom_ai_url, ok = _load_config_value( config_file=ai_data, @@ -167,327 +483,23 @@ def _load_custom_ai(config: dict) -> None: name="server_type", default="localhost:11434", ) - assert ok, f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" + assert ( + ok + ), f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" # Create correct type of LLM llm: AIModel match server_type: case "ollama": llm = OllamaAIModel( - name=name, - tokens=custom_ai_max_tokens, - url=custom_ai_url, + name=name, + tokens=custom_ai_max_tokens, + url=custom_ai_url, ) case _: - raise NotImplementedError(f"The custom AI server type is not implemented: {server_type}") + raise NotImplementedError( + f"The custom AI server type is not implemented: {server_type}" + ) # Add the custom AI. add_custom_ai_model(llm) - - -def load_envs() -> None: - """Environment variables are loaded in the following order: - - 1. Environment variables already loaded. Any variable not present will be looked for in - .env files in the following locations. - 2. .env file in the current directory, moving upwards in the directory tree. - 3. esbmc-ai.env file in the current directory, moving upwards in the directory tree. - 4. esbmc-ai.env file in $HOME/.config/ for Linux/macOS and %userprofile% for Windows. - - Note: ESBMC_AI_CFG_PATH undergoes tilde user expansion and also environment - variable expansion. - """ - - def get_env_vars() -> None: - """Gets all the system environment variables that are currently loaded.""" - for k in keys: - value: Optional[str] = os.getenv(k) - if value != None: - values[k] = value - - keys: list[str] = ["OPENAI_API_KEY", "HUGGINGFACE_API_KEY", "ESBMC_AI_CFG_PATH"] - values: dict[str, str] = {} - - # Load from system env - get_env_vars() - - # Find .env in current working directory and load it. - dotenv_file_path: str = find_dotenv(usecwd=True) - if dotenv_file_path != "": - load_dotenv(dotenv_path=dotenv_file_path, override=False, verbose=True) - else: - # Find esbmc-ai.env in current working directory and load it. - dotenv_file_path: str = find_dotenv(filename="esbmc-ai.env", usecwd=True) - if dotenv_file_path != "": - load_dotenv(dotenv_path=dotenv_file_path, override=False, verbose=True) - - get_env_vars() - - # Look for .env in home folder. - home_path: Path = Path.home() - match system_name(): - case "Linux" | "Darwin": - home_path /= ".config/esbmc-ai.env" - case "Windows": - home_path /= "esbmc-ai.env" - case _: - raise ValueError(f"Unknown OS type: {system_name()}") - - load_dotenv(home_path, override=False, verbose=True) - - get_env_vars() - - # Check if all the values are set, else error. - for key in keys: - if key not in values: - print(f"Error: No ${key} in environment.") - sys.exit(1) - - global api_keys - api_keys = APIKeyCollection( - openai=str(os.getenv("OPENAI_API_KEY")), - huggingface=str(os.getenv("HUGGINGFACE_API_KEY")), - ) - - global cfg_path - cfg_path = os.path.expanduser( - os.path.expandvars(str(os.getenv("ESBMC_AI_CFG_PATH"))) - ) - - -def _load_ai_data(config: dict) -> None: - # User chat mode will store extra AIAgentConversations into scenarios. - global chat_prompt_user_mode - chat_prompt_user_mode = DynamicAIModelAgent( - system_messages=AIAgentConversation.load_from_config( - config["chat_modes"]["user_chat"]["system"] - ), - initial_prompt=config["chat_modes"]["user_chat"]["initial"], - temperature=config["chat_modes"]["user_chat"]["temperature"], - scenarios={ - "set_solution": AIAgentConversation.load_from_config( - messages_list=config["chat_modes"]["user_chat"]["set_solution"], - ), - }, - ) - - # Generator mode loads scenarios normally. - json_fcm_scenarios: dict = config["chat_modes"]["generate_solution"]["scenarios"] - fcm_scenarios: dict = { - scenario: AIAgentConversation.load_from_config(messages["system"]) - for scenario, messages in json_fcm_scenarios.items() - } - global chat_prompt_generator_mode - chat_prompt_generator_mode = DynamicAIModelAgent( - system_messages=AIAgentConversation.load_from_config( - config["chat_modes"]["generate_solution"]["system"] - ), - initial_prompt=config["chat_modes"]["generate_solution"]["initial"], - temperature=config["chat_modes"]["generate_solution"]["temperature"], - scenarios=fcm_scenarios, - ) - - -def _load_config_value( - config_file: dict, name: str, default: object = None -) -> tuple[Any, bool]: - if name in config_file: - return config_file[name], True - else: - print(f"Warning: {name} not found in config... Using default value: {default}") - return default, False - - -def _load_config_bool( - config_file: dict, - name: str, - default: bool = False, -) -> bool: - value, _ = _load_config_value(config_file, name, default) - if isinstance(value, bool): - return value - else: - raise TypeError( - f"Error: config invalid {name} value: {value} " - + "Make sure it is a bool value." - ) - - -def _load_config_real_number( - config_file: dict, name: str, default: object = None -) -> Union[int, float]: - value, _ = _load_config_value(config_file, name, default) - # Type check - if type(value) is float or type(value) is int: - return value - else: - raise TypeError( - f"Error: config invalid {name} value: {value} " - + "Make sure it is a float or int..." - ) - - -def load_config(file_path: str) -> None: - if not os.path.exists(file_path) and os.path.isfile(file_path): - print(f"Error: Config not found: {file_path}") - sys.exit(1) - - config_file = None - with open(file_path, mode="r") as file: - config_file = json.load(file) - - global esbmc_params - esbmc_params, _ = _load_config_value( - config_file, - "esbmc_params", - esbmc_params, - ) - - global fix_code_max_attempts - fix_code_max_attempts = int( - _load_config_real_number( - config_file=config_file["chat_modes"]["generate_solution"], - name="max_attempts", - default=fix_code_max_attempts, - ) - ) - - global source_code_format - source_code_format, _ = _load_config_value( - config_file=config_file, - name="source_code_format", - default=source_code_format, - ) - - if source_code_format not in ["full", "single"]: - raise Exception( - f"Source code format in the config is not valid: {source_code_format}" - ) - - global esbmc_output_type - esbmc_output_type, _ = _load_config_value( - config_file=config_file, - name="esbmc_output_type", - default=esbmc_output_type, - ) - - if esbmc_output_type not in ["full", "vp", "ce"]: - raise Exception( - f"ESBMC output type in the config is not valid: {esbmc_output_type}" - ) - - global fix_code_message_history - fix_code_message_history, _ = _load_config_value( - config_file=config_file["chat_modes"]["generate_solution"], - name="message_history", - ) - - if fix_code_message_history not in ["normal", "latest_only", "reverse"]: - raise ValueError( - f"error: fix code mode message history not valid: {fix_code_message_history}" - ) - - global requests_max_tries - requests_max_tries = int( - _load_config_real_number( - config_file=config_file["llm_requests"], - name="max_tries", - default=requests_max_tries, - ) - ) - - global requests_timeout - requests_timeout = _load_config_real_number( - config_file=config_file["llm_requests"], - name="timeout", - default=requests_timeout, - ) - - global verifier_timeout - verifier_timeout = _load_config_real_number( - config_file=config_file, - name="verifier_timeout", - default=verifier_timeout, - ) - - global temp_auto_clean - temp_auto_clean, _ = _load_config_value( - config_file, - "temp_auto_clean", - temp_auto_clean, - ) - - global temp_file_dir - temp_file_dir, _ = _load_config_value( - config_file, - "temp_file_dir", - temp_file_dir, - ) - - global allow_successful - allow_successful = _load_config_bool( - config_file, - "allow_successful", - False, - ) - - global loading_hints - loading_hints = _load_config_bool( - config_file, - "loading_hints", - True, - ) - - # Load the custom ai configs. - _load_custom_ai(config_file["ai_custom"]) - - global ai_model - ai_model_name, _ = _load_config_value( - config_file, - "ai_model", - ) - if is_valid_ai_model(ai_model_name, api_keys): - # Load the ai_model from loaded models. - ai_model = get_ai_model_by_name(ai_model_name, api_keys) - else: - print(f"Error: {ai_model_name} is not a valid AI model") - sys.exit(4) - - global esbmc_path - # Health check verifies this later in the init process. - esbmc_path, _ = _load_config_value( - config_file, - "esbmc_path", - esbmc_path, - ) - # Expand variables and tilde. - esbmc_path = os.path.expanduser(os.path.expandvars(esbmc_path)) - - # Load the AI data from the file that will command the AI for all modes. - printv("Initializing AI data") - _load_ai_data(config=config_file) - - -def load_args(args) -> None: - set_verbose(args.verbose) - - global ai_model - if args.ai_model != "": - if is_valid_ai_model(args.ai_model, api_keys): - ai_model = get_ai_model_by_name(args.ai_model, api_keys) - else: - print(f"Error: invalid --ai-model parameter {args.ai_model}") - sys.exit(4) - - global raw_conversation - raw_conversation = args.raw_conversation - - global esbmc_params - # If append flag is set, then append. - if args.append: - esbmc_params.extend(args.remaining) - elif len(args.remaining) != 0: - esbmc_params = args.remaining - - global generate_patches - generate_patches = args.generate_patches diff --git a/esbmc_ai/esbmc_util.py b/esbmc_ai/esbmc_util.py index a034a2e..99c1ed1 100644 --- a/esbmc_ai/esbmc_util.py +++ b/esbmc_ai/esbmc_util.py @@ -7,152 +7,172 @@ from typing import Optional from esbmc_ai.solution import SourceFile - -from . import config - - -def esbmc_get_violated_property(esbmc_output: str) -> Optional[str]: - """Gets the violated property line of the ESBMC output.""" - # Find "Violated property:" string in ESBMC output - lines: list[str] = esbmc_output.splitlines() - for ix, line in enumerate(lines): - if "Violated property:" == line: - return "\n".join(lines[ix : ix + 3]) - return None - - -def esbmc_get_counter_example(esbmc_output: str) -> Optional[str]: - """Gets ESBMC output after and including [Counterexample]""" - idx: int = esbmc_output.find("[Counterexample]\n") - if idx == -1: - return None - else: - return esbmc_output[idx:] - - -def esbmc_get_error_type(esbmc_output: str) -> str: - """Gets the error of violated property, the entire line.""" - # TODO Test me - # Start search from the marker. - marker: str = "Violated property:\n" - violated_property_index: int = esbmc_output.rfind(marker) + len(marker) - from_loc_error_msg: str = esbmc_output[violated_property_index:] - # Find second new line which contains the location of the violated - # property and that should point to the line with the type of error. - # In this case, the type of error is the "scenario". - scenario_index: int = from_loc_error_msg.find("\n") - scenario: str = from_loc_error_msg[scenario_index + 1 :] - scenario_end_l_index: int = scenario.find("\n") - scenario = scenario[:scenario_end_l_index].strip() - return scenario - - -def get_source_code_err_line(esbmc_output: str) -> Optional[int]: - # Find "Violated property:" string in ESBMC output - violated_property: Optional[str] = esbmc_get_violated_property(esbmc_output) - if violated_property: - # Get the line of the violated property. - pos_line: str = violated_property.splitlines()[1] - pos_line_split: list[str] = pos_line.split(" ") - for ix, word in enumerate(pos_line_split): - if word == "line": - # Get the line number - return int(pos_line_split[ix + 1]) - return None - - -def get_source_code_err_line_idx(esbmc_output: str) -> Optional[int]: - line: Optional[int] = get_source_code_err_line(esbmc_output) - if line: - return line - 1 - else: +from esbmc_ai.config import default_scenario + + +class ESBMCUtil: + @classmethod + def init(cls, esbmc_path: Path) -> None: + cls.esbmc_path: Path = esbmc_path + + @classmethod + def esbmc_get_violated_property(cls, esbmc_output: str) -> Optional[str]: + """Gets the violated property line of the ESBMC output.""" + # Find "Violated property:" string in ESBMC output + lines: list[str] = esbmc_output.splitlines() + for ix, line in enumerate(lines): + if "Violated property:" == line: + return "\n".join(lines[ix : ix + 3]) return None - -def get_clang_err_line(clang_output: str) -> Optional[int]: - """For when the code does not compile, gets the error line reported in the clang - output. This is useful for `esbmc_output_type single`""" - lines: list[str] = clang_output.splitlines() - for line in lines: - # Find the first line containing a filename along with error. - line_split: list[str] = line.split(":") - if len(line_split) < 4: - continue - # Check for the filename - if line_split[0].endswith(".c") and " error" in line_split[3]: - return int(line_split[1]) - - return None - - -def get_clang_err_line_index(clang_output: str) -> Optional[int]: - line: Optional[int] = get_clang_err_line(clang_output) - if line: - return line - 1 - else: + @classmethod + def esbmc_get_counter_example(cls, esbmc_output: str) -> Optional[str]: + """Gets ESBMC output after and including [Counterexample]""" + idx: int = esbmc_output.find("[Counterexample]\n") + if idx == -1: + return None + else: + return esbmc_output[idx:] + + @classmethod + def esbmc_get_error_type(cls, esbmc_output: str) -> str: + """Gets the error of violated property, the entire line.""" + # TODO Test me + # Start search from the marker. + marker: str = "Violated property:\n" + violated_property_index: int = esbmc_output.rfind(marker) + len(marker) + from_loc_error_msg: str = esbmc_output[violated_property_index:] + # Find second new line which contains the location of the violated + # property and that should point to the line with the type of error. + # In this case, the type of error is the "scenario". + scenario_index: int = from_loc_error_msg.find("\n") + scenario: str = from_loc_error_msg[scenario_index + 1 :] + scenario_end_l_index: int = scenario.find("\n") + scenario = scenario[:scenario_end_l_index].strip() + + if not scenario: + return default_scenario + + return scenario + + @classmethod + def get_source_code_err_line(cls, esbmc_output: str) -> Optional[int]: + # Find "Violated property:" string in ESBMC output + violated_property: Optional[str] = cls.esbmc_get_violated_property(esbmc_output) + if violated_property: + # Get the line of the violated property. + pos_line: str = violated_property.splitlines()[1] + pos_line_split: list[str] = pos_line.split(" ") + for ix, word in enumerate(pos_line_split): + if word == "line": + # Get the line number + return int(pos_line_split[ix + 1]) return None + @classmethod + def get_source_code_err_line_idx(cls, esbmc_output: str) -> Optional[int]: + line: Optional[int] = cls.get_source_code_err_line(esbmc_output) + if line: + return line - 1 + else: + return None + + @classmethod + def get_clang_err_line(cls, clang_output: str) -> Optional[int]: + """For when the code does not compile, gets the error line reported in the clang + output. This is useful for `esbmc_output_type single`""" + lines: list[str] = clang_output.splitlines() + for line in lines: + # Find the first line containing a filename along with error. + line_split: list[str] = line.split(":") + if len(line_split) < 4: + continue + # Check for the filename + if line_split[0].endswith(".c") and " error" in line_split[3]: + return int(line_split[1]) -def esbmc(path: Path, esbmc_params: list, timeout: Optional[float] = None): - """Exit code will be 0 if verification successful, 1 if verification - failed. And any other number for compilation error/general errors.""" - # Build parameters - esbmc_cmd = [config.esbmc_path] - esbmc_cmd.extend(esbmc_params) - esbmc_cmd.append(str(path)) + return None - if "--timeout" in esbmc_cmd: - print( - 'Do not add --timeout to ESBMC parameters, instead specify it in "verifier_timeout".' - ) - sys.exit(1) - - esbmc_cmd.extend(["--timeout", str(timeout)]) - - # Add slack time to process to allow verifier to timeout and end gracefully. - process_timeout: Optional[float] = timeout + 10 if timeout else None - - # Run ESBMC and get output - process: CompletedProcess = run( - esbmc_cmd, - stdout=PIPE, - stderr=STDOUT, - timeout=process_timeout, - ) - - output: str = process.stdout.decode("utf-8") - return process.returncode, output - - -def esbmc_load_source_code( - source_file: SourceFile, - source_file_content_index: int, - esbmc_params: list = config.esbmc_params, - auto_clean: bool = config.temp_auto_clean, - timeout: Optional[float] = None, -): - - file_path: Path - if config.temp_file_dir: - file_path = source_file.save_file( - file_path=Path(config.temp_file_dir), - temp_dir=False, - index=source_file_content_index, - ) - else: - file_path = source_file.save_file( - file_path=None, - temp_dir=True, - index=source_file_content_index, + @classmethod + def get_clang_err_line_index(cls, clang_output: str) -> Optional[int]: + line: Optional[int] = cls.get_clang_err_line(clang_output) + if line: + return line - 1 + else: + return None + + @classmethod + def esbmc( + cls, + path: Path, + esbmc_params: list, + timeout: Optional[int] = None, + ): + """Exit code will be 0 if verification successful, 1 if verification + failed. And any other number for compilation error/general errors.""" + # Build parameters + esbmc_cmd = [str(cls.esbmc_path)] + esbmc_cmd.extend(esbmc_params) + esbmc_cmd.append(str(path)) + + if "--timeout" in esbmc_cmd: + print( + 'Do not add --timeout to ESBMC parameters, instead specify it in "verifier_timeout".' + ) + sys.exit(1) + + esbmc_cmd.extend(["--timeout", str(timeout)]) + + # Add slack time to process to allow verifier to timeout and end gracefully. + process_timeout: Optional[float] = timeout + 10 if timeout else None + + # Run ESBMC and get output + process: CompletedProcess = run( + esbmc_cmd, + stdout=PIPE, + stderr=STDOUT, + timeout=process_timeout, ) - # Call ESBMC to temporary folder. - results = esbmc(file_path, esbmc_params, timeout=timeout) + output: str = process.stdout.decode("utf-8") + return process.returncode, output + + @classmethod + def esbmc_load_source_code( + cls, + source_file: SourceFile, + source_file_content_index: int, + esbmc_params: list, + auto_clean: bool, + temp_file_dir: Optional[Path] = None, + timeout: Optional[int] = None, + ): + + file_path: Path + if temp_file_dir: + file_path = source_file.save_file( + file_path=Path(temp_file_dir), + temp_dir=False, + index=source_file_content_index, + ) + else: + file_path = source_file.save_file( + file_path=None, + temp_dir=True, + index=source_file_content_index, + ) + + # Call ESBMC to temporary folder. + results = cls.esbmc( + path=file_path, + esbmc_params=esbmc_params, + timeout=timeout, + ) - # Delete temp files and path - if auto_clean: - # Remove file - os.remove(file_path) + # Delete temp files and path + if auto_clean: + # Remove file + os.remove(file_path) - # Return - return results + # Return + return results diff --git a/esbmc_ai/loading_widget.py b/esbmc_ai/loading_widget.py index 6cde88a..93008e2 100644 --- a/esbmc_ai/loading_widget.py +++ b/esbmc_ai/loading_widget.py @@ -12,7 +12,7 @@ from threading import Thread from typing import Optional -from esbmc_ai import config +from esbmc_ai import Config class LoadingWidget(object): @@ -56,7 +56,7 @@ def _animate(self) -> None: terminal.flush() def start(self, text: str = "Please Wait") -> None: - if not config.loading_hints: + if not Config.get_value("loading_hints"): return self.done = False self.loading_text = text @@ -65,7 +65,7 @@ def start(self, text: str = "Please Wait") -> None: self.thread.start() def stop(self) -> None: - if not config.loading_hints: + if not Config.get_value("loading_hints"): return self.done = True # Block until end. diff --git a/pyproject.toml b/pyproject.toml index 418549d..f613127 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,26 +23,24 @@ keywords = [ ] dependencies = [ - "openai", - "python-dotenv", - "tiktoken", - "aiohttp", - "aiosignal", - "async-timeout", - "attrs", - "certifi", - "charset-normalizer", - "frozenlist", - "idna", - "multidict", - "regex", - "requests", - "urllib3", - "yarl", - "libclang", - "clang", - "langchain", - "langchain-openai", + "openai", + "python-dotenv==1.0.0", + "tiktoken", + "aiosignal==1.3.1", + "async-timeout==4.0.2", + "attrs==23.1.0", + "certifi==2022.12.7", + "charset-normalizer==3.1.0", + "idna==3.4", + "regex==2023.3.23", + "requests==2.29.0", + "urllib3==1.26.15", + "yarl==1.9.2", + "langchain", + "langchain-openai", + "langchain-community", + "langchain-ollama", + "lizard", ] #[project.optional-dependencies] @@ -53,14 +51,29 @@ dependencies = [ # "...", #] +[tool.hatch.envs.default] +# Dependency of pytest-regtest: py +dependencies = [ + "pylint", + "ipykernel", + "pytest", + "pytest-cov", + "pytest-regtest", + "py", + "twine", + "hatch", + "transformers", + "torch", +] + [project.scripts] esbmc-ai = "esbmc_ai.__main__:main" [project.urls] -Homepage = "https://github.com/Yiannis128/esbmc-ai" -"Source Code" = "https://github.com/Yiannis128/esbmc-ai" -Documentation = "https://github.com/Yiannis128/esbmc-ai/wiki" -Issues = "https://github.com/Yiannis128/esbmc-ai/issues" +Homepage = "https://github.com/esbmc/esbmc-ai" +"Source Code" = "https://github.com/esbmc/esbmc-ai" +Documentation = "https://github.com/esbmc/esbmc-ai/wiki" +Issues = "https://github.com/esbmc/esbmc-ai/issues" [tool.hatch.version] path = "esbmc_ai/__about__.py" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 583ea95..0000000 --- a/requirements.txt +++ /dev/null @@ -1,49 +0,0 @@ --i https://pypi.org/simple -aiohttp==3.8.4; python_version >= '3.6' -aiosignal==1.3.1; python_version >= '3.7' -annotated-types==0.6.0; python_version >= '3.8' -anyio==4.3.0; python_version >= '3.8' -async-timeout==4.0.2; python_version >= '3.6' -attrs==23.1.0; python_version >= '3.7' -certifi==2022.12.7; python_version >= '3.6' -charset-normalizer==3.1.0; python_full_version >= '3.7.0' -clang==17.0.6 -dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0' -distro==1.9.0; python_version >= '3.6' -frozenlist==1.3.3; python_version >= '3.7' -greenlet==3.0.3; platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32'))))) -h11==0.14.0; python_version >= '3.7' -httpcore==1.0.5; python_version >= '3.8' -httpx==0.27.0; python_version >= '3.8' -idna==3.4; python_version >= '3.5' -jsonpatch==1.33; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6' -jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6' -langchain==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1' -langchain-community==0.0.34; python_version < '4.0' and python_full_version >= '3.8.1' -langchain-core==0.1.45; python_version < '4.0' and python_full_version >= '3.8.1' -langchain-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1' -langchain-text-splitters==0.0.1; python_version < '4.0' and python_full_version >= '3.8.1' -langsmith==0.1.49; python_version < '4.0' and python_full_version >= '3.8.1' -libclang==18.1.1 -marshmallow==3.21.1; python_version >= '3.8' -multidict==6.0.4; python_version >= '3.7' -mypy-extensions==1.0.0; python_version >= '3.5' -numpy==1.26.4; python_version >= '3.9' -openai==1.23.2; python_full_version >= '3.7.1' -orjson==3.10.1; python_version >= '3.8' -packaging==23.2; python_version >= '3.7' -pydantic==2.7.0; python_version >= '3.8' -pydantic-core==2.18.1; python_version >= '3.8' -python-dotenv==1.0.0; python_version >= '3.8' -pyyaml==6.0.1; python_version >= '3.6' -regex==2023.3.23; python_version >= '3.8' -requests==2.29.0; python_version >= '3.7' -sniffio==1.3.1; python_version >= '3.7' -sqlalchemy==2.0.29; python_version >= '3.7' -tenacity==8.2.3; python_version >= '3.7' -tiktoken==0.6.0; python_version >= '3.8' -tqdm==4.66.2; python_version >= '3.7' -typing-extensions==4.11.0; python_version >= '3.8' -typing-inspect==0.9.0 -urllib3==1.26.15; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' -yarl==1.9.2; python_version >= '3.7' diff --git a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out b/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out index 6cf7991..dad5f03 100644 --- a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out +++ b/tests/regtest/_regtest_outputs/test_base_chat_interface.test_push_message_stack.out @@ -1,2 +1,5 @@ -(SystemMessage(content='System message'), AIMessage(content='OK')) -[AIMessage(content='Test 1'), HumanMessage(content='Test 2'), SystemMessage(content='Test 3')] +system: System message +ai: OK +ai: Test 1 +human: Test 2 +system: Test 3 diff --git a/tests/regtest/test_base_chat_interface.py b/tests/regtest/test_base_chat_interface.py index 80151ee..cb4adf9 100644 --- a/tests/regtest/test_base_chat_interface.py +++ b/tests/regtest/test_base_chat_interface.py @@ -8,7 +8,6 @@ from esbmc_ai.ai_models import AIModel from esbmc_ai.chat_response import ChatResponse from esbmc_ai.chats.base_chat_interface import BaseChatInterface -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings @pytest.fixture @@ -24,11 +23,7 @@ def setup(): ] chat: BaseChatInterface = BaseChatInterface( - ai_model_agent=ChatPromptSettings( - initial_prompt="", - system_messages=AIAgentConversation.from_seq(system_messages), - temperature=1.0, - ), + system_messages=system_messages, ai_model=ai_model, llm=llm, ) @@ -50,8 +45,11 @@ def test_push_message_stack(regtest, setup) -> None: chat.push_to_message_stack(messages[2]) with regtest: - print(chat.ai_model_agent.system_messages.messages) - print(chat.messages) + for msg in chat._system_messages: + print(f"{msg.type}: {msg.content}") + + for msg in chat.messages: + print(f"{msg.type}: {msg.content}") def test_send_message(regtest, setup) -> None: @@ -65,12 +63,13 @@ def test_send_message(regtest, setup) -> None: with regtest: print("System Messages:") - for m in chat.ai_model_agent.system_messages.messages: + for m in chat._system_messages: print(f"{m.type}: {m.content}") print("Chat Messages:") for m in chat.messages: print(f"{m.type}: {m.content}") print("Responses:") for m in chat_responses: - print(f"{m.message.type}({m.total_tokens} - {m.finish_reason}): {m.message.content}") - + print( + f"{m.message.type}({m.total_tokens} - {m.finish_reason}): {m.message.content}" + ) diff --git a/tests/test_ai_models.py b/tests/test_ai_models.py index 7cc8397..8f95179 100644 --- a/tests/test_ai_models.py +++ b/tests/test_ai_models.py @@ -10,22 +10,22 @@ from pytest import raises from esbmc_ai.ai_models import ( + AIModelOpenAI, add_custom_ai_model, is_valid_ai_model, AIModel, _AIModels, get_ai_model_by_name, OllamaAIModel, - _get_openai_model_max_tokens, ) """TODO Find a way to mock the OpenAI API and test GPT LLM code.""" -def test_is_valid_ai_model() -> None: - assert is_valid_ai_model(_AIModels.FALCON_7B.value) - assert is_valid_ai_model(_AIModels.STARCHAT_BETA.value) - assert is_valid_ai_model("falcon-7b") +# def test_is_valid_ai_model() -> None: +# assert is_valid_ai_model(_AIModels.FALCON_7B.value) +# assert is_valid_ai_model(_AIModels.STARCHAT_BETA.value) +# assert is_valid_ai_model("falcon-7b") def test_is_not_valid_ai_model() -> None: @@ -58,7 +58,7 @@ def test_add_custom_ai_model() -> None: def test_get_ai_model_by_name() -> None: # Try with first class AI - assert get_ai_model_by_name("falcon-7b") + # assert get_ai_model_by_name("falcon-7b") # Try with custom AI. # Add custom AI model if not added by previous tests. @@ -142,12 +142,14 @@ def test_escape_messages() -> None: def test__get_openai_model_max_tokens() -> None: - assert _get_openai_model_max_tokens("gpt-4o") == 128000 - assert _get_openai_model_max_tokens("gpt-4-turbo") == 8192 - assert _get_openai_model_max_tokens("gpt-3.5-turbo") == 16385 - assert _get_openai_model_max_tokens("gpt-3.5-turbo-instruct") == 4096 - assert _get_openai_model_max_tokens("gpt-3.5-turbo-aaaaaa") == 16385 - assert _get_openai_model_max_tokens("gpt-3.5-turbo-instruct-bbb") == 4096 + assert AIModelOpenAI.get_openai_model_max_tokens("gpt-4o") == 128000 + assert AIModelOpenAI.get_openai_model_max_tokens("gpt-4-turbo") == 8192 + assert AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo") == 16385 + assert AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo-instruct") == 4096 + assert AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo-aaaaaa") == 16385 + assert ( + AIModelOpenAI.get_openai_model_max_tokens("gpt-3.5-turbo-instruct-bbb") == 4096 + ) with raises(ValueError): - _get_openai_model_max_tokens("aaaaa") + AIModelOpenAI.get_openai_model_max_tokens("aaaaa") diff --git a/tests/test_base_chat_interface.py b/tests/test_base_chat_interface.py index e4738c4..ee6be2d 100644 --- a/tests/test_base_chat_interface.py +++ b/tests/test_base_chat_interface.py @@ -7,7 +7,6 @@ from esbmc_ai.ai_models import AIModel from esbmc_ai.chats.base_chat_interface import BaseChatInterface from esbmc_ai.chat_response import ChatResponse -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings @pytest.fixture(scope="module") @@ -28,16 +27,14 @@ def test_push_message_stack(setup) -> None: ai_model, system_messages = setup chat: BaseChatInterface = BaseChatInterface( - ai_model_agent=ChatPromptSettings( - AIAgentConversation.from_seq(system_messages), - initial_prompt="", - temperature=1.0, - ), + system_messages=system_messages, ai_model=ai_model, llm=llm, ) - assert chat.ai_model_agent.system_messages.messages == tuple(system_messages) + for msg, chat_msg in zip(system_messages, chat._system_messages): + assert msg.type == chat_msg.type + assert msg.content == chat_msg.content messages: list[BaseMessage] = [ AIMessage(content="Test 1"), @@ -61,11 +58,7 @@ def test_send_message(setup) -> None: ai_model, system_messages = setup chat: BaseChatInterface = BaseChatInterface( - ai_model_agent=ChatPromptSettings( - AIAgentConversation.from_seq(system_messages), - initial_prompt="", - temperature=1.0, - ), + system_messages=system_messages, ai_model=ai_model, llm=llm, ) @@ -101,11 +94,7 @@ def test_apply_template() -> None: llm: FakeListChatModel = FakeListChatModel(responses=responses) chat: BaseChatInterface = BaseChatInterface( - ai_model_agent=ChatPromptSettings( - AIAgentConversation.from_seq(system_messages), - initial_prompt="{source_code}{esbmc_output}", - temperature=1.0, - ), + system_messages=system_messages, ai_model=ai_model, llm=llm, ) diff --git a/tests/test_config.py b/tests/test_config.py index 9b774b0..b663717 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,80 +7,12 @@ from esbmc_ai.ai_models import is_valid_ai_model -def test_load_config_value() -> None: - result, ok = config._load_config_value( - { - "test": "value", - }, - "test", - ) - assert ok and result == "value" - - -def test_load_config_value_default_value() -> None: - result, ok = config._load_config_value( - { - "test": "value", - }, - "test", - "wrong", - ) - assert ok and result == "value" - - -def test_load_config_value_default_value_not_exists() -> None: - result, ok = config._load_config_value( - {}, - "test2", - "wrong", - ) - assert not ok and result == "wrong" - - -def test_load_config_real_number() -> None: - result = config._load_config_real_number( - { - "test": 1.0, - }, - "test", - ) - assert result == 1.0 - - -def test_load_config_real_number_default_value() -> None: - result = config._load_config_real_number({}, "test", 1.1) - assert result == 1.1 - - -def test_load_config_real_number_wrong_value() -> None: - with raises(TypeError): - result = config._load_config_real_number( - { - "test": "wrong value", - }, - "test", - ) - assert result == None - - -def test_load_config_real_number_wrong_value_default() -> None: - with raises(TypeError): - result = config._load_config_real_number( - { - "test": "wrong value", - }, - "test", - 1.0, - ) - assert result == None - - def test_load_custom_ai() -> None: custom_ai_config: dict = { "example_ai": { "max_tokens": 4096, "url": "www.example.com", - "server_type": "ollama" + "server_type": "ollama", } } diff --git a/tests/test_esbmc_util.py b/tests/test_esbmc_util.py index d5e7090..bf74e43 100644 --- a/tests/test_esbmc_util.py +++ b/tests/test_esbmc_util.py @@ -3,12 +3,7 @@ import pytest from os import listdir -from esbmc_ai.esbmc_util import ( - esbmc_get_counter_example, - esbmc_get_violated_property, - get_source_code_err_line, - get_clang_err_line, -) +from esbmc_ai.esbmc_util import ESBMCUtil @pytest.fixture(scope="module") @@ -27,22 +22,22 @@ def test_get_source_code_err_line(setup_get_data): data_esbmc_output: dict[str, str] = setup_get_data esbmc_output: str = data_esbmc_output["cartpole_48_safe.c-amalgamation-6.c"] - assert get_source_code_err_line(esbmc_output) == 323 + assert ESBMCUtil.get_source_code_err_line(esbmc_output) == 323 esbmc_output = data_esbmc_output["cartpole_92_safe.c-amalgamation-14.c"] - assert get_source_code_err_line(esbmc_output) == 221 + assert ESBMCUtil.get_source_code_err_line(esbmc_output) == 221 esbmc_output = data_esbmc_output["cartpole_95_safe.c-amalgamation-80.c"] - assert get_source_code_err_line(esbmc_output) == 285 + assert ESBMCUtil.get_source_code_err_line(esbmc_output) == 285 esbmc_output = data_esbmc_output["cartpole_26_safe.c-amalgamation-74.c"] - assert get_source_code_err_line(esbmc_output) == 299 + assert ESBMCUtil.get_source_code_err_line(esbmc_output) == 299 esbmc_output = data_esbmc_output["robot_5_safe.c-amalgamation-13.c"] - assert get_source_code_err_line(esbmc_output) == 350 + assert ESBMCUtil.get_source_code_err_line(esbmc_output) == 350 esbmc_output = data_esbmc_output["vdp_1_safe.c-amalgamation-28.c"] - assert get_source_code_err_line(esbmc_output) == 247 + assert ESBMCUtil.get_source_code_err_line(esbmc_output) == 247 def test_esbmc_get_counter_example(setup_get_data) -> None: @@ -50,27 +45,27 @@ def test_esbmc_get_counter_example(setup_get_data) -> None: esbmc_output: str = data_esbmc_output["cartpole_48_safe.c-amalgamation-6.c"] ce_idx: int = esbmc_output.find("[Counterexample]") - assert esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] + assert ESBMCUtil.esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] esbmc_output = data_esbmc_output["cartpole_92_safe.c-amalgamation-14.c"] ce_idx = esbmc_output.find("[Counterexample]") - assert esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] + assert ESBMCUtil.esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] esbmc_output = data_esbmc_output["cartpole_95_safe.c-amalgamation-80.c"] ce_idx = esbmc_output.find("[Counterexample]") - assert esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] + assert ESBMCUtil.esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] esbmc_output = data_esbmc_output["cartpole_26_safe.c-amalgamation-74.c"] ce_idx = esbmc_output.find("[Counterexample]") - assert esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] + assert ESBMCUtil.esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] esbmc_output = data_esbmc_output["robot_5_safe.c-amalgamation-13.c"] ce_idx = esbmc_output.find("[Counterexample]") - assert esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] + assert ESBMCUtil.esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] esbmc_output = data_esbmc_output["vdp_1_safe.c-amalgamation-28.c"] ce_idx = esbmc_output.find("[Counterexample]") - assert esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] + assert ESBMCUtil.esbmc_get_counter_example(esbmc_output) == esbmc_output[ce_idx:] def test_esbmc_get_violated_property(setup_get_data) -> None: @@ -79,32 +74,50 @@ def test_esbmc_get_violated_property(setup_get_data) -> None: esbmc_output: str = data_esbmc_output["cartpole_48_safe.c-amalgamation-6.c"] start_idx: int = esbmc_output.find("Violated property:") end_idx: int = esbmc_output.find("VERIFICATION FAILED") - 3 - assert esbmc_get_violated_property(esbmc_output) == esbmc_output[start_idx:end_idx] + assert ( + ESBMCUtil.esbmc_get_violated_property(esbmc_output) + == esbmc_output[start_idx:end_idx] + ) esbmc_output = data_esbmc_output["cartpole_92_safe.c-amalgamation-14.c"] start_idx = esbmc_output.find("Violated property:") end_idx = esbmc_output.find("VERIFICATION FAILED") - 3 - assert esbmc_get_violated_property(esbmc_output) == esbmc_output[start_idx:end_idx] + assert ( + ESBMCUtil.esbmc_get_violated_property(esbmc_output) + == esbmc_output[start_idx:end_idx] + ) esbmc_output = data_esbmc_output["cartpole_95_safe.c-amalgamation-80.c"] start_idx = esbmc_output.find("Violated property:") end_idx = esbmc_output.find("VERIFICATION FAILED") - 3 - assert esbmc_get_violated_property(esbmc_output) == esbmc_output[start_idx:end_idx] + assert ( + ESBMCUtil.esbmc_get_violated_property(esbmc_output) + == esbmc_output[start_idx:end_idx] + ) esbmc_output = data_esbmc_output["cartpole_26_safe.c-amalgamation-74.c"] start_idx = esbmc_output.find("Violated property:") end_idx = esbmc_output.find("VERIFICATION FAILED") - 3 - assert esbmc_get_violated_property(esbmc_output) == esbmc_output[start_idx:end_idx] + assert ( + ESBMCUtil.esbmc_get_violated_property(esbmc_output) + == esbmc_output[start_idx:end_idx] + ) esbmc_output = data_esbmc_output["robot_5_safe.c-amalgamation-13.c"] start_idx = esbmc_output.find("Violated property:") end_idx = esbmc_output.find("VERIFICATION FAILED") - 3 - assert esbmc_get_violated_property(esbmc_output) == esbmc_output[start_idx:end_idx] + assert ( + ESBMCUtil.esbmc_get_violated_property(esbmc_output) + == esbmc_output[start_idx:end_idx] + ) esbmc_output = data_esbmc_output["vdp_1_safe.c-amalgamation-28.c"] start_idx = esbmc_output.find("Violated property:") end_idx = esbmc_output.find("VERIFICATION FAILED") - 3 - assert esbmc_get_violated_property(esbmc_output) == esbmc_output[start_idx:end_idx] + assert ( + ESBMCUtil.esbmc_get_violated_property(esbmc_output) + == esbmc_output[start_idx:end_idx] + ) @pytest.fixture(scope="module") @@ -122,5 +135,5 @@ def setup_clang_parse_errors() -> dict[str, str]: def test_get_clang_err_line_index(setup_clang_parse_errors) -> None: data_esbmc_output = setup_clang_parse_errors print(data_esbmc_output["threading.c"]) - line = get_clang_err_line(data_esbmc_output["threading.c"]) + line = ESBMCUtil.get_clang_err_line(data_esbmc_output["threading.c"]) assert line == 26 diff --git a/tests/test_latest_state_solution_generator.py b/tests/test_latest_state_solution_generator.py index 1db2ffa..9bf66a9 100644 --- a/tests/test_latest_state_solution_generator.py +++ b/tests/test_latest_state_solution_generator.py @@ -1,14 +1,14 @@ # Author: Yiannis Charalambous -from typing import Optional +from typing import Any, Optional from langchain_core.language_models import FakeListChatModel import pytest from langchain.schema import HumanMessage, AIMessage, SystemMessage +from esbmc_ai.config import default_scenario from esbmc_ai.ai_models import AIModel from esbmc_ai.chat_response import ChatResponse -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings from esbmc_ai.chats.latest_state_solution_generator import LatestStateSolutionGenerator @@ -27,64 +27,73 @@ def setup_llm_model(): def test_send_message(setup_llm_model) -> None: llm, model = setup_llm_model - chat_settings = ChatPromptSettings( - system_messages=AIAgentConversation( - messages=( - SystemMessage(content="Test message 1"), - HumanMessage(content="Test message 2"), - AIMessage(content="Test message 3"), - ), - ), - initial_prompt="Initial test message", - temperature=1.0, - ) solution_generator = LatestStateSolutionGenerator( + scenarios={ + "base": { + "initial": "Initial test message", + "system": [ + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ], + } + }, llm=llm, ai_model=model, - ai_model_agent=chat_settings, + ) + + # Create an object that can be edited by reference + class Referenced: + def __init__(self, value: Any) -> None: + self.value: Any = value + + initial_prompt: Referenced = Referenced( + solution_generator.scenarios[default_scenario]["initial"] ) def send_message_mock(message: Optional[str] = None) -> ChatResponse: assert len(solution_generator.messages) == 1 - assert solution_generator.messages[0] == HumanMessage( - content=chat_settings.initial_prompt, - ) + assert solution_generator.messages[0].content == initial_prompt.value + assert solution_generator.messages[0].type == HumanMessage(content="").type return ChatResponse() # Use the LLM method to check if the code is overwritten solution_generator.send_message = send_message_mock - + # Call update state once since `generate_solution` requires it solution_generator.update_state("", "") + # Check now if the message stack is wiped per generate solution call. solution_generator.generate_solution() - chat_settings.initial_prompt = "aaaaaaa" + initial_prompt.value = "aaaaaaa" + solution_generator.scenarios[default_scenario]["initial"] = "aaaaaaa" + solution_generator.generate_solution() - chat_settings.initial_prompt = "bbbbbbb" + initial_prompt.value = "bbbbbbb" + solution_generator.scenarios[default_scenario]["initial"] = "bbbbbbb" + solution_generator.generate_solution() - chat_settings.initial_prompt = "ccccccc" + initial_prompt.value = "ccccccc" + solution_generator.scenarios[default_scenario]["initial"] = "ccccccc" def test_message_stack(setup_llm_model) -> None: llm, model = setup_llm_model - chat_settings = ChatPromptSettings( - system_messages=AIAgentConversation( - messages=( - SystemMessage(content="Test message 1"), - HumanMessage(content="Test message 2"), - AIMessage(content="Test message 3"), - ), - ), - initial_prompt="Initial test message", - temperature=1.0, - ) - solution_generator = LatestStateSolutionGenerator( llm=llm, ai_model=model, - ai_model_agent=chat_settings, + scenarios={ + "base": { + "initial": "Initial test message", + "system": ( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + } + }, ) with pytest.raises(AssertionError): @@ -94,10 +103,10 @@ def test_message_stack(setup_llm_model) -> None: solution, _ = solution_generator.generate_solution() assert solution == llm.responses[0] - solution_generator.ai_model_agent.initial_prompt = "Test message 2" + solution_generator.scenarios[default_scenario]["initial"] = "Test message 2" solution, _ = solution_generator.generate_solution() assert solution == llm.responses[1] - solution_generator.ai_model_agent.initial_prompt = "Test message 3" + solution_generator.scenarios[default_scenario]["initial"] = "Test message 3" solution, _ = solution_generator.generate_solution() assert solution == llm.responses[2] diff --git a/tests/test_reverse_order_solution_generator.py b/tests/test_reverse_order_solution_generator.py index ebcbef8..0e89df1 100644 --- a/tests/test_reverse_order_solution_generator.py +++ b/tests/test_reverse_order_solution_generator.py @@ -9,8 +9,8 @@ SystemMessage, ) +from esbmc_ai.config import default_scenario from esbmc_ai.ai_models import AIModel -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator @@ -36,22 +36,19 @@ def test_send_message(setup_llm_model) -> None: def test_message_stack(setup_llm_model) -> None: llm, model = setup_llm_model - chat_settings = ChatPromptSettings( - system_messages=AIAgentConversation( - messages=( - SystemMessage(content="Test message 1"), - HumanMessage(content="Test message 2"), - AIMessage(content="Test message 3"), - ), - ), - initial_prompt="Initial test message", - temperature=1.0, - ) - solution_generator = ReverseOrderSolutionGenerator( llm=llm, ai_model=model, - ai_model_agent=chat_settings, + scenarios={ + "base": { + "initial": "Initial test message", + "system": ( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + } + }, ) with pytest.raises(AssertionError): @@ -61,10 +58,10 @@ def test_message_stack(setup_llm_model) -> None: solution, _ = solution_generator.generate_solution() assert solution == llm.responses[0] - solution_generator.ai_model_agent.initial_prompt = "Test message 2" + solution_generator.scenarios[default_scenario]["initial"] = "Test message 2" solution, _ = solution_generator.generate_solution() assert solution == llm.responses[1] - solution_generator.ai_model_agent.initial_prompt = "Test message 3" + solution_generator.scenarios[default_scenario]["initial"] = "Test message 3" solution, _ = solution_generator.generate_solution() assert solution == llm.responses[2] diff --git a/tests/test_solution_generator.py b/tests/test_solution_generator.py index ff86e7b..0a8f107 100644 --- a/tests/test_solution_generator.py +++ b/tests/test_solution_generator.py @@ -5,7 +5,6 @@ import pytest from esbmc_ai.ai_models import AIModel -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings from esbmc_ai.chats.solution_generator import SolutionGenerator @@ -25,22 +24,19 @@ def setup_llm_model(): def test_call_update_state_first(setup_llm_model) -> None: llm, model = setup_llm_model - chat_settings = ChatPromptSettings( - system_messages=AIAgentConversation( - messages=( - SystemMessage(content="Test message 1"), - HumanMessage(content="Test message 2"), - AIMessage(content="Test message 3"), - ), - ), - initial_prompt="Initial test message", - temperature=1.0, - ) - solution_generator = SolutionGenerator( llm=llm, ai_model=model, - ai_model_agent=chat_settings, + scenarios={ + "base": { + "initial": "Initial test message", + "system": ( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + } + }, ) with pytest.raises(AssertionError): diff --git a/tests/test_user_chat.py b/tests/test_user_chat.py index 0d39909..d7f5a8c 100644 --- a/tests/test_user_chat.py +++ b/tests/test_user_chat.py @@ -7,7 +7,6 @@ from esbmc_ai.ai_models import AIModel from esbmc_ai.chat_response import ChatResponse, FinishReason -from esbmc_ai.config import AIAgentConversation, ChatPromptSettings from esbmc_ai.chats.user_chat import UserChat @@ -18,46 +17,46 @@ def setup(): AIMessage(content="OK"), ] - set_solution_messages = [ - SystemMessage(content="Corrected output"), - ] - summary_text = "THIS IS A SUMMARY OF THE CONVERSATION" chat: UserChat = UserChat( - ai_model_agent=ChatPromptSettings( - system_messages=AIAgentConversation.from_seq(system_messages), - initial_prompt="This is initial prompt", - temperature=1.0, - ), + system_messages=system_messages, ai_model=AIModel(name="test", tokens=12), llm=FakeListChatModel(responses=[summary_text]), source_code="This is source code", esbmc_output="This is esbmc output", - set_solution_messages=AIAgentConversation.from_seq(set_solution_messages), + set_solution_messages=[ + SystemMessage(content="Corrected output"), + ], ) return chat, summary_text, system_messages -def test_compress_message_stack(setup) -> None: +@pytest.fixture +def initial_prompt(): + return "This is initial prompt" + + +def test_compress_message_stack(setup, initial_prompt) -> None: chat, summary_text, system_messages = setup - chat.messages = [SystemMessage(content=chat.ai_model_agent.initial_prompt)] + chat.messages = [SystemMessage(content=initial_prompt)] chat.compress_message_stack() # Check system messages - assert chat.ai_model_agent.system_messages.messages == tuple(system_messages) + for msg, chat_msg in zip(system_messages, chat._system_messages): + assert msg.type == chat_msg.type and msg.content == chat_msg.content # Check normal messages - assert chat.messages == [SystemMessage(content=summary_text)] + assert chat.messages[0].content == summary_text -def test_automatic_compress(setup) -> None: +def test_automatic_compress(setup, initial_prompt) -> None: chat, summary_text, system_messages = setup # Make the prompt extra large. - big_prompt: str = chat.ai_model_agent.initial_prompt * 10 + big_prompt: str = initial_prompt * 10 response: ChatResponse = chat.send_message(big_prompt) @@ -66,7 +65,8 @@ def test_automatic_compress(setup) -> None: chat.compress_message_stack() # Check system messages - assert chat.ai_model_agent.system_messages.messages == tuple(system_messages) + for msg, chat_msg in zip(system_messages, chat._system_messages): + assert msg.type == chat_msg.type and msg.content == chat_msg.content # Check normal messages - Should be summarized automatically - assert chat.messages == [SystemMessage(content=summary_text)] + assert chat.messages[0].content == summary_text