Skip to content

Commit

Permalink
Agentchat refactor (#4062)
Browse files Browse the repository at this point in the history
* Agentchat refactor

* Move termination stop message to a separate field in task result

* Update quick start example

* Use string stop reason instead of stop message in task result for simpler API

* Use main function
  • Loading branch information
ekzhu authored Nov 5, 2024
1 parent 1098768 commit c3283c6
Show file tree
Hide file tree
Showing 18 changed files with 284 additions and 412 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, Field, model_validator

from .. import EVENT_LOGGER_NAME
from ..base import Response
Expand All @@ -33,30 +33,6 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class ToolCallEvent(BaseModel):
"""A tool call event."""

source: str
"""The source of the event."""

tool_calls: List[FunctionCall]
"""The tool call message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolCallResultEvent(BaseModel):
"""A tool call result event."""

source: str
"""The source of the event."""

tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class Handoff(BaseModel):
"""Handoff configuration for :class:`AssistantAgent`."""

Expand Down Expand Up @@ -264,19 +240,21 @@ async def on_messages_stream(

# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage))
yield ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
yield ToolCallResultMessage(content=results, source=self.name)
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Detect handoff requests.
handoffs: List[Handoff] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from autogen_core.base import CancellationToken

from ..base import ChatAgent, Response, TaskResult
from ..messages import ChatMessage, InnerMessage, TextMessage
from ..messages import AgentMessage, ChatMessage, InnerMessage, TextMessage


class BaseChatAgent(ChatAgent, ABC):
Expand Down Expand Up @@ -62,7 +62,7 @@ async def run(
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
response = await self.on_messages([first_message], cancellation_token)
messages: List[InnerMessage | ChatMessage] = [first_message]
messages: List[AgentMessage] = [first_message]
if response.inner_messages is not None:
messages += response.inner_messages
messages.append(response.chat_message)
Expand All @@ -73,14 +73,14 @@ async def run_stream(
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
and the final task result as the last item in the stream."""
if cancellation_token is None:
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
yield first_message
messages: List[InnerMessage | ChatMessage] = [first_message]
messages: List[AgentMessage] = [first_message]
async for message in self.on_messages_stream([first_message], cancellation_token):
if isinstance(message, Response):
yield message.chat_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@

from autogen_core.base import CancellationToken

from ..messages import ChatMessage, InnerMessage
from ..messages import AgentMessage


@dataclass
class TaskResult:
"""Result of running a task."""

messages: Sequence[InnerMessage | ChatMessage]
messages: Sequence[AgentMessage]
"""Messages produced by the task."""

stop_reason: str | None = None
"""The reason the task stopped."""


class TaskRunner(Protocol):
"""A task runner."""
Expand All @@ -31,7 +34,7 @@ def run_stream(
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
:class:`TaskResult` as the last item in the stream."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import List, Sequence

from ..messages import ChatMessage, StopMessage
from ..messages import AgentMessage, StopMessage


class TerminatedException(BaseException): ...
Expand Down Expand Up @@ -50,7 +50,7 @@ def terminated(self) -> bool:
...

@abstractmethod
async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
"""Check if the conversation should be terminated based on the messages received
since the last time the condition was called.
Return a StopMessage if the conversation should be terminated, or None otherwise.
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self, *conditions: TerminationCondition) -> None:
def terminated(self) -> bool:
return all(condition.terminated for condition in self._conditions)

async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached.")
# Check all remaining conditions.
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(self, *conditions: TerminationCondition) -> None:
def terminated(self) -> bool:
return any(condition.terminated for condition in self._conditions)

async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
if self.terminated:
raise RuntimeError("Termination condition has already been reached")
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,18 @@
import sys
from datetime import datetime

from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..messages import ChatMessage, StopMessage, TextMessage
from ..teams._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
)
from pydantic import BaseModel


class ConsoleLogHandler(logging.Handler):
@staticmethod
def serialize_chat_message(message: ChatMessage) -> str:
if isinstance(message, TextMessage | StopMessage):
return message.content
else:
d = message.model_dump()
assert "content" in d
return json.dumps(d["content"], indent=2)

def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, GroupChatPublishEvent):
if record.msg.source is None:
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}]:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
)
else:
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], {record.msg.source.type}:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
)
sys.stdout.flush()
elif isinstance(record.msg, ToolCallEvent):
sys.stdout.write(
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, ToolCallResultEvent):
sys.stdout.write(
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
sys.stdout.write(
f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}"
)
sys.stdout.flush()
elif isinstance(record.msg, TerminationEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Termination:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
if isinstance(record.msg, BaseModel):
record.msg = json.dumps(
{
"timestamp": ts,
"message": record.msg.model_dump_json(indent=2),
"type": record.msg.__class__.__name__,
},
)
sys.stdout.flush()
else:
raise ValueError(f"Unexpected log record: {record.msg}")
sys.stdout.write(f"{record.msg}\n")
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import json
import logging
from dataclasses import asdict, is_dataclass
from datetime import datetime
from typing import Any

from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..teams._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
)
from pydantic import BaseModel


class FileLogHandler(logging.Handler):
Expand All @@ -20,65 +13,12 @@ def __init__(self, filename: str) -> None:

def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent):
log_entry = json.dumps(
if isinstance(record.msg, BaseModel):
record.msg = json.dumps(
{
"timestamp": ts,
"source": record.msg.source,
"agent_message": record.msg.agent_message.model_dump(),
"message": record.msg.model_dump(),
"type": record.msg.__class__.__name__,
},
default=self.json_serializer,
)
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"source": record.msg.source,
"selected_speaker": record.msg.selected_speaker,
"type": "SelectSpeakerEvent",
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_calls": record.msg.model_dump(),
"type": "ToolCallEvent",
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallResultEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_call_results": record.msg.model_dump(),
"type": "ToolCallResultEvent",
},
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)

def close(self) -> None:
self.file_handler.close()
super().close()

@staticmethod
def json_serializer(obj: Any) -> Any:
if is_dataclass(obj) and not isinstance(obj, type):
return asdict(obj)
elif isinstance(obj, type):
return str(obj)
return str(obj)
self.file_handler.emit(record)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class BaseMessage(BaseModel):
Expand All @@ -14,6 +14,8 @@ class BaseMessage(BaseModel):
models_usage: RequestUsage | None = None
"""The model client usage incurred when producing this message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class TextMessage(BaseMessage):
"""A text message."""
Expand Down Expand Up @@ -75,6 +77,10 @@ class ToolCallResultMessage(BaseMessage):
"""Messages for agent-to-agent communication."""


AgentMessage = InnerMessage | ChatMessage
"""All message types."""


__all__ = [
"BaseMessage",
"TextMessage",
Expand All @@ -85,4 +91,6 @@ class ToolCallResultMessage(BaseMessage):
"ToolCallMessage",
"ToolCallResultMessage",
"ChatMessage",
"InnerMessage",
"AgentMessage",
]
Loading

0 comments on commit c3283c6

Please sign in to comment.