Skip to content

Commit

Permalink
refactor: push memory windowing down to redis history to save io
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardzjl committed Jan 5, 2024
1 parent ce16656 commit 5104b5a
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 9 deletions.
8 changes: 8 additions & 0 deletions api/chatbot/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from langchain.chains import LLMChain
from langchain.memory.chat_memory import BaseChatMemory

from chatbot.memory import ChatbotMemory


class LLMConvChain(LLMChain):
"""Conversation chain that persists message separately on chain start and end."""
Expand All @@ -15,6 +17,9 @@ def prep_inputs(self, inputs: dict[str, Any] | Any) -> dict[str, str]:
"""
inputs = super().prep_inputs(inputs)
# we need to access the history so we need to ensure it's BaseChatMemory and then we can access it by memory.chat_memory
if self.memory is not None and isinstance(self.memory, ChatbotMemory):
message = inputs[self.user_input_variable]
self.memory.history.add_user_message(message)
if self.memory is not None and isinstance(self.memory, BaseChatMemory):
message = inputs[self.user_input_variable]
self.memory.chat_memory.add_user_message(message)
Expand All @@ -31,6 +36,9 @@ def prep_outputs(
"""
self._validate_outputs(outputs)
# we need to access the history so we need to ensure it's BaseChatMemory and then we can access it by memory.chat_memory
if self.memory is not None and isinstance(self.memory, ChatbotMemory):
text = outputs[self.output_key]
self.memory.history.add_ai_message(text)
if self.memory is not None and isinstance(self.memory, BaseChatMemory):
text = outputs[self.output_key]
self.memory.chat_memory.add_ai_message(text)
Expand Down
10 changes: 5 additions & 5 deletions api/chatbot/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from fastapi import Depends, Header
from langchain.chains.base import Chain
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference,
)
Expand All @@ -19,7 +18,8 @@
from chatbot.callbacks import TracingLLMCallbackHandler
from chatbot.chains import LLMConvChain
from chatbot.config import settings
from chatbot.history import ContextAwareMessageHistory
from chatbot.history import ChatbotMessageHistory
from chatbot.memory import ChatbotMemory
from chatbot.prompts.chatml import AI_SUFFIX, HUMAN_PREFIX, ChatMLPromptTemplate


Expand All @@ -42,7 +42,7 @@ def EmailHeader(alias: Optional[str] = None, **kwargs):


def MessageHistory() -> BaseChatMessageHistory:
return ContextAwareMessageHistory(
return ChatbotMessageHistory(
url=str(settings.redis_om_url),
key_prefix="chatbot:messages:",
session_id="sid", # a fake session id as it is required
Expand All @@ -52,10 +52,10 @@ def MessageHistory() -> BaseChatMessageHistory:
def ChatMemory(
history: Annotated[BaseChatMessageHistory, Depends(MessageHistory)]
) -> BaseMemory:
return ConversationBufferWindowMemory(
return ChatbotMemory(
memory_key="history",
input_key="input",
chat_memory=history,
history=history,
return_messages=True,
)

Expand Down
9 changes: 8 additions & 1 deletion api/chatbot/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@
from chatbot.utils import utcnow


class ContextAwareMessageHistory(RedisChatMessageHistory):
class ChatbotMessageHistory(RedisChatMessageHistory):
"""Context aware history which also persists extra information in `additional_kwargs`."""

@property
def key(self) -> str:
"""Construct the record key to use"""
return self.key_prefix + (session_id.get() or self.session_id)

def windowed_messages(self, window_size: int = 5) -> list[BaseMessage]:
"""Retrieve the last k pairs of messages from Redis"""
_items = self.redis_client.lrange(self.key, -window_size * 2, -1)
items = [json.loads(m.decode("utf-8")) for m in _items]
messages = messages_from_dict(items)
return messages

@property
def messages(self) -> list[BaseMessage]: # type: ignore
"""Retrieve the messages from Redis"""
Expand Down
80 changes: 80 additions & 0 deletions api/chatbot/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Optional

from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.utils import get_prompt_input_key
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.memory import BaseMemory
from langchain_core.messages import BaseMessage
from pydantic.v1 import Field, validator

from chatbot.history import ChatbotMessageHistory


class ChatbotMemory(BaseMemory):
history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
output_key: Optional[str] = None
input_key: Optional[str] = None
return_messages: bool = True
memory_key: str = "history" #: :meta private:
k: int = 5
"""Number of messages to store in buffer."""

@validator("k")
def k_must_be_positive(cls, v: int) -> int:
if v <= 0:
raise ValueError("k must be greater than 0")
return v

@property
def buffer(self) -> str | list[BaseMessage]:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str

@property
def buffer_as_messages(self) -> list[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
if isinstance(self.history, ChatbotMessageHistory):
return self.history.windowed_messages(self.k)
return self.history.messages[-self.k * 2 :] if self.k > 0 else []

@property
def buffer_as_str(self) -> str:
# not going to support this
raise NotImplementedError

@property
def memory_variables(self) -> list[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]

def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.history.add_user_message(input_str)
self.history.add_ai_message(output_str)

def clear(self) -> None:
"""Clear memory contents."""
self.history.clear()

def _get_input_output(
self, inputs: dict[str, Any], outputs: dict[str, str]
) -> tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
6 changes: 3 additions & 3 deletions api/chatbot/routers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from chatbot.context import session_id
from chatbot.dependencies import MessageHistory, UserIdHeader
from chatbot.history import ContextAwareMessageHistory
from chatbot.history import ChatbotMessageHistory

router = APIRouter(
prefix="/api/conversations/{conversation_id}/messages",
Expand All @@ -24,7 +24,7 @@ async def thumbup(
"""Using message index as the uuid is in the message body which is json dumped into redis,
and is impossible to filter on.
Also separate thumbup and thumbdown into two endpoints to make it more RESTful."""
if not isinstance(history, ContextAwareMessageHistory):
if not isinstance(history, ChatbotMessageHistory):
# should never happen
return
session_id.set(f"{userid}:{conversation_id}")
Expand All @@ -44,7 +44,7 @@ async def thumbdown(
"""Using message index as the uuid is in the message body which is json dumped into redis,
and is impossible to filter on.
Also separate thumbup and thumbdown into two endpoints to make it more RESTful."""
if not isinstance(history, ContextAwareMessageHistory):
if not isinstance(history, ChatbotMessageHistory):
# should never happen
return
session_id.set(f"{userid}:{conversation_id}")
Expand Down

0 comments on commit 5104b5a

Please sign in to comment.