Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft, Feedback Needed] Memory in AgentChat #4438

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
from ..memory._base_memory import Memory, MemoryQueryResult

event_logger = logging.getLogger(EVENT_LOGGER_NAME)

Expand Down Expand Up @@ -216,7 +217,8 @@ def __init__(
name: str,
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
tools: List[Tool | Callable[..., Any] |
Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
Expand All @@ -226,14 +228,19 @@ def __init__(
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._memory = memory

self._system_messages: List[SystemMessage | UserMessage |
AssistantMessage | FunctionExecutionResultMessage] = []
if system_message is None:
self._system_messages = []
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
raise ValueError(
"The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
Expand All @@ -242,7 +249,8 @@ def __init__(
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
self._tools.append(FunctionTool(
tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
Expand All @@ -254,19 +262,22 @@ def __init__(
self._handoffs: Dict[str, HandoffBase] = {}
if handoffs is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling, which is needed for handoffs.")
raise ValueError(
"The model does not support function calling, which is needed for handoffs.")
for handoff in handoffs:
if isinstance(handoff, str):
handoff = HandoffBase(target=handoff)
if isinstance(handoff, HandoffBase):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
raise ValueError(
f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
raise ValueError(
f"Handoff names must be unique: {handoff_tool_names}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
raise ValueError(
Expand All @@ -288,7 +299,8 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
raise AssertionError(
"The stream should have returned the final result.")

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
Expand All @@ -297,41 +309,54 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
self._model_context.append(UserMessage(
content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentEvent | ChatMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
# Prepare messages for model with memory context if available
llm_messages = self._system_messages
if memory_context:
llm_messages = llm_messages + \
[SystemMessage(content=memory_context)]
llm_messages = llm_messages + self._model_context

# Generate inference result
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
self._model_context.append(AssistantMessage(
content=result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
return

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
assert isinstance(result.content, list) and all(
isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallRequestEvent(
content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
tool_call_result_msg = ToolCallExecutionEvent(
content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
self._model_context.append(
FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -349,7 +374,8 @@ async def on_messages_stream(
)
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name),
inner_messages=inner_messages,
)
return
Expand All @@ -360,10 +386,12 @@ async def on_messages_stream(
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
self._model_context.append(AssistantMessage(
content=result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
else:
Expand All @@ -379,7 +407,8 @@ async def on_messages_stream(
)
tool_call_summary = "\n".join(tool_call_summaries)
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name),
chat_message=TextMessage(
content=tool_call_summary, source=self.name),
inner_messages=inner_messages,
)

Expand All @@ -390,9 +419,11 @@ async def _execute_tool_call(
try:
if not self._tools + self._handoff_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
tool = next((t for t in self._tools +
self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
raise ValueError(
f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Protocol, Union, runtime_checkable

from autogen_core import CancellationToken, Image
from pydantic import BaseModel, ConfigDict, Field
from autogen_core.model_context import ChatCompletionContext


class MimeType(Enum):
"""Supported MIME types for memory content."""

TEXT = "text/plain"
JSON = "application/json"
MARKDOWN = "text/markdown"
IMAGE = "image/*"
BINARY = "application/octet-stream"


ContentType = Union[str, bytes, dict, Image]


class MemoryContent(BaseModel):
"""A content item with type information."""

content: ContentType
mime_type: MimeType

model_config = ConfigDict(arbitrary_types_allowed=True)


class BaseMemoryConfig(BaseModel):
"""Base configuration for memory implementations."""

k: int = Field(default=5, description="Number of results to return")
score_threshold: float | None = Field(default=None, description="Minimum relevance score")

model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryEntry(BaseModel):
"""A memory entry containing content and metadata."""

content: MemoryContent
"""The content item with type information."""

metadata: Dict[str, Any] = Field(default_factory=dict)
"""Optional metadata associated with the memory entry."""

timestamp: datetime = Field(default_factory=datetime.now)
"""When the memory was created."""

source: str | None = None
"""Optional source identifier for the memory."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryQueryResult(BaseModel):
"""Result from a memory query including the entry and its relevance score."""

entry: MemoryEntry
"""The memory entry."""

score: float
"""Relevance score for this result. Higher means more relevant."""

model_config = ConfigDict(arbitrary_types_allowed=True)


@runtime_checkable
class Memory(Protocol):
"""Protocol defining the interface for memory implementations."""

@property
def name(self) -> str | None:
"""The name of this memory implementation."""
...

@property
def config(self) -> BaseMemoryConfig:
"""The configuration for this memory implementation."""
...

async def transform(
self,
model_context: ChatCompletionContext,
) -> ChatCompletionContext:
"""
Transform the model context using relevant memory content.

Args:
model_context: The context to transform

Returns:
The transformed context
"""
...

async def query(
self,
query: MemoryContent,
cancellation_token: "CancellationToken | None" = None,
**kwargs: Any,
) -> List[MemoryQueryResult]:
"""
Query the memory store and return relevant entries.

Args:
query: Query content item
cancellation_token: Optional token to cancel operation
**kwargs: Additional implementation-specific parameters

Returns:
List of memory entries with relevance scores
"""
...

async def add(self, entry: MemoryEntry, cancellation_token: "CancellationToken | None" = None) -> None:
"""
Add a new entry to memory.

Args:
entry: The memory entry to add
cancellation_token: Optional token to cancel operation
"""
...

async def clear(self) -> None:
"""Clear all entries from memory."""
...

async def cleanup(self) -> None:
"""Clean up any resources used by the memory implementation."""
...
Loading