Skip to content

Commit

Permalink
[SDK/CLI] Add chat group poc (#2249)
Browse files Browse the repository at this point in the history
# Description

Add chat group sdk poc:

```python
from promptflow._sdk.entities._chat_group._chat_group import ChatGroup
from promptflow._sdk.entities._chat_group._chat_role import ChatRole

copilot = ChatRole(
    flow=FLOWS_DIR / "chat_group_copilot",
    role="assistant",
    inputs=dict(
        question=topic,
        model="gpt-3.5-turbo",
        conversation_history="${parent.conversation_history}",
    ),
)
simulation = ChatRole(
    flow=FLOWS_DIR / "chat_group_simulation",
    role="user",
    inputs=dict(
        topic=topic,
        persona="criticizer",
        conversation_history="${parent.conversation_history}",
    ),
)

chat_group = ChatGroup(
    roles=[copilot, simulation],
    max_turns=4,
    max_tokens=1000,
    max_time=1000,
    stop_signal="[STOP]",
)
chat_group.invoke()

# history has 4 records
history = chat_group.conversation_history
assert len(history) == 4
assert history[0][0] == history[2][0] == copilot.role
assert history[1][0] == history[3][0] == simulation.role

```

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
0mza987 authored Mar 13, 2024
1 parent 3b4bbe7 commit 39a9f8c
Show file tree
Hide file tree
Showing 15 changed files with 583 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/promptflow/promptflow/_sdk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def _prepare_home_dir() -> Path:

AzureMLWorkspaceTriad = namedtuple("AzureMLWorkspace", ["subscription_id", "resource_group_name", "workspace_name"])

# chat group
STOP_SIGNAL = "[STOP]"


class RunTypes:
BATCH = "batch"
Expand Down Expand Up @@ -445,6 +448,11 @@ class LineRunFieldName:
EVALUATIONS = "evaluations"


class ChatGroupSpeakOrder(str, Enum):
SEQUENTIAL = "sequential"
LLM = "llm"


TRACE_LIST_DEFAULT_LIMIT = 1000


Expand Down
18 changes: 18 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,21 @@ class ExperimentCommandRunError(SDKError):
"""Exception raised if experiment validation failed."""

pass


class ChatGroupError(SDKError):
"""Exception raised if chat group operation failed."""

pass


class ChatRoleError(SDKError):
"""Exception raised if chat agent operation failed."""

pass


class UnexpectedAttributeError(SDKError):
"""Exception raised if unexpected attribute is found."""

pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
212 changes: 212 additions & 0 deletions src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import time
from collections import Counter
from itertools import cycle
from typing import Any, Dict, List, Optional

from promptflow._sdk._constants import STOP_SIGNAL, ChatGroupSpeakOrder
from promptflow._sdk._errors import ChatGroupError
from promptflow._sdk.entities._chat_group._chat_role import ChatRole
from promptflow._utils.logger_utils import get_cli_sdk_logger

logger = get_cli_sdk_logger()


class ChatGroup:
"""Chat group entity, can invoke a multi-turn conversation with multiple chat roles.
:param roles: List of chat roles in the chat group.
:type roles: List[ChatRole]
:param speak_order: Speak order of the chat group. Default to be sequential which is the order of the roles list.
:type speak_order: ChatGroupSpeakOrder
:param max_turns: Maximum turns of the chat group. Default to be None which means no limit.
:type max_turns: Optional[int]
:param max_tokens: Maximum tokens of the chat group. Default to be None which means no limit.
:type max_tokens: Optional[int]
:param max_time: Maximum time of the chat group. Default to be None which means no limit.
:type max_time: Optional[int]
:param stop_signal: Stop signal of the chat group. Default to be "[STOP]".
:type stop_signal: Optional[str]
:param entry_role: Entry role of the chat group. Default to be None which means the first role in the roles list.
Only meaningful when speak order is not sequential.
"""

def __init__(
self,
roles: List[ChatRole],
speak_order: ChatGroupSpeakOrder = ChatGroupSpeakOrder.SEQUENTIAL,
max_turns: Optional[int] = None,
max_tokens: Optional[int] = None,
max_time: Optional[int] = None,
stop_signal: Optional[str] = STOP_SIGNAL,
entry_role: Optional[ChatRole] = None,
):
self._roles = roles
self._speak_order = speak_order
self._roles_dict, self._speak_order_list = self._prepare_roles(roles, entry_role, speak_order)
self._max_turns, self._max_tokens, self._max_time = self._validate_int_parameters(
max_turns, max_tokens, max_time
)
self._stop_signal = stop_signal
self._entry_role = entry_role
self._conversation_history = []

@property
def conversation_history(self):
return self._conversation_history

def _prepare_roles(self, roles: List[ChatRole], entry_role: ChatRole, speak_order: ChatGroupSpeakOrder):
"""Prepare roles"""
logger.info("Preparing roles in chat group.")
# check roles is a non-empty list of ChatRole
if not isinstance(roles, list) or len(roles) == 0 or not all(isinstance(role, ChatRole) for role in roles):
raise ChatGroupError(f"Agents should be a non-empty list of ChatRole. Got {roles!r} instead.")

# check entry_role is in roles
if entry_role is not None and entry_role not in roles:
raise ChatGroupError(f"Entry role {entry_role.role} is not in roles list {roles!r}.")

# check if there is duplicate role name
role_names = [role.role for role in roles]
if len(role_names) != len(set(role_names)):
counter = Counter(role_names)
duplicate_roles = [role for role in counter if counter[role] > 1]
raise ChatGroupError(f"Duplicate roles are not allowed: {duplicate_roles!r}.")

speak_order_list = self._get_speak_order(roles, entry_role, speak_order)
roles_dict = {role.role: role for role in roles}
return roles_dict, cycle(speak_order_list)

def _get_speak_order(
self, roles: List[ChatRole], entry_role: Optional[ChatRole], speak_order: ChatGroupSpeakOrder
) -> List[str]:
"""Calculate speak order"""
if speak_order == ChatGroupSpeakOrder.SEQUENTIAL:
if entry_role:
logger.warn(
f"Entry role {entry_role.role!r} is ignored when speak order is sequential. "
f"The first role in the list will be the entry role: {roles[0].role!r}."
)

speak_order_list = [role.role for role in roles]
logger.info(f"Role speak order is {speak_order_list!r}.")
return speak_order_list
else:
raise NotImplementedError(f"Speak order {speak_order.value!r} is not supported yet.")

@staticmethod
def _validate_int_parameters(max_turns: int, max_tokens: int, max_time: int):
"""Validate int parameters"""
logger.debug("Validating integer parameters for chat group.")
if max_turns is not None and not isinstance(max_turns, int):
raise ChatGroupError(f"max_turns should be an integer. Got {type(max_turns)!r} instead.")
if max_tokens is not None and not isinstance(max_tokens, int):
raise ChatGroupError(f"max_tokens should be an integer. Got {type(max_tokens)!r} instead.")
if max_time is not None and not isinstance(max_time, int):
raise ChatGroupError(f"max_time should be an integer. Got {type(max_time)!r} instead.")

logger.info(
f"Chat group maximum turns: {max_turns!r}, maximum tokens: {max_tokens!r}, maximum time: {max_time!r}."
)
return max_turns, max_tokens, max_time

def invoke(self):
"""Invoke the chat group"""
logger.info("Invoking chat group.")

chat_round = 0
chat_token = 0
chat_start_time = time.time()
while True:
chat_round += 1

# select current role and run
current_role = self._select_role()
logger.info(f"[Round {chat_round}] Chat role {current_role.role!r} is speaking.")
role_input_values = self._get_role_input_values(current_role)
# TODO: Hide flow-invoker and executor log for execution
result = current_role.invoke(**role_input_values)
logger.info(f"[Round {chat_round}] Chat role {current_role.role!r} result: {result!r}.")

# post process after role's invocation
self._update_information_with_result(current_role, result)
# TODO: Get used token from result and update chat_token

# check if the chat group should continue
continue_chat = self._check_continue_condition(chat_round, chat_token, chat_start_time)
if not continue_chat:
logger.info(
f"Chat group stops at round {chat_round!r}, token cost {chat_token!r}, "
f"time cost {round(time.time() - chat_start_time, 2)} seconds."
)
break

def _select_role(self) -> ChatRole:
"""Select next role"""
if self._speak_order == ChatGroupSpeakOrder.LLM:
return self._predict_next_role_with_llm()
next_role_name = next(self._speak_order_list)
return self._roles_dict[next_role_name]

def _get_role_input_values(self, role: ChatRole) -> Dict[str, Any]:
"""Get role input values"""
input_values = {}
for key in role.inputs:
role_input = role.inputs[key]
value = role_input.get("value", None)
# only conversation history binding needs to be processed here, other values are specified when
# initializing the chat role.
if value == "${parent.conversation_history}":
value = self._conversation_history
input_values[key] = value
logger.debug(f"Input values for role {role.role!r}: {input_values!r}")
return input_values

def _update_information_with_result(self, role: ChatRole, result: dict) -> None:
"""Update information with result"""
logger.debug(f"Updating chat group information with result from role {role.role!r}: {result!r}.")

# 1. update group chat history
self._update_conversation_history(role, result)

# 2. Update the role output value
for key, value in result.items():
if key in role.outputs:
role.outputs[key]["value"] = value

def _update_conversation_history(self, role: ChatRole, result: dict) -> None:
"""Update conversation history"""
self._conversation_history.append((role.role, result))

def _check_continue_condition(self, chat_round: int, chat_token: int, chat_start_time: float) -> bool:
continue_chat = True
time_cost = time.time() - chat_start_time

# 1. check if the chat round reaches the maximum
if self._max_turns is not None and chat_round >= self._max_turns:
logger.warn(f"Chat round {chat_round!r} reaches the maximum {self._max_turns!r}.")
continue_chat = False

# 2. check if the chat token reaches the maximum
if self._max_tokens is not None and chat_token >= self._max_tokens:
logger.warn(f"Chat token {chat_token!r} reaches the maximum {self._max_tokens!r}.")
continue_chat = False

# 3. check if the chat time reaches the maximum
if self._max_time is not None and time_cost >= self._max_time:
logger.warn(f"Chat time reaches the maximum {self._max_time!r} seconds.")
continue_chat = False

# TODO: How to apply stop signal since a role can have multiple outputs?
if continue_chat:
logger.info(
f"Chat group continues at round {chat_round!r}, "
f"token cost {chat_token!r}, time cost {round(time_cost, 2)!r} seconds."
)
return continue_chat

def _predict_next_role_with_llm(self) -> ChatRole:
"""Predict next role for non-deterministic speak order."""
raise NotImplementedError(f"Speak order {self._speak_order} is not supported yet.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from collections import UserDict
from typing import Any

from promptflow._sdk._errors import UnexpectedAttributeError


class AttrDict(UserDict):
def __init__(self, inputs: dict, **kwargs: Any):
super().__init__(**inputs, **kwargs)

def __getattr__(self, item: Any):
return self.__getitem__(item)

def __getitem__(self, item: Any):
if item not in self:
raise UnexpectedAttributeError(f"Invalid attribute {item!r}, expected one of {list(self.keys())}.")
res = super().__getitem__(item)
return res


class ChatRoleInputs(AttrDict):
"""Chat role inputs"""


class ChatRoleOutputs(AttrDict):
"""Chat role outputs"""


class ChatGroupInputs(AttrDict):
"""Chat group inputs"""


class ChatGroupOutputs(AttrDict):
"""Chat group outputs"""
Loading

0 comments on commit 39a9f8c

Please sign in to comment.