Skip to content

Commit

Permalink
Updated pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Nov 4, 2024
1 parent 9a85a4a commit 4c8a1e5
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 49 deletions.
5 changes: 4 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ disable=raw-checker-failed,
use-symbolic-message-instead,
use-implicit-booleaness-not-comparison-to-string,
use-implicit-booleaness-not-comparison-to-zero,
unspecified-encoding
unspecified-encoding,
too-many-arguments,
too-many-positional-arguments,
too-many-instance-attributes

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
13 changes: 9 additions & 4 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,15 @@ def main() -> None:
if len(str(Config.get_user_chat_initial().content)) > 0:
printv("Using initial prompt from file...\n")
anim.start("Model is parsing ESBMC output... Please Wait")
response = chat.send_message(
message=str(Config.get_user_chat_initial().content),
)
anim.stop()
try:
response = chat.send_message(
message=str(Config.get_user_chat_initial().content),
)
except Exception as e:
print("There was an error while generating a response: {e}")
sys.exit(1)
finally:
anim.stop()

if response.finish_reason == FinishReason.length:
raise RuntimeError(f"The token length is too large: {chat.ai_model.tokens}")
Expand Down
3 changes: 1 addition & 2 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def is_valid_ai_model(
"""Accepts both the AIModel object and the name as parameter. It checks the
openai servers to see if a model is defined on their servers, if not, then
it checks the internally defined AI models list."""
from openai import Client

# Get the name of the model
name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model
Expand All @@ -251,8 +252,6 @@ def is_valid_ai_model(
# NOTE: This is not tested as no way to mock API currently.
if api_keys and api_keys.openai:
try:
from openai import Client

for model in Client(api_key=api_keys.openai).models.list().data:
if model.id == name:
return True
Expand Down
50 changes: 26 additions & 24 deletions esbmc_ai/chats/base_chat_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Author: Yiannis Charalambous

"""Contains code for the base class for interacting with the LLMs in a
conversation-based way."""

from abc import abstractmethod
from typing import Optional

Expand All @@ -14,7 +17,7 @@
from esbmc_ai.ai_models import AIModel


class BaseChatInterface(object):
class BaseChatInterface:
"""Base class for interacting with an LLM. It allows for interactions with
text generation LLMs and also chat LLMs."""

Expand All @@ -32,12 +35,14 @@ def __init__(

@abstractmethod
def compress_message_stack(self) -> None:
"""Compress the message stack, is abstract and needs to be implemented."""
raise NotImplementedError()

def push_to_message_stack(
self,
message: BaseMessage,
) -> None:
"""Pushes a message to the message stack without querying the LLM."""
self.messages.append(message)

def apply_template_value(self, **kwargs: str) -> None:
Expand Down Expand Up @@ -81,31 +86,28 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse:
all_messages = self._system_messages.copy()
all_messages.extend(self.messages.copy())

response: ChatResponse
try:
response_message: BaseMessage = self.llm.invoke(input=all_messages)
response_message: BaseMessage = self.llm.invoke(input=all_messages)

self.push_to_message_stack(message=response_message)

self.push_to_message_stack(message=response_message)
# Check if token limit has been exceeded.
all_messages.append(response_message)
new_tokens: int = self.llm.get_num_tokens_from_messages(
messages=all_messages,
)

# Check if token limit has been exceeded.
all_messages.append(response_message)
new_tokens: int = self.llm.get_num_tokens_from_messages(
messages=all_messages,
response: ChatResponse
if new_tokens > self.ai_model.tokens:
response = ChatResponse(
finish_reason=FinishReason.length,
message=response_message,
total_tokens=self.ai_model.tokens,
)
else:
response = ChatResponse(
finish_reason=FinishReason.stop,
message=response_message,
total_tokens=new_tokens,
)
if new_tokens > self.ai_model.tokens:
response = ChatResponse(
finish_reason=FinishReason.length,
message=response_message,
total_tokens=self.ai_model.tokens,
)
else:
response = ChatResponse(
finish_reason=FinishReason.stop,
message=response_message,
total_tokens=new_tokens,
)
except Exception as e:
print(f"There was an unkown error when generating a response: {e}")
exit(1)

return response
3 changes: 3 additions & 0 deletions esbmc_ai/chats/latest_state_solution_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Author: Yiannis Charalambous

"""Contains code that extends the default solution generator to only use the
latest state of the code only (removes history)"""

from typing import Optional
from typing_extensions import override
from langchain_core.messages import BaseMessage
Expand Down
29 changes: 22 additions & 7 deletions esbmc_ai/chats/solution_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Author: Yiannis Charalambous 2023

"""Contains code for automatically repairing code using ESBMC."""

from typing import Optional
from langchain_core.language_models import BaseChatModel
from typing_extensions import override
Expand All @@ -10,21 +12,25 @@
from esbmc_ai.solution import SourceFile

from esbmc_ai.ai_models import AIModel
from .base_chat_interface import BaseChatInterface
from esbmc_ai.esbmc_util import ESBMCUtil
from .base_chat_interface import BaseChatInterface


class ESBMCTimedOutException(Exception):
pass
"""Error that means that ESBMC timed out and so the error could not be
determined."""


class SourceCodeParseError(Exception):
pass
"""Error that means that SolutionGenerator could not parse the source code
to return the right format."""


def get_source_code_formatted(
source_code_format: str, source_code: str, esbmc_output: str
) -> str:
"""Gets the formatted output source code, based on the source_code_format
passed."""
match source_code_format:
case "single":
# Get source code error line from esbmc output
Expand All @@ -49,11 +55,14 @@ def get_source_code_formatted(


def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str:
"""Gets the formatted output ESBMC output, based on the esbmc_output_type
passed."""
# Check for parsing error
if "ERROR: PARSING ERROR" in esbmc_output:
# Parsing errors are usually small in nature.
raise SourceCodeParseError()
elif "ERROR: Timed out" in esbmc_output:

if "ERROR: Timed out" in esbmc_output:
raise ESBMCTimedOutException()

match esbmc_output_type:
Expand All @@ -74,6 +83,9 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str


class SolutionGenerator(BaseChatInterface):
"""Class that generates a solution using verifier output and source code
that contains a bug."""

def __init__(
self,
scenarios: FixCodeScenarios,
Expand Down Expand Up @@ -166,6 +178,8 @@ def generate_solution(
self,
override_scenario: Optional[str] = None,
) -> tuple[str, FinishReason]:
"""Queries the AI model to get a solution. Accepts an override scenario
parameter, in which case the scenario won't be resolved automatically."""

assert (
self.source_code_raw is not None
Expand Down Expand Up @@ -209,9 +223,10 @@ def generate_solution(
# Check if it parses
line = ESBMCUtil.get_clang_err_line_index(self.esbmc_output)

assert (
line
), "fix code command: error line could not be found to apply brutal patch replacement"
assert line, (
"fix code command: error line could not be found to apply "
"brutal patch replacement"
)
solution = SourceFile.apply_line_patch(
self.source_code_raw, solution, line, line
)
Expand Down
10 changes: 8 additions & 2 deletions esbmc_ai/chats/user_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Author: Yiannis Charalambous 2023

"""Contains class that handles the UserChat of ESBMC-AI"""

from typing_extensions import override

from langchain.memory import ConversationSummaryMemory
Expand All @@ -14,6 +16,9 @@


class UserChat(BaseChatInterface):
"""Simple interface that talks to the LLM and stores the result. The class
also stores the fixed results from fix code command."""

solution: str = ""

def __init__(
Expand Down Expand Up @@ -50,8 +55,9 @@ def set_solution(self, source_code: str) -> None:

@override
def compress_message_stack(self) -> None:
"""Uses ConversationSummaryMemory from Langchain to summarize the conversation of all the non-protected
messages into one summary message which is added into the conversation as a SystemMessage.
"""Uses ConversationSummaryMemory from Langchain to summarize the
conversation of all the non-protected messages into one summary message
which is added into the conversation as a SystemMessage.
"""

memory: ConversationSummaryMemory = ConversationSummaryMemory.from_messages(
Expand Down
2 changes: 2 additions & 0 deletions esbmc_ai/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .fix_code_command import FixCodeCommand
from .help_command import HelpCommand

"""This module contains built-in commands that can be executed by ESBMC-AI."""

__all__ = [
"ChatCommand",
"ExitCommand",
Expand Down
6 changes: 4 additions & 2 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Yiannis Charalambous

import sys
from typing import Any, Optional, Tuple
from typing import Any, Optional
from typing_extensions import override

from esbmc_ai.ai_models import AIModel
Expand Down Expand Up @@ -42,6 +42,8 @@ def __str__(self) -> str:


class FixCodeCommand(ChatCommand):
"""Command for automatically fixing code using a verifier."""

on_solution_signal: Signal = Signal()

def __init__(self) -> None:
Expand Down Expand Up @@ -71,7 +73,7 @@ def print_raw_conversation() -> None:
)

message_history: str = (
kwargs["message_history"] if "message_history" else "normal"
kwargs["message_history"] if "message_history" in kwargs else "normal"
)

api_keys: APIKeyCollection = kwargs["api_keys"]
Expand Down
16 changes: 9 additions & 7 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def _validate_prompt_template(conv: Dict[str, List[Dict]]) -> bool:


class Config:
"""Config loader for ESBMC-AI"""

api_keys: APIKeyCollection
raw_conversation: bool = False
cfg_path: Path
Expand All @@ -114,7 +116,7 @@ class Config:
name="temp_file_dir",
default_value=None,
validate=lambda v: isinstance(v, str) and Path(v).is_file(),
on_load=lambda v: Path(v),
on_load=Path,
default_value_none=True,
),
ConfigField(
Expand Down Expand Up @@ -181,13 +183,13 @@ class Config:
ConfigField(
name="user_chat.temperature",
default_value=1.0,
validate=lambda v: isinstance(v, float) and v >= 0 and v <= 2.0,
validate=lambda v: isinstance(v, float) and 0 <= v <= 2.0,
error_message="Temperature needs to be a value between 0 and 2.0",
),
ConfigField(
name="fix_code.temperature",
default_value=1.0,
validate=lambda v: isinstance(v, float) and v >= 0 and v <= 2.0,
validate=lambda v: isinstance(v, float) and 0 <= v <= 2,
error_message="Temperature needs to be a value between 0 and 2.0",
),
ConfigField(
Expand All @@ -210,14 +212,14 @@ class Config:
ConfigField(
name="prompt_templates.user_chat.system",
default_value=None,
validate=lambda v: _validate_prompt_template_conversation(v),
on_load=lambda v: list_to_base_messages(v),
validate=_validate_prompt_template_conversation,
on_load=list_to_base_messages,
),
ConfigField(
name="prompt_templates.user_chat.set_solution",
default_value=None,
validate=lambda v: _validate_prompt_template_conversation(v),
on_load=lambda v: list_to_base_messages(v),
validate=_validate_prompt_template_conversation,
on_load=list_to_base_messages,
),
# Here we have a list of prompt templates that are for each scenario.
# The base scenario prompt template is required.
Expand Down

0 comments on commit 4c8a1e5

Please sign in to comment.