diff --git a/config.json b/config.json index b53172d..2324bae 100644 --- a/config.json +++ b/config.json @@ -72,6 +72,7 @@ "generate_solution": { "max_attempts": 5, "temperature": 1.3, + "message_history": "normal", "scenarios": { "division by zero": { "system": [ diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index 6a8edb8..97abcab 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -5,6 +5,8 @@ from typing_extensions import override from esbmc_ai.chat_response import FinishReason +from esbmc_ai.latest_state_solution_generator import LatestStateSolutionGenerator +from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator from .chat_command import ChatCommand from .. import config @@ -17,8 +19,6 @@ from ..solution_generator import ( ESBMCTimedOutException, SolutionGenerator, - SourceCodeParseError, - get_esbmc_output_formatted, ) from ..logging import print_horizontal_line, printv, printvv @@ -58,21 +58,58 @@ def print_raw_conversation() -> None: else "Using generic prompt..." ) + match config.fix_code_message_history: + case "normal": + solution_generator = SolutionGenerator( + ai_model_agent=config.chat_prompt_generator_mode, + ai_model=config.ai_model, + llm=config.ai_model.create_llm( + api_keys=config.api_keys, + temperature=config.chat_prompt_generator_mode.temperature, + requests_max_tries=config.requests_max_tries, + requests_timeout=config.requests_timeout, + ), + scenario=scenario, + source_code_format=config.source_code_format, + esbmc_output_type=config.esbmc_output_type, + ) + case "latest_only": + solution_generator = LatestStateSolutionGenerator( + ai_model_agent=config.chat_prompt_generator_mode, + ai_model=config.ai_model, + llm=config.ai_model.create_llm( + api_keys=config.api_keys, + temperature=config.chat_prompt_generator_mode.temperature, + requests_max_tries=config.requests_max_tries, + requests_timeout=config.requests_timeout, + ), + scenario=scenario, + source_code_format=config.source_code_format, + esbmc_output_type=config.esbmc_output_type, + ) + case "reverse": + solution_generator = ReverseOrderSolutionGenerator( + ai_model_agent=config.chat_prompt_generator_mode, + ai_model=config.ai_model, + llm=config.ai_model.create_llm( + api_keys=config.api_keys, + temperature=config.chat_prompt_generator_mode.temperature, + requests_max_tries=config.requests_max_tries, + requests_timeout=config.requests_timeout, + ), + scenario=scenario, + source_code_format=config.source_code_format, + esbmc_output_type=config.esbmc_output_type, + ) + case _: + raise NotImplementedError( + f"error: {config.fix_code_message_history} has not been implemented in the Fix Code Command" + ) + try: - solution_generator = SolutionGenerator( - ai_model_agent=config.chat_prompt_generator_mode, + solution_generator.update_state( source_code=source_code, esbmc_output=esbmc_output, - ai_model=config.ai_model, - llm=config.ai_model.create_llm( - api_keys=config.api_keys, - temperature=config.chat_prompt_generator_mode.temperature, - requests_max_tries=config.requests_max_tries, - requests_timeout=config.requests_timeout, - ), - scenario=scenario, - source_code_format=config.source_code_format, - esbmc_output_type=config.esbmc_output_type, ) except ESBMCTimedOutException: print("error: ESBMC has timed out...") @@ -90,9 +127,7 @@ def print_raw_conversation() -> None: llm_solution, finish_reason = solution_generator.generate_solution() self.anim.stop() if finish_reason == FinishReason.length: - self.anim.start("Compressing message stack... Please Wait") solution_generator.compress_message_stack() - self.anim.stop() else: source_code = llm_solution break @@ -133,16 +168,9 @@ def print_raw_conversation() -> None: return False, source_code - # TODO Move this process into Solution Generator since have (beginning) is done - # inside, and the other half is done here. - # Get formatted ESBMC output try: - esbmc_output = get_esbmc_output_formatted( - esbmc_output_type=config.esbmc_output_type, - esbmc_output=esbmc_output, - ) - except SourceCodeParseError: - pass + # Update state + solution_generator.update_state(source_code, esbmc_output) except ESBMCTimedOutException: if config.raw_conversation: print_raw_conversation() @@ -156,9 +184,6 @@ def print_raw_conversation() -> None: else "" ) - # Update state - solution_generator.update_state(source_code, esbmc_output) - if config.raw_conversation: print_raw_conversation() diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index 1506178..3ce23fa 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -44,6 +44,7 @@ source_code_format: str = "full" fix_code_max_attempts: int = 5 +fix_code_message_history: str = "" requests_max_tries: int = 5 requests_timeout: float = 60 @@ -57,6 +58,7 @@ cfg_path: str +# TODO Get rid of this class as soon as ConfigTool with the pyautoconfig class AIAgentConversation(NamedTuple): """Immutable class describing the conversation definition for an AI agent. The class represents the system messages of the AI agent defined and contains a load @@ -384,6 +386,17 @@ def load_config(file_path: str) -> None: f"ESBMC output type in the config is not valid: {esbmc_output_type}" ) + global fix_code_message_history + fix_code_message_history, _ = _load_config_value( + config_file=config_file["chat_modes"]["generate_solution"], + name="message_history", + ) + + if fix_code_message_history not in ["normal", "latest_only", "reverse"]: + raise ValueError( + f"error: fix code mode message history not valid: {fix_code_message_history}" + ) + global requests_max_tries requests_max_tries = int( _load_config_real_number( diff --git a/esbmc_ai/latest_state_solution_generator.py b/esbmc_ai/latest_state_solution_generator.py new file mode 100644 index 0000000..81d15a5 --- /dev/null +++ b/esbmc_ai/latest_state_solution_generator.py @@ -0,0 +1,27 @@ +# Author: Yiannis Charalambous + +from typing_extensions import override +from langchain_core.messages import BaseMessage +from esbmc_ai.solution_generator import SolutionGenerator +from esbmc_ai.chat_response import FinishReason + +# TODO Test me + + +class LatestStateSolutionGenerator(SolutionGenerator): + """SolutionGenerator that only shows the latest source code and verifier + output state.""" + + @override + def generate_solution(self) -> tuple[str, FinishReason]: + # Backup message stack and clear before sending base message. We want + # to keep the message stack intact because we will print it with + # print_raw_conversation. + messages: list[BaseMessage] = self.messages + self.messages: list[BaseMessage] = [] + solution, finish_reason = super().generate_solution() + # Append last messages to the messages stack + messages.extend(self.messages) + # Restore + self.messages = messages + return solution, finish_reason diff --git a/esbmc_ai/reverse_order_solution_generator.py b/esbmc_ai/reverse_order_solution_generator.py new file mode 100644 index 0000000..6d47938 --- /dev/null +++ b/esbmc_ai/reverse_order_solution_generator.py @@ -0,0 +1,34 @@ +# Author: Yiannis Charalambous + +from langchain.schema import BaseMessage, HumanMessage +from typing_extensions import override, Optional +from esbmc_ai.solution_generator import ( + SolutionGenerator, + get_source_code_formatted, + get_source_code_err_line_idx, + get_clang_err_line_index, + apply_line_patch, +) +from esbmc_ai.chat_response import FinishReason, ChatResponse + +# TODO Test me + + +class ReverseOrderSolutionGenerator(SolutionGenerator): + """SolutionGenerator that shows the source code and verifier output state in + reverse order.""" + + @override + def send_message(self, message: Optional[str] = None) -> ChatResponse: + # Reverse the messages + messages: list[BaseMessage] = self.messages.copy() + self.messages.reverse() + + response: ChatResponse = super().send_message(message) + + # Add to the reversed message the new message received by the LLM. + messages.append(self.messages[-1]) + # Restore + self.messages = messages + + return response diff --git a/esbmc_ai/solution_generator.py b/esbmc_ai/solution_generator.py index f22f54e..95b8eec 100644 --- a/esbmc_ai/solution_generator.py +++ b/esbmc_ai/solution_generator.py @@ -82,19 +82,23 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str class SolutionGenerator(BaseChatInterface): def __init__( self, - ai_model_agent: DynamicAIModelAgent, + ai_model_agent: DynamicAIModelAgent | ChatPromptSettings, llm: BaseLanguageModel, - source_code: str, - esbmc_output: str, ai_model: AIModel, scenario: str = "", source_code_format: str = "full", esbmc_output_type: str = "full", ) -> None: - # Convert to chat prompt - chat_prompt: ChatPromptSettings = DynamicAIModelAgent.to_chat_prompt_settings( - ai_model_agent=ai_model_agent, scenario=scenario - ) + """Initializes the solution generator. This ModelChat provides Dynamic + Prompting. Will get the correct scenario from the DynamicAIModelAgent + supplied and create a ChatPrompt.""" + + chat_prompt: ChatPromptSettings = ai_model_agent + if isinstance(ai_model_agent, DynamicAIModelAgent): + # Convert to chat prompt + chat_prompt = DynamicAIModelAgent.to_chat_prompt_settings( + ai_model_agent=ai_model_agent, scenario=scenario + ) super().__init__( ai_model_agent=chat_prompt, @@ -102,30 +106,19 @@ def __init__( llm=llm, ) - self.initial_prompt = ai_model_agent.initial_prompt - self.esbmc_output_type: str = esbmc_output_type self.source_code_format: str = source_code_format - self.source_code_raw: str = source_code - # Used for resetting state. - self._original_source_code: str = source_code - # Format ESBMC output - try: - self.esbmc_output = get_esbmc_output_formatted( - esbmc_output_type=self.esbmc_output_type, - esbmc_output=esbmc_output, - ) - except SourceCodeParseError: - # When clang output is displayed, show it entirely as it doesn't get very - # big. - self.esbmc_output = esbmc_output + self.source_code_raw: Optional[str] = None + self.source_code_formatted: Optional[str] = None + self.esbmc_output: Optional[str] = None @override def compress_message_stack(self) -> None: # Resets the conversation - cannot summarize code + # If generate_solution is called after this point, it will start new + # with the currently set state. self.messages: list[BaseMessage] = [] - self.source_code_raw = self._original_source_code @classmethod def get_code_from_solution(cls, solution: str) -> str: @@ -153,27 +146,43 @@ def get_code_from_solution(cls, solution: str) -> str: pass return solution - def update_state( - self, source_code: Optional[str] = None, esbmc_output: Optional[str] = None - ) -> None: - if source_code: - self.source_code_raw = source_code - if esbmc_output: - self.esbmc_output = esbmc_output + def update_state(self, source_code: str, esbmc_output: str) -> None: + """Updates the latest state of the code and ESBMC output. This should be + called before generate_solution.""" + self.source_code_raw = source_code - def generate_solution(self) -> tuple[str, FinishReason]: - self.push_to_message_stack(HumanMessage(content=self.initial_prompt)) + # Format ESBMC output + try: + self.esbmc_output = get_esbmc_output_formatted( + esbmc_output_type=self.esbmc_output_type, + esbmc_output=esbmc_output, + ) + except SourceCodeParseError: + # When clang output is displayed, show it entirely as it doesn't get very + # big. + self.esbmc_output = esbmc_output # Format source code - source_code_formatted: str = get_source_code_formatted( + self.source_code_formatted = get_source_code_formatted( source_code_format=self.source_code_format, - source_code=self.source_code_raw, + source_code=source_code, esbmc_output=self.esbmc_output, ) + def generate_solution(self) -> tuple[str, FinishReason]: + assert ( + self.source_code_raw is not None + and self.source_code_formatted is not None + and self.esbmc_output is not None + ), "Call update_state before calling generate_solution." + + self.push_to_message_stack( + HumanMessage(content=self.ai_model_agent.initial_prompt) + ) + # Apply template substitution to message stack self.apply_template_value( - source_code=source_code_formatted, + source_code=self.source_code_formatted, esbmc_output=self.esbmc_output, ) diff --git a/tests/test_latest_state_solution_generator.py b/tests/test_latest_state_solution_generator.py new file mode 100644 index 0000000..6bf4631 --- /dev/null +++ b/tests/test_latest_state_solution_generator.py @@ -0,0 +1,92 @@ +# Author: Yiannis Charalambous + +import pytest + +from langchain.schema import HumanMessage, AIMessage, SystemMessage +from langchain_community.llms.fake import FakeListLLM + +from esbmc_ai.ai_models import AIModel +from esbmc_ai.config import AIAgentConversation, ChatPromptSettings +from esbmc_ai.latest_state_solution_generator import LatestStateSolutionGenerator + + +@pytest.fixture(scope="function") +def setup_llm_model(): + llm = FakeListLLM( + responses=[ + "This is a test response", + "Another test response", + "One more!", + ], + ) + model = AIModel("test model", 1000) + return llm, model + + +def test_call_update_state_first(setup_llm_model) -> None: + llm, model = setup_llm_model + + chat_settings = ChatPromptSettings( + system_messages=AIAgentConversation( + messages=( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + ), + initial_prompt="Initial test message", + temperature=1.0, + ) + + solution_generator = LatestStateSolutionGenerator( + llm=llm, + ai_model=model, + ai_model_agent=chat_settings, + ) + + with pytest.raises(AssertionError): + solution_generator.generate_solution() + + +def test_message_stack(setup_llm_model) -> None: + llm, model = setup_llm_model + + chat_settings = ChatPromptSettings( + system_messages=AIAgentConversation( + messages=( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + ), + initial_prompt="Initial test message", + temperature=1.0, + ) + + solution_generator = LatestStateSolutionGenerator( + llm=llm, + ai_model=model, + ai_model_agent=chat_settings, + ) + + with pytest.raises(AssertionError): + solution_generator.generate_solution() + + solution_generator.update_state("", "") + + solution, _ = solution_generator.generate_solution() + assert solution == llm.responses[0] + solution_generator.ai_model_agent.initial_prompt = "Test message 2" + solution, _ = solution_generator.generate_solution() + assert solution == llm.responses[1] + solution_generator.ai_model_agent.initial_prompt = "Test message 3" + solution, _ = solution_generator.generate_solution() + assert solution == llm.responses[2] + + # Test history is intact + assert solution_generator.messages[0].content == "Initial test message" + assert solution_generator.messages[1].content == "This is a test response" + assert solution_generator.messages[2].content == "Test message 2" + assert solution_generator.messages[3].content == "Another test response" + assert solution_generator.messages[4].content == "Test message 3" + assert solution_generator.messages[5].content == "One more!" diff --git a/tests/test_reverse_order_solution_generator.py b/tests/test_reverse_order_solution_generator.py new file mode 100644 index 0000000..65bb69d --- /dev/null +++ b/tests/test_reverse_order_solution_generator.py @@ -0,0 +1,92 @@ +# Author: Yiannis Charalambous + +import pytest + +from langchain.schema import HumanMessage, AIMessage, SystemMessage +from langchain_community.llms.fake import FakeListLLM + +from esbmc_ai.ai_models import AIModel +from esbmc_ai.config import AIAgentConversation, ChatPromptSettings +from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator + + +@pytest.fixture(scope="function") +def setup_llm_model(): + llm = FakeListLLM( + responses=[ + "This is a test response", + "Another test response", + "One more!", + ], + ) + model = AIModel("test model", 1000) + return llm, model + + +def test_call_update_state_first(setup_llm_model) -> None: + llm, model = setup_llm_model + + chat_settings = ChatPromptSettings( + system_messages=AIAgentConversation( + messages=( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + ), + initial_prompt="Initial test message", + temperature=1.0, + ) + + solution_generator = ReverseOrderSolutionGenerator( + llm=llm, + ai_model=model, + ai_model_agent=chat_settings, + ) + + with pytest.raises(AssertionError): + solution_generator.generate_solution() + + +def test_message_stack(setup_llm_model) -> None: + llm, model = setup_llm_model + + chat_settings = ChatPromptSettings( + system_messages=AIAgentConversation( + messages=( + SystemMessage(content="Test message 1"), + HumanMessage(content="Test message 2"), + AIMessage(content="Test message 3"), + ), + ), + initial_prompt="Initial test message", + temperature=1.0, + ) + + solution_generator = ReverseOrderSolutionGenerator( + llm=llm, + ai_model=model, + ai_model_agent=chat_settings, + ) + + with pytest.raises(AssertionError): + solution_generator.generate_solution() + + solution_generator.update_state("", "") + + solution, _ = solution_generator.generate_solution() + assert solution == llm.responses[0] + solution_generator.ai_model_agent.initial_prompt = "Test message 2" + solution, _ = solution_generator.generate_solution() + assert solution == llm.responses[1] + solution_generator.ai_model_agent.initial_prompt = "Test message 3" + solution, _ = solution_generator.generate_solution() + assert solution == llm.responses[2] + + # Test history is intact + assert solution_generator.messages[0].content == "Initial test message" + assert solution_generator.messages[1].content == "This is a test response" + assert solution_generator.messages[2].content == "Test message 2" + assert solution_generator.messages[3].content == "Another test response" + assert solution_generator.messages[4].content == "Test message 3" + assert solution_generator.messages[5].content == "One more!"