From 4c8a1e59a6442308a121c27f7ffa951d0db85d47 Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Mon, 4 Nov 2024 15:38:00 +0000 Subject: [PATCH] Updated pylint --- .pylintrc | 5 +- esbmc_ai/__main__.py | 13 +++-- esbmc_ai/ai_models.py | 3 +- esbmc_ai/chats/base_chat_interface.py | 50 ++++++++++--------- .../chats/latest_state_solution_generator.py | 3 ++ esbmc_ai/chats/solution_generator.py | 29 ++++++++--- esbmc_ai/chats/user_chat.py | 10 +++- esbmc_ai/commands/__init__.py | 2 + esbmc_ai/commands/fix_code_command.py | 6 ++- esbmc_ai/config.py | 16 +++--- 10 files changed, 88 insertions(+), 49 deletions(-) diff --git a/.pylintrc b/.pylintrc index b5c747f..ed32435 100644 --- a/.pylintrc +++ b/.pylintrc @@ -431,7 +431,10 @@ disable=raw-checker-failed, use-symbolic-message-instead, use-implicit-booleaness-not-comparison-to-string, use-implicit-booleaness-not-comparison-to-zero, - unspecified-encoding + unspecified-encoding, + too-many-arguments, + too-many-positional-arguments, + too-many-instance-attributes # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index 33f8e40..ee80d02 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -417,10 +417,15 @@ def main() -> None: if len(str(Config.get_user_chat_initial().content)) > 0: printv("Using initial prompt from file...\n") anim.start("Model is parsing ESBMC output... Please Wait") - response = chat.send_message( - message=str(Config.get_user_chat_initial().content), - ) - anim.stop() + try: + response = chat.send_message( + message=str(Config.get_user_chat_initial().content), + ) + except Exception as e: + print("There was an error while generating a response: {e}") + sys.exit(1) + finally: + anim.stop() if response.finish_reason == FinishReason.length: raise RuntimeError(f"The token length is too large: {chat.ai_model.tokens}") diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index 6c1f228..db13e6e 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -243,6 +243,7 @@ def is_valid_ai_model( """Accepts both the AIModel object and the name as parameter. It checks the openai servers to see if a model is defined on their servers, if not, then it checks the internally defined AI models list.""" + from openai import Client # Get the name of the model name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model @@ -251,8 +252,6 @@ def is_valid_ai_model( # NOTE: This is not tested as no way to mock API currently. if api_keys and api_keys.openai: try: - from openai import Client - for model in Client(api_key=api_keys.openai).models.list().data: if model.id == name: return True diff --git a/esbmc_ai/chats/base_chat_interface.py b/esbmc_ai/chats/base_chat_interface.py index 5d2ab07..351c16f 100644 --- a/esbmc_ai/chats/base_chat_interface.py +++ b/esbmc_ai/chats/base_chat_interface.py @@ -1,5 +1,8 @@ # Author: Yiannis Charalambous +"""Contains code for the base class for interacting with the LLMs in a +conversation-based way.""" + from abc import abstractmethod from typing import Optional @@ -14,7 +17,7 @@ from esbmc_ai.ai_models import AIModel -class BaseChatInterface(object): +class BaseChatInterface: """Base class for interacting with an LLM. It allows for interactions with text generation LLMs and also chat LLMs.""" @@ -32,12 +35,14 @@ def __init__( @abstractmethod def compress_message_stack(self) -> None: + """Compress the message stack, is abstract and needs to be implemented.""" raise NotImplementedError() def push_to_message_stack( self, message: BaseMessage, ) -> None: + """Pushes a message to the message stack without querying the LLM.""" self.messages.append(message) def apply_template_value(self, **kwargs: str) -> None: @@ -81,31 +86,28 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse: all_messages = self._system_messages.copy() all_messages.extend(self.messages.copy()) - response: ChatResponse - try: - response_message: BaseMessage = self.llm.invoke(input=all_messages) + response_message: BaseMessage = self.llm.invoke(input=all_messages) + + self.push_to_message_stack(message=response_message) - self.push_to_message_stack(message=response_message) + # Check if token limit has been exceeded. + all_messages.append(response_message) + new_tokens: int = self.llm.get_num_tokens_from_messages( + messages=all_messages, + ) - # Check if token limit has been exceeded. - all_messages.append(response_message) - new_tokens: int = self.llm.get_num_tokens_from_messages( - messages=all_messages, + response: ChatResponse + if new_tokens > self.ai_model.tokens: + response = ChatResponse( + finish_reason=FinishReason.length, + message=response_message, + total_tokens=self.ai_model.tokens, + ) + else: + response = ChatResponse( + finish_reason=FinishReason.stop, + message=response_message, + total_tokens=new_tokens, ) - if new_tokens > self.ai_model.tokens: - response = ChatResponse( - finish_reason=FinishReason.length, - message=response_message, - total_tokens=self.ai_model.tokens, - ) - else: - response = ChatResponse( - finish_reason=FinishReason.stop, - message=response_message, - total_tokens=new_tokens, - ) - except Exception as e: - print(f"There was an unkown error when generating a response: {e}") - exit(1) return response diff --git a/esbmc_ai/chats/latest_state_solution_generator.py b/esbmc_ai/chats/latest_state_solution_generator.py index 2e40b7b..88af213 100644 --- a/esbmc_ai/chats/latest_state_solution_generator.py +++ b/esbmc_ai/chats/latest_state_solution_generator.py @@ -1,5 +1,8 @@ # Author: Yiannis Charalambous +"""Contains code that extends the default solution generator to only use the +latest state of the code only (removes history)""" + from typing import Optional from typing_extensions import override from langchain_core.messages import BaseMessage diff --git a/esbmc_ai/chats/solution_generator.py b/esbmc_ai/chats/solution_generator.py index 6364e6e..30afcec 100644 --- a/esbmc_ai/chats/solution_generator.py +++ b/esbmc_ai/chats/solution_generator.py @@ -1,5 +1,7 @@ # Author: Yiannis Charalambous 2023 +"""Contains code for automatically repairing code using ESBMC.""" + from typing import Optional from langchain_core.language_models import BaseChatModel from typing_extensions import override @@ -10,21 +12,25 @@ from esbmc_ai.solution import SourceFile from esbmc_ai.ai_models import AIModel -from .base_chat_interface import BaseChatInterface from esbmc_ai.esbmc_util import ESBMCUtil +from .base_chat_interface import BaseChatInterface class ESBMCTimedOutException(Exception): - pass + """Error that means that ESBMC timed out and so the error could not be + determined.""" class SourceCodeParseError(Exception): - pass + """Error that means that SolutionGenerator could not parse the source code + to return the right format.""" def get_source_code_formatted( source_code_format: str, source_code: str, esbmc_output: str ) -> str: + """Gets the formatted output source code, based on the source_code_format + passed.""" match source_code_format: case "single": # Get source code error line from esbmc output @@ -49,11 +55,14 @@ def get_source_code_formatted( def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str: + """Gets the formatted output ESBMC output, based on the esbmc_output_type + passed.""" # Check for parsing error if "ERROR: PARSING ERROR" in esbmc_output: # Parsing errors are usually small in nature. raise SourceCodeParseError() - elif "ERROR: Timed out" in esbmc_output: + + if "ERROR: Timed out" in esbmc_output: raise ESBMCTimedOutException() match esbmc_output_type: @@ -74,6 +83,9 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str class SolutionGenerator(BaseChatInterface): + """Class that generates a solution using verifier output and source code + that contains a bug.""" + def __init__( self, scenarios: FixCodeScenarios, @@ -166,6 +178,8 @@ def generate_solution( self, override_scenario: Optional[str] = None, ) -> tuple[str, FinishReason]: + """Queries the AI model to get a solution. Accepts an override scenario + parameter, in which case the scenario won't be resolved automatically.""" assert ( self.source_code_raw is not None @@ -209,9 +223,10 @@ def generate_solution( # Check if it parses line = ESBMCUtil.get_clang_err_line_index(self.esbmc_output) - assert ( - line - ), "fix code command: error line could not be found to apply brutal patch replacement" + assert line, ( + "fix code command: error line could not be found to apply " + "brutal patch replacement" + ) solution = SourceFile.apply_line_patch( self.source_code_raw, solution, line, line ) diff --git a/esbmc_ai/chats/user_chat.py b/esbmc_ai/chats/user_chat.py index 3274f34..5d7a781 100644 --- a/esbmc_ai/chats/user_chat.py +++ b/esbmc_ai/chats/user_chat.py @@ -1,5 +1,7 @@ # Author: Yiannis Charalambous 2023 +"""Contains class that handles the UserChat of ESBMC-AI""" + from typing_extensions import override from langchain.memory import ConversationSummaryMemory @@ -14,6 +16,9 @@ class UserChat(BaseChatInterface): + """Simple interface that talks to the LLM and stores the result. The class + also stores the fixed results from fix code command.""" + solution: str = "" def __init__( @@ -50,8 +55,9 @@ def set_solution(self, source_code: str) -> None: @override def compress_message_stack(self) -> None: - """Uses ConversationSummaryMemory from Langchain to summarize the conversation of all the non-protected - messages into one summary message which is added into the conversation as a SystemMessage. + """Uses ConversationSummaryMemory from Langchain to summarize the + conversation of all the non-protected messages into one summary message + which is added into the conversation as a SystemMessage. """ memory: ConversationSummaryMemory = ConversationSummaryMemory.from_messages( diff --git a/esbmc_ai/commands/__init__.py b/esbmc_ai/commands/__init__.py index d7767fb..01e7577 100644 --- a/esbmc_ai/commands/__init__.py +++ b/esbmc_ai/commands/__init__.py @@ -3,6 +3,8 @@ from .fix_code_command import FixCodeCommand from .help_command import HelpCommand +"""This module contains built-in commands that can be executed by ESBMC-AI.""" + __all__ = [ "ChatCommand", "ExitCommand", diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index f925942..12c9847 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -1,7 +1,7 @@ # Author: Yiannis Charalambous import sys -from typing import Any, Optional, Tuple +from typing import Any, Optional from typing_extensions import override from esbmc_ai.ai_models import AIModel @@ -42,6 +42,8 @@ def __str__(self) -> str: class FixCodeCommand(ChatCommand): + """Command for automatically fixing code using a verifier.""" + on_solution_signal: Signal = Signal() def __init__(self) -> None: @@ -71,7 +73,7 @@ def print_raw_conversation() -> None: ) message_history: str = ( - kwargs["message_history"] if "message_history" else "normal" + kwargs["message_history"] if "message_history" in kwargs else "normal" ) api_keys: APIKeyCollection = kwargs["api_keys"] diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index d107bb6..cfd730b 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -90,6 +90,8 @@ def _validate_prompt_template(conv: Dict[str, List[Dict]]) -> bool: class Config: + """Config loader for ESBMC-AI""" + api_keys: APIKeyCollection raw_conversation: bool = False cfg_path: Path @@ -114,7 +116,7 @@ class Config: name="temp_file_dir", default_value=None, validate=lambda v: isinstance(v, str) and Path(v).is_file(), - on_load=lambda v: Path(v), + on_load=Path, default_value_none=True, ), ConfigField( @@ -181,13 +183,13 @@ class Config: ConfigField( name="user_chat.temperature", default_value=1.0, - validate=lambda v: isinstance(v, float) and v >= 0 and v <= 2.0, + validate=lambda v: isinstance(v, float) and 0 <= v <= 2.0, error_message="Temperature needs to be a value between 0 and 2.0", ), ConfigField( name="fix_code.temperature", default_value=1.0, - validate=lambda v: isinstance(v, float) and v >= 0 and v <= 2.0, + validate=lambda v: isinstance(v, float) and 0 <= v <= 2, error_message="Temperature needs to be a value between 0 and 2.0", ), ConfigField( @@ -210,14 +212,14 @@ class Config: ConfigField( name="prompt_templates.user_chat.system", default_value=None, - validate=lambda v: _validate_prompt_template_conversation(v), - on_load=lambda v: list_to_base_messages(v), + validate=_validate_prompt_template_conversation, + on_load=list_to_base_messages, ), ConfigField( name="prompt_templates.user_chat.set_solution", default_value=None, - validate=lambda v: _validate_prompt_template_conversation(v), - on_load=lambda v: list_to_base_messages(v), + validate=_validate_prompt_template_conversation, + on_load=list_to_base_messages, ), # Here we have a list of prompt templates that are for each scenario. # The base scenario prompt template is required.