Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better BaseChatInterface + Logging #119

Merged
merged 7 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion esbmc_ai/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Author: Yiannis Charalambous

__version__ = "v0.4.0.post0"
__version__ = "v0.5.0.dev6"
__author__: str = "Yiannis Charalambous"
19 changes: 13 additions & 6 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from esbmc_ai.loading_widget import LoadingWidget, create_loading_widget
from esbmc_ai.user_chat import UserChat
from esbmc_ai.logging import printv, printvv
from esbmc_ai.logging import print_horizontal_line, printv, printvv
from esbmc_ai.esbmc_util import esbmc
from esbmc_ai.chat_response import FinishReason, ChatResponse
from esbmc_ai.ai_models import _ai_model_names
Expand Down Expand Up @@ -223,6 +223,14 @@ def main() -> None:
+ ", +custom models}",
)

parser.add_argument(
"-r",
"--raw-conversation",
action="store_true",
default=False,
help="Show the raw conversation at the end of a command. Good for debugging...",
)

parser.add_argument(
"-a",
"--append",
Expand Down Expand Up @@ -271,18 +279,17 @@ def main() -> None:
set_main_source_file(SourceFile(args.filename, file.read()))

anim.start("ESBMC is processing... Please Wait")
exit_code, esbmc_output, esbmc_err_output = esbmc(
exit_code, esbmc_output = esbmc(
path=get_main_source_file_path(),
esbmc_params=config.esbmc_params,
timeout=config.verifier_timeout,
)
anim.stop()

# Print verbose lvl 2
printvv("-" * os.get_terminal_size().columns)
print_horizontal_line(2)
printvv(esbmc_output)
printvv(esbmc_err_output)
printvv("-" * os.get_terminal_size().columns)
print_horizontal_line(2)

# ESBMC will output 0 for verification success and 1 for verification
# failed, if anything else gets thrown, it's an ESBMC error.
Expand All @@ -292,7 +299,7 @@ def main() -> None:
sys.exit(0)
elif exit_code != 0 and exit_code != 1:
print(f"ESBMC exit code: {exit_code}")
print(f"ESBMC Output:\n\n{esbmc_err_output}")
print(f"ESBMC Output:\n\n{esbmc_output}")
sys.exit(1)

# Command mode: Check if command is called and call it.
Expand Down
54 changes: 43 additions & 11 deletions esbmc_ai/base_chat_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Author: Yiannis Charalambous

from abc import abstractmethod
from typing import Optional

from langchain.base_language import BaseLanguageModel
from langchain.schema import (
Expand All @@ -26,40 +27,71 @@ def __init__(
super().__init__()
self.ai_model: AIModel = ai_model
self.ai_model_agent: ChatPromptSettings = ai_model_agent
self._system_messages: list[BaseMessage] = list(
ai_model_agent.system_messages.messages
)
self.messages: list[BaseMessage] = []
self.llm: BaseLanguageModel = llm
self.template_values: dict[str, str] = {}

@abstractmethod
def compress_message_stack(self) -> None:
raise NotImplementedError()

def set_template_value(self, key: str, value: str) -> None:
"""Replaces a template key with the value provided when the chat template is
applied."""
self.template_values[key] = value

def push_to_message_stack(
self,
message: BaseMessage,
) -> None:
self.messages.append(message)

def send_message(self, message: str) -> ChatResponse:
def apply_template_value(self, **kwargs: str) -> None:
"""Will substitute an f-string in the message stack and system messages to
the provided value."""

system_message_prompts: PromptValue = self.ai_model.apply_chat_template(
messages=self._system_messages,
**kwargs,
)
self._system_messages = system_message_prompts.to_messages()

message_prompts: PromptValue = self.ai_model.apply_chat_template(
messages=self.messages,
**kwargs,
)
self.messages = message_prompts.to_messages()

def get_applied_messages(self, **kwargs: str) -> tuple[BaseMessage, ...]:
"""Applies the f-string substituion and returns the result instead of assigning
it to the message stack."""
message_prompts: PromptValue = self.ai_model.apply_chat_template(
messages=self.messages,
**kwargs,
)
return tuple(message_prompts.to_messages())

def get_applied_system_messages(self, **kwargs: str) -> tuple[BaseMessage, ...]:
"""Same as `get_applied_messages` but for system messages."""
message_prompts: PromptValue = self.ai_model.apply_chat_template(
messages=self._system_messages,
**kwargs,
)
return tuple(message_prompts.to_messages())

def send_message(self, message: Optional[str] = None) -> ChatResponse:
"""Sends a message to the AI model. Returns solution."""
self.push_to_message_stack(message=HumanMessage(content=message))
if message:
self.push_to_message_stack(message=HumanMessage(content=message))

all_messages = list(self.ai_model_agent.system_messages.messages)
all_messages.extend(self.messages)
all_messages = self._system_messages.copy()
all_messages.extend(self.messages.copy())

# Transform message stack to ChatPromptValue: If this is a ChatLLM then the
# function will simply be an identity function that does nothing and simply
# returns the messages as a ChatPromptValue. If this is a text generation
# LLM, then the function should inject the config message around the
# conversation to make the LLM behave like a ChatLLM.
# Do not replace any values.
message_prompts: PromptValue = self.ai_model.apply_chat_template(
messages=all_messages,
**self.template_values,
)

response: ChatResponse
Expand Down
32 changes: 23 additions & 9 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
esbmc_load_source_code,
)
from ..solution_generator import SolutionGenerator, get_esbmc_output_formatted
from ..logging import printv, printvv
from ..logging import print_horizontal_line, printv, printvv

# TODO Remove built in messages and move them to config.

Expand All @@ -33,6 +33,14 @@ def __init__(self) -> None:

@override
def execute(self, **kwargs: Any) -> Tuple[bool, str]:
def print_raw_conversation() -> None:
print("Notice: Printing raw conversation...")
all_messages = solution_generator._system_messages.copy()
all_messages.extend(solution_generator.messages.copy())
messages: list[str] = [f"{msg.type}: {msg.content}" for msg in all_messages]
print("\n" + "\n\n".join(messages))
print("Notice: End of conversation")

file_name: str = kwargs["file_name"]
source_code: str = kwargs["source_code"]
esbmc_output: str = kwargs["esbmc_output"]
Expand Down Expand Up @@ -84,15 +92,15 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:

# Print verbose lvl 2
printvv("\nGeneration:")
printvv("-" * get_terminal_size().columns)
print_horizontal_line(2)
printvv(llm_solution)
printvv("-" * get_terminal_size().columns)
print_horizontal_line(2)
printvv("")

# Pass to ESBMC, a workaround is used where the file is saved
# to a temporary location since ESBMC needs it in file format.
self.anim.start("Verifying with ESBMC... Please Wait")
exit_code, esbmc_output, esbmc_err_output = esbmc_load_source_code(
exit_code, esbmc_output = esbmc_load_source_code(
file_path=file_name,
source_code=llm_solution,
esbmc_params=config.esbmc_params,
Expand All @@ -113,13 +121,16 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
pass

# Print verbose lvl 2
printvv("-" * get_terminal_size().columns)
print_horizontal_line(2)
printvv(esbmc_output)
printvv(esbmc_err_output)
printvv("-" * get_terminal_size().columns)
print_horizontal_line(2)

if exit_code == 0:
self.on_solution_signal.emit(llm_solution)

if config.raw_conversation:
print_raw_conversation()

return False, llm_solution

# Failure case
Expand All @@ -128,22 +139,25 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
if idx < max_retries - 1:

# Inform solution generator chat about the ESBMC response.
# TODO Add option to customize in config.
if exit_code != 1:
# The program did not compile.
solution_generator.push_to_message_stack(
message=HumanMessage(
content=f"The source code you provided does not compile. Fix the compilation errors. Use ESBMC output to fix the compilation errors:\n\n```\n{esbmc_output}\n```"
content=f"Here is the ESBMC output:\n\n```\n{esbmc_output}\n```"
)
)
else:
solution_generator.push_to_message_stack(
message=HumanMessage(
content=f"ESBMC has reported that verification failed, use the ESBMC output to find out what is wrong, and fix it. Here is ESBMC output:\n\n```\n{esbmc_output}\n```"
content=f"Here is the ESBMC output:\n\n```\n{esbmc_output}\n```"
)
)

solution_generator.push_to_message_stack(
AIMessage(content="Understood.")
)

if config.raw_conversation:
print_raw_conversation()
return True, "Failed all attempts..."
7 changes: 6 additions & 1 deletion esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataclasses import dataclass
from langchain.schema import BaseMessage

from .logging import *
from esbmc_ai.logging import printv, set_verbose
from .ai_models import *
from .api_key_collection import APIKeyCollection
from .chat_response import json_to_base_messages
Expand Down Expand Up @@ -51,6 +51,8 @@

loading_hints: bool = False
allow_successful: bool = False
# Show the raw conversation after the command ends
raw_conversation: bool = False

cfg_path: str

Expand Down Expand Up @@ -475,6 +477,9 @@ def load_args(args) -> None:
print(f"Error: invalid --ai-model parameter {args.ai_model}")
sys.exit(4)

global raw_conversation
raw_conversation = args.raw_conversation

global esbmc_params
# If append flag is set, then append.
if args.append:
Expand Down
8 changes: 2 additions & 6 deletions esbmc_ai/esbmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,8 @@ def esbmc(path: str, esbmc_params: list, timeout: Optional[float] = None):
timeout=process_timeout,
)

output_bytes: bytes = process.stdout
err_bytes: bytes = process.stderr
output: str = str(output_bytes).replace("\\n", "\n")
err: str = str(err_bytes).replace("\\n", "\n")

return process.returncode, output, err
output: str = process.stdout.decode("utf-8")
return process.returncode, output


def esbmc_load_source_code(
Expand Down
26 changes: 20 additions & 6 deletions esbmc_ai/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,42 @@

"""Logging module for verbose printing."""

verbose: int = 0
from os import get_terminal_size

_verbose: int = 0


def get_verbose_level() -> int:
return _verbose


def set_verbose(level: int) -> None:
"""Sets the verbosity level."""
global verbose
verbose = level
global _verbose
_verbose = level


def printv(m) -> None:
"""Level 1 verbose printing."""
if verbose > 0:
if _verbose > 0:
print(m)


def printvv(m) -> None:
"""Level 2 verbose printing."""
if verbose > 1:
if _verbose > 1:
print(m)


def printvvv(m) -> None:
"""Level 3 verbose printing."""
if verbose > 2:
if _verbose > 2:
print(m)


def print_horizontal_line(verbosity: int) -> None:
if verbosity >= _verbose:
try:
printvv("-" * get_terminal_size().columns)
except OSError:
pass
12 changes: 8 additions & 4 deletions esbmc_ai/solution_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Author: Yiannis Charalambous 2023

from re import S
from typing import Optional
from typing_extensions import override
from langchain.base_language import BaseLanguageModel
from langchain.schema import BaseMessage
from langchain.schema import BaseMessage, HumanMessage

from esbmc_ai.chat_response import ChatResponse, FinishReason
from esbmc_ai.config import ChatPromptSettings, DynamicAIModelAgent
Expand Down Expand Up @@ -86,14 +87,17 @@ def __init__(

self.source_code_format: str = source_code_format
self.source_code_raw: str = source_code
self.source_code = get_source_code_formatted(

source_code_formatted: str = get_source_code_formatted(
source_code_format=self.source_code_format,
source_code=self.source_code_raw,
esbmc_output=self.esbmc_output,
)

self.set_template_value("source_code", self.source_code)
self.set_template_value("esbmc_output", self.esbmc_output)
self.apply_template_value(
source_code=source_code_formatted,
esbmc_output=self.esbmc_output,
)

@override
def compress_message_stack(self) -> None:
Expand Down
Loading
Loading