From bc08d75988cdff880ac5eeac47c521a39aaf7bf9 Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Thu, 12 Sep 2024 22:08:53 +0100 Subject: [PATCH 1/5] Add Ollama Pipfile dep --- Pipfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Pipfile b/Pipfile index 3f3c2e2..3f83258 100644 --- a/Pipfile +++ b/Pipfile @@ -25,6 +25,7 @@ clang = "*" langchain = "*" langchain-openai = "*" langchain-community = "*" +langchain-ollama = "*" lizard = "*" [dev-packages] From 49382f92e5609ef447959f75e6fe7086bcd7c7d7 Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Thu, 12 Sep 2024 22:11:09 +0100 Subject: [PATCH 2/5] Added Ollama models support --- esbmc_ai/__main__.py | 5 +- esbmc_ai/ai_models.py | 157 ++++---------------------- esbmc_ai/chats/base_chat_interface.py | 29 ++--- esbmc_ai/chats/solution_generator.py | 4 +- esbmc_ai/chats/user_chat.py | 4 +- esbmc_ai/config.py | 48 +++----- 6 files changed, 57 insertions(+), 190 deletions(-) diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index e9a715d..1804373 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -11,12 +11,13 @@ import readline from typing import Optional +from langchain_core.language_models import BaseChatModel + from esbmc_ai.commands.fix_code_command import FixCodeCommandResult _ = readline import argparse -from langchain.base_language import BaseLanguageModel import esbmc_ai.config as config @@ -365,7 +366,7 @@ def main() -> None: del esbmc_output printv(f"Initializing the LLM: {config.ai_model.name}\n") - chat_llm: BaseLanguageModel = config.ai_model.create_llm( + chat_llm: BaseChatModel = config.ai_model.create_llm( api_keys=config.api_keys, temperature=config.chat_prompt_user_mode.temperature, requests_max_tries=config.requests_max_tries, diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index 16120f7..8010336 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -3,23 +3,14 @@ from abc import abstractmethod from typing import Any, Iterable, Optional, Union from enum import Enum +from langchain_core.language_models import BaseChatModel from pydantic.v1.types import SecretStr from typing_extensions import override -from langchain.prompts import PromptTemplate -from langchain.base_language import BaseLanguageModel - from langchain_openai import ChatOpenAI -from langchain_community.llms.huggingface_text_gen_inference import ( - HuggingFaceTextGenInference, -) +from langchain_ollama import ChatOllama -from langchain.prompts.chat import ( - AIMessagePromptTemplate, - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) +from langchain.prompts.chat import ChatPromptTemplate from langchain.schema import ( BaseMessage, PromptValue, @@ -30,6 +21,8 @@ class AIModel(object): + """This base class represents an abstract AI model.""" + name: str tokens: int @@ -48,7 +41,7 @@ def create_llm( temperature: float = 1.0, requests_max_tries: int = 5, requests_timeout: float = 60, - ) -> BaseLanguageModel: + ) -> BaseChatModel: """Initializes a large language model model with the provided parameters.""" raise NotImplementedError() @@ -132,7 +125,9 @@ def apply_chat_template( messages: Iterable[BaseMessage], **format_values: Any, ) -> PromptValue: - # Default one, identity function essentially. + """Applies the formatted values onto the message chat template. For example, + if the message contains the token {source}, then format_values contains a + value for {source} then it will be substituted.""" escaped_messages = AIModel.escape_messages(messages, list(format_values.keys())) message_tuples = AIModel.convert_messages_to_tuples(escaped_messages) return ChatPromptTemplate.from_messages(messages=message_tuples).format_prompt( @@ -151,7 +146,7 @@ def create_llm( temperature: float = 1.0, requests_max_tries: int = 5, requests_timeout: float = 60, - ) -> BaseLanguageModel: + ) -> BaseChatModel: assert api_keys.openai, "No OpenAI api key has been specified..." return ChatOpenAI( model=self.name, @@ -163,108 +158,23 @@ def create_llm( model_kwargs={}, ) - -class AIModelTextGen(AIModel): - """Below are only used for models that need them, such as models that - are using the provider "text_inference_server".""" - - def __init__( - self, - name: str, - tokens: int, - url: str, - config_message: str = "{history}\n\n{user_prompt}", - system_template: str = "{content}", - human_template: str = "{content}", - ai_template: str = "{content}", - stop_sequences: list[str] = [], - ) -> None: +class OllamaAIModel(AIModel): + def __init__(self, name: str, tokens: int, url: str) -> None: super().__init__(name, tokens) - self.url: str = url - self.chat_template: PromptTemplate = PromptTemplate.from_template( - template=config_message, - ) - """The chat template to place all messages in.""" - - self.system_template: SystemMessagePromptTemplate = ( - SystemMessagePromptTemplate.from_template( - template=system_template, - ) - ) - """Template for each system message.""" - - self.human_template: HumanMessagePromptTemplate = ( - HumanMessagePromptTemplate.from_template( - template=human_template, - ) - ) - """Template for each human message.""" - - self.ai_template: AIMessagePromptTemplate = ( - AIMessagePromptTemplate.from_template( - template=ai_template, - ) - ) - """Template for each AI message.""" - - self.stop_sequences: list[str] = stop_sequences - + @override - def create_llm( - self, - api_keys: APIKeyCollection, - temperature: float = 1.0, - requests_max_tries: int = 5, - requests_timeout: float = 60, - ) -> BaseLanguageModel: - return HuggingFaceTextGenInference( - client=None, - async_client=None, - inference_server_url=self.url, - server_kwargs={ - "headers": {"Authorization": f"Bearer {api_keys.huggingface}"} - }, - # FIXME Need to find a way to make output bigger. When token - # tracking for this LLM type is added. - max_new_tokens=5000, + def create_llm(self, api_keys: APIKeyCollection, temperature: float = 1, requests_max_tries: int = 5, requests_timeout: float = 60) -> BaseChatModel: + # Ollama does not use API keys + _ = api_keys + _ = requests_max_tries + return ChatOllama( + base_url=self.url, + model=self.name, temperature=temperature, - stop_sequences=self.stop_sequences, - max_retries=requests_max_tries, - timeout=requests_timeout, - ) - - @override - def apply_chat_template( - self, - messages: Iterable[BaseMessage], - **format_values: Any, - ) -> PromptValue: - """Text generation LLMs take single string of text as input. So the conversation - is converted into a string and returned back in a single prompt value. The config - message is also applied to the conversation.""" - - escaped_messages = AIModel.escape_messages(messages, list(format_values.keys())) - - formatted_messages: list[BaseMessage] = [] - for msg in escaped_messages: - formatted_msg: BaseMessage - if msg.type == "ai": - formatted_msg = self.ai_template.format(content=msg.content) - elif msg.type == "system": - formatted_msg = self.system_template.format(content=msg.content) - elif msg.type == "human": - formatted_msg = self.human_template.format(content=msg.content) - else: - raise ValueError( - f"Got unsupported message type: {msg.type}: {msg.content}" - ) - formatted_messages.append(formatted_msg) - - return self.chat_template.format_prompt( - history="\n\n".join([str(msg.content) for msg in formatted_messages[:-1]]), - user_prompt=formatted_messages[-1].content, - **format_values, + client_kwargs={ + "timeout":requests_timeout, + }, ) @@ -272,25 +182,8 @@ class _AIModels(Enum): """Private enum that contains predefined AI Models. OpenAI models are not defined because they are fetched from the API.""" - FALCON_7B = AIModelTextGen( - name="falcon-7b", - tokens=8192, - url="https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct", - config_message='>>DOMAIN<>SUMMARY<<{history}\n\n{user_prompt}\n\n', - ai_template=">>ANSWER<<{content}", - human_template=">>QUESTION<>ANSWER<<", - system_template="System: {content}", - ) - STARCHAT_BETA = AIModelTextGen( - name="starchat-beta", - tokens=8192, - url="https://api-inference.huggingface.co/models/HuggingFaceH4/starchat-beta", - config_message="{history}\n{user_prompt}\n<|assistant|>\n", - system_template="<|system|>\n{content}\n<|end|>", - ai_template="<|assistant|>\n{content}\n<|end|>", - human_template="<|user|>\n{content}\n<|end|>", - stop_sequences=["<|end|>"], - ) + # FALCON_7B = OllamaAIModel(...) + pass _custom_ai_models: list[AIModel] = [] diff --git a/esbmc_ai/chats/base_chat_interface.py b/esbmc_ai/chats/base_chat_interface.py index 6344734..defb2bd 100644 --- a/esbmc_ai/chats/base_chat_interface.py +++ b/esbmc_ai/chats/base_chat_interface.py @@ -3,14 +3,12 @@ from abc import abstractmethod from typing import Optional -from langchain.base_language import BaseLanguageModel from langchain.schema import ( - AIMessage, BaseMessage, HumanMessage, - LLMResult, PromptValue, ) +from langchain_core.language_models import BaseChatModel from esbmc_ai.config import ChatPromptSettings from esbmc_ai.chat_response import ChatResponse, FinishReason @@ -18,10 +16,13 @@ class BaseChatInterface(object): + """Base class for interacting with an LLM. It allows for interactions with + text generation LLMs and also chat LLMs.""" + def __init__( self, ai_model_agent: ChatPromptSettings, - llm: BaseLanguageModel, + llm: BaseChatModel, ai_model: AIModel, ) -> None: super().__init__() @@ -31,7 +32,7 @@ def __init__( ai_model_agent.system_messages.messages ) self.messages: list[BaseMessage] = [] - self.llm: BaseLanguageModel = llm + self.llm: BaseChatModel = llm @abstractmethod def compress_message_stack(self) -> None: @@ -84,25 +85,9 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse: all_messages = self._system_messages.copy() all_messages.extend(self.messages.copy()) - # Transform message stack to ChatPromptValue: If this is a ChatLLM then the - # function will simply be an identity function that does nothing and simply - # returns the messages as a ChatPromptValue. If this is a text generation - # LLM, then the function should inject the config message around the - # conversation to make the LLM behave like a ChatLLM. - # Do not replace any values. - message_prompts: PromptValue = self.ai_model.apply_chat_template( - messages=all_messages, - ) - response: ChatResponse try: - result: LLMResult = self.llm.generate_prompt( - prompts=[message_prompts], - ) - - response_message: BaseMessage = AIMessage( - content=result.generations[0][0].text - ) + response_message: BaseMessage = self.llm.invoke(input=all_messages) self.push_to_message_stack(message=response_message) diff --git a/esbmc_ai/chats/solution_generator.py b/esbmc_ai/chats/solution_generator.py index 6187ee3..030e4e3 100644 --- a/esbmc_ai/chats/solution_generator.py +++ b/esbmc_ai/chats/solution_generator.py @@ -2,8 +2,8 @@ from re import S from typing import Optional +from langchain_core.language_models import BaseChatModel from typing_extensions import override -from langchain.base_language import BaseLanguageModel from langchain.schema import BaseMessage, HumanMessage from esbmc_ai.chat_response import ChatResponse, FinishReason @@ -83,7 +83,7 @@ class SolutionGenerator(BaseChatInterface): def __init__( self, ai_model_agent: DynamicAIModelAgent | ChatPromptSettings, - llm: BaseLanguageModel, + llm: BaseChatModel, ai_model: AIModel, scenario: str = "", source_code_format: str = "full", diff --git a/esbmc_ai/chats/user_chat.py b/esbmc_ai/chats/user_chat.py index 3c1cdd6..c044d6b 100644 --- a/esbmc_ai/chats/user_chat.py +++ b/esbmc_ai/chats/user_chat.py @@ -1,8 +1,8 @@ # Author: Yiannis Charalambous 2023 +from langchain_core.language_models import BaseChatModel from typing_extensions import override -from langchain.base_language import BaseLanguageModel from langchain.memory import ConversationSummaryMemory from langchain_community.chat_message_histories import ChatMessageHistory @@ -21,7 +21,7 @@ def __init__( self, ai_model_agent: ChatPromptSettings, ai_model: AIModel, - llm: BaseLanguageModel, + llm: BaseChatModel, source_code: str, esbmc_output: str, set_solution_messages: AIAgentConversation, diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index 28b4ccc..d99141c 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -153,48 +153,36 @@ def _load_custom_ai(config: dict) -> None: assert ( isinstance(custom_ai_max_tokens, int) and custom_ai_max_tokens > 0 ), f'custom_ai_max_tokens in ai_custom entry "{name}" needs to be an int and greater than 0.' + # Load the URL custom_ai_url, ok = _load_config_value( config_file=ai_data, name="url", ) assert ok, f'url field not found in "ai_custom" entry "{name}".' - stop_sequences, ok = _load_config_value( + + # Get provider type + server_type, ok = _load_config_value( config_file=ai_data, - name="stop_sequences", - ) - # Load the config message - config_message: dict[str, str] = ai_data["config_message"] - template, ok = _load_config_value( - config_file=config_message, - name="template", - ) - human, ok = _load_config_value( - config_file=config_message, - name="human", - ) - ai, ok = _load_config_value( - config_file=config_message, - name="ai", - ) - system, ok = _load_config_value( - config_file=config_message, - name="system", + name="server_type", + default="localhost:11434", ) + assert ok, f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" - # Add the custom AI. - add_custom_ai_model( - AIModelTextGen( + # Create correct type of LLM + llm: AIModel + match server_type: + case "ollama": + llm = OllamaAIModel( name=name, tokens=custom_ai_max_tokens, url=custom_ai_url, - config_message=template, - ai_template=ai, - human_template=human, - system_template=system, - stop_sequences=stop_sequences, - ) - ) + ) + case _: + raise NotImplementedError(f"The custom AI server type is not implemented: {server_type}") + + # Add the custom AI. + add_custom_ai_model(llm) def load_envs() -> None: From 1aafdd3b4421b3537a056ce321c4f47005c25732 Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Thu, 12 Sep 2024 22:12:02 +0100 Subject: [PATCH 3/5] Update config --- config.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.json b/config.json index dd1f40a..9e96240 100644 --- a/config.json +++ b/config.json @@ -90,10 +90,10 @@ "system": [ { "role": "System", - "content": "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location. " + "content": "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location that you need to fix. Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text." } ], - "initial": "Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text. The ESBMC output is {esbmc_output} The source code is {source_code}" + "initial": "The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text." } } } \ No newline at end of file From 00346e7aef8c6abe0b1fd61a8822ef28c1646923 Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Thu, 12 Sep 2024 22:12:27 +0100 Subject: [PATCH 4/5] Updated tests to have Ollama support --- ...st_base_chat_interface.test_send_message.out | 17 ++++++++++++++--- tests/regtest/test_base_chat_interface.py | 17 ++++++++++++----- tests/test_ai_models.py | 10 +++------- tests/test_base_chat_interface.py | 8 ++++---- tests/test_config.py | 7 +------ tests/test_latest_state_solution_generator.py | 4 ++-- tests/test_reverse_order_solution_generator.py | 4 ++-- tests/test_user_chat.py | 4 ++-- 8 files changed, 40 insertions(+), 31 deletions(-) diff --git a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out b/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out index a5824ff..11993c6 100644 --- a/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out +++ b/tests/regtest/_regtest_outputs/test_base_chat_interface.test_send_message.out @@ -1,3 +1,14 @@ -(SystemMessage(content='System message'), AIMessage(content='OK')) -[HumanMessage(content='Test 1'), AIMessage(content='OK 1'), HumanMessage(content='Test 2'), AIMessage(content='OK 2'), HumanMessage(content='Test 3'), AIMessage(content='OK 3')] -[ChatResponse(message=AIMessage(content='OK 1'), total_tokens=15, finish_reason=), ChatResponse(message=AIMessage(content='OK 2'), total_tokens=23, finish_reason=), ChatResponse(message=AIMessage(content='OK 3'), total_tokens=31, finish_reason=)] +System Messages: +system: System message +ai: OK +Chat Messages: +human: Test 1 +ai: OK 1 +human: Test 2 +ai: OK 2 +human: Test 3 +ai: OK 3 +Responses: +ai(15 - FinishReason.stop): OK 1 +ai(23 - FinishReason.stop): OK 2 +ai(31 - FinishReason.stop): OK 3 diff --git a/tests/regtest/test_base_chat_interface.py b/tests/regtest/test_base_chat_interface.py index 8945f9a..80151ee 100644 --- a/tests/regtest/test_base_chat_interface.py +++ b/tests/regtest/test_base_chat_interface.py @@ -1,8 +1,8 @@ # Author: Yiannis Charalambous +from langchain_core.language_models import FakeListChatModel import pytest -from langchain_community.llms import FakeListLLM from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage from esbmc_ai.ai_models import AIModel @@ -14,7 +14,7 @@ @pytest.fixture def setup(): responses: list[str] = ["OK 1", "OK 2", "OK 3"] - llm: FakeListLLM = FakeListLLM(responses=responses) + llm: FakeListChatModel = FakeListChatModel(responses=responses) ai_model: AIModel = AIModel("test", 1024) @@ -64,6 +64,13 @@ def test_send_message(regtest, setup) -> None: ] with regtest: - print(chat.ai_model_agent.system_messages.messages) - print(chat.messages) - print(chat_responses) + print("System Messages:") + for m in chat.ai_model_agent.system_messages.messages: + print(f"{m.type}: {m.content}") + print("Chat Messages:") + for m in chat.messages: + print(f"{m.type}: {m.content}") + print("Responses:") + for m in chat_responses: + print(f"{m.message.type}({m.total_tokens} - {m.finish_reason}): {m.message.content}") + diff --git a/tests/test_ai_models.py b/tests/test_ai_models.py index 028bff3..7cc8397 100644 --- a/tests/test_ai_models.py +++ b/tests/test_ai_models.py @@ -15,7 +15,7 @@ AIModel, _AIModels, get_ai_model_by_name, - AIModelTextGen, + OllamaAIModel, _get_openai_model_max_tokens, ) @@ -93,19 +93,15 @@ def test_apply_chat_template() -> None: assert prompt == ChatPromptValue(messages=messages) # Test the text gen method - custom_model_2: AIModelTextGen = AIModelTextGen( + custom_model_2: OllamaAIModel = OllamaAIModel( name="custom", tokens=999, url="", - config_message="{history}\n\n{user_prompt}", - ai_template="AI: {content}", - human_template="Human: {content}", - system_template="System: {content}", ) prompt_text: str = custom_model_2.apply_chat_template(messages=messages).to_string() - assert prompt_text == "System: M1\n\nHuman: M2\n\nAI: M3" + assert prompt_text == "System: M1\nHuman: M2\nAI: M3" def test_escape_messages() -> None: diff --git a/tests/test_base_chat_interface.py b/tests/test_base_chat_interface.py index 42b7b1a..e4738c4 100644 --- a/tests/test_base_chat_interface.py +++ b/tests/test_base_chat_interface.py @@ -1,8 +1,8 @@ # Author: Yiannis Charalambous +from langchain_core.language_models import FakeListChatModel import pytest -from langchain_community.llms import FakeListLLM from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage from esbmc_ai.ai_models import AIModel from esbmc_ai.chats.base_chat_interface import BaseChatInterface @@ -23,7 +23,7 @@ def setup(): def test_push_message_stack(setup) -> None: - llm: FakeListLLM = FakeListLLM(responses=[]) + llm: FakeListChatModel = FakeListChatModel(responses=[]) ai_model, system_messages = setup @@ -56,7 +56,7 @@ def test_push_message_stack(setup) -> None: def test_send_message(setup) -> None: responses: list[str] = ["OK 1", "OK 2", "OK 3"] - llm: FakeListLLM = FakeListLLM(responses=responses) + llm: FakeListChatModel = FakeListChatModel(responses=responses) ai_model, system_messages = setup @@ -98,7 +98,7 @@ def test_apply_template() -> None: "Replace with also replaced message", "replacedalso replaced", ] - llm: FakeListLLM = FakeListLLM(responses=responses) + llm: FakeListChatModel = FakeListChatModel(responses=responses) chat: BaseChatInterface = BaseChatInterface( ai_model_agent=ChatPromptSettings( diff --git a/tests/test_config.py b/tests/test_config.py index 6fa3741..9b774b0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -80,12 +80,7 @@ def test_load_custom_ai() -> None: "example_ai": { "max_tokens": 4096, "url": "www.example.com", - "config_message": { - "template": "example", - "system": "{content}", - "ai": "{content}", - "human": "{content}", - }, + "server_type": "ollama" } } diff --git a/tests/test_latest_state_solution_generator.py b/tests/test_latest_state_solution_generator.py index 679e2d6..1db2ffa 100644 --- a/tests/test_latest_state_solution_generator.py +++ b/tests/test_latest_state_solution_generator.py @@ -1,10 +1,10 @@ # Author: Yiannis Charalambous from typing import Optional +from langchain_core.language_models import FakeListChatModel import pytest from langchain.schema import HumanMessage, AIMessage, SystemMessage -from langchain_community.llms.fake import FakeListLLM from esbmc_ai.ai_models import AIModel from esbmc_ai.chat_response import ChatResponse @@ -14,7 +14,7 @@ @pytest.fixture(scope="function") def setup_llm_model(): - llm = FakeListLLM( + llm = FakeListChatModel( responses=[ "This is a test response", "Another test response", diff --git a/tests/test_reverse_order_solution_generator.py b/tests/test_reverse_order_solution_generator.py index fa8c4ef..ebcbef8 100644 --- a/tests/test_reverse_order_solution_generator.py +++ b/tests/test_reverse_order_solution_generator.py @@ -1,5 +1,6 @@ # Author: Yiannis Charalambous +from langchain_core.language_models import FakeListChatModel import pytest from langchain.schema import ( @@ -7,7 +8,6 @@ AIMessage, SystemMessage, ) -from langchain_community.llms.fake import FakeListLLM from esbmc_ai.ai_models import AIModel from esbmc_ai.config import AIAgentConversation, ChatPromptSettings @@ -16,7 +16,7 @@ @pytest.fixture(scope="function") def setup_llm_model(): - llm = FakeListLLM( + llm = FakeListChatModel( responses=[ "This is a test response", "Another test response", diff --git a/tests/test_user_chat.py b/tests/test_user_chat.py index f0198ca..0d39909 100644 --- a/tests/test_user_chat.py +++ b/tests/test_user_chat.py @@ -1,8 +1,8 @@ # Author: Yiannis Charalambous +from langchain_core.language_models import FakeListChatModel import pytest -from langchain_community.llms import FakeListLLM from langchain.schema import AIMessage, SystemMessage from esbmc_ai.ai_models import AIModel @@ -30,7 +30,7 @@ def setup(): temperature=1.0, ), ai_model=AIModel(name="test", tokens=12), - llm=FakeListLLM(responses=[summary_text]), + llm=FakeListChatModel(responses=[summary_text]), source_code="This is source code", esbmc_output="This is esbmc output", set_solution_messages=AIAgentConversation.from_seq(set_solution_messages), From 7c0195d3deaf85ebf406700fb2da2a4ef926c7c9 Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Thu, 12 Sep 2024 22:21:06 +0100 Subject: [PATCH 5/5] Update --- esbmc_ai/chats/__init__.py | 3 +++ esbmc_ai/chats/user_chat.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/esbmc_ai/chats/__init__.py b/esbmc_ai/chats/__init__.py index e2b4c2a..2844a35 100644 --- a/esbmc_ai/chats/__init__.py +++ b/esbmc_ai/chats/__init__.py @@ -1,5 +1,8 @@ # Author: Yiannis Charalambous +"""This module contains different chat interfaces. Along with `BaseChatInterface` +that provides necessary boilet-plate for implementing an LLM based chat.""" + from .base_chat_interface import BaseChatInterface from .latest_state_solution_generator import LatestStateSolutionGenerator from .solution_generator import SolutionGenerator diff --git a/esbmc_ai/chats/user_chat.py b/esbmc_ai/chats/user_chat.py index c044d6b..441c59a 100644 --- a/esbmc_ai/chats/user_chat.py +++ b/esbmc_ai/chats/user_chat.py @@ -1,12 +1,12 @@ # Author: Yiannis Charalambous 2023 -from langchain_core.language_models import BaseChatModel from typing_extensions import override from langchain.memory import ConversationSummaryMemory +from langchain.schema import BaseMessage, SystemMessage +from langchain_core.language_models import BaseChatModel from langchain_community.chat_message_histories import ChatMessageHistory -from langchain.schema import BaseMessage, SystemMessage from esbmc_ai.config import AIAgentConversation, ChatPromptSettings from esbmc_ai.ai_models import AIModel