Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Fix Anthropic parallel tool call #10005

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any, ClassVar

if sys.version_info >= (3, 12):
Expand All @@ -26,7 +26,10 @@
from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import (
AnthropicChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS
from semantic_kernel.connectors.ai.anthropic.services.utils import (
MESSAGE_CONVERTERS,
update_settings_from_function_call_configuration,
)
from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
Expand All @@ -43,10 +46,10 @@
from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceInvalidResponseError,
ServiceResponseException,
)
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
trace_chat_completion,
Expand Down Expand Up @@ -130,6 +133,19 @@ def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"]
def service_url(self) -> str | None:
return str(self.async_client.base_url)

@override
def _update_function_choice_settings_callback(
self,
) -> Callable[[FunctionCallChoiceConfiguration, "PromptExecutionSettings", FunctionChoiceType], None]:
return update_settings_from_function_call_configuration

@override
def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None:
if hasattr(settings, "tool_choice"):
settings.tool_choice = None
if hasattr(settings, "tools"):
settings.tools = None

@override
@trace_chat_completion(MODEL_PROVIDER_NAME)
async def _inner_get_chat_message_contents(
Expand Down Expand Up @@ -171,6 +187,7 @@ async def _inner_get_streaming_chat_message_contents(
async for message in response:
yield message

@override
def _prepare_chat_history_for_request(
self,
chat_history: "ChatHistory",
Expand All @@ -194,14 +211,37 @@ def _prepare_chat_history_for_request(
system_message_content = None
system_message_count = 0
formatted_messages: list[dict[str, Any]] = []
for message in chat_history.messages:
# Skip system messages after the first one is found
if message.role == AuthorRole.SYSTEM:
for i in range(len(chat_history)):
prev_message = chat_history[i - 1] if i > 0 else None
curr_message = chat_history[i]
if curr_message.role == AuthorRole.SYSTEM:
# Skip system messages after the first one is found
if system_message_count == 0:
system_message_content = message.content
system_message_content = curr_message.content
system_message_count += 1
elif curr_message.role == AuthorRole.USER or curr_message.role == AuthorRole.ASSISTANT:
formatted_messages.append(MESSAGE_CONVERTERS[curr_message.role](curr_message))
elif curr_message.role == AuthorRole.TOOL:
if prev_message is None:
# Under no circumstances should a tool message be the first message in the chat history
raise ServiceInvalidRequestError("Tool message found without a preceding message.")
if prev_message.role == AuthorRole.USER or prev_message.role == AuthorRole.SYSTEM:
# A tool message should not be found after a user or system message
# Please NOTE that in SK there are the USER role and the TOOL role, but in Anthropic
# the tool messages are considered as USER messages. We are checking against the SK roles.
raise ServiceInvalidRequestError("Tool message found after a user or system message.")

formatted_message = MESSAGE_CONVERTERS[curr_message.role](curr_message)
if prev_message.role == AuthorRole.ASSISTANT:
# The first tool message after an assistant message should be a new message
formatted_messages.append(formatted_message)
else:
# Append the tool message to the previous tool message.
# This indicates that the assistant message requested multiple parallel tool calls.
# Anthropic requires that parallel Tool messages are grouped together in a single message.
formatted_messages[-1][content_key] += formatted_message[content_key]
else:
formatted_messages.append(MESSAGE_CONVERTERS[message.role](message))
raise ServiceInvalidRequestError(f"Unsupported role in chat history: {curr_message.role}")

if system_message_count > 1:
logger.warning(
Expand Down Expand Up @@ -277,50 +317,6 @@ def _create_streaming_chat_message_content(
items=items,
)

def update_settings_from_function_call_configuration_anthropic(
self,
function_choice_configuration: FunctionCallChoiceConfiguration,
settings: "PromptExecutionSettings",
type: "FunctionChoiceType",
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
if (
function_choice_configuration.available_functions
and hasattr(settings, "tools")
and hasattr(settings, "tool_choice")
):
settings.tools = [
self.kernel_function_metadata_to_function_call_format_anthropic(f)
for f in function_choice_configuration.available_functions
]

if (
settings.function_choice_behavior
and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED
) or type == FunctionChoiceType.REQUIRED:
settings.tool_choice = {"type": "any"}
else:
settings.tool_choice = {"type": type.value}

def kernel_function_metadata_to_function_call_format_anthropic(
self,
metadata: KernelFunctionMetadata,
) -> dict[str, Any]:
"""Convert the kernel function metadata to function calling format."""
return {
"name": metadata.fully_qualified_name,
"description": metadata.description or "",
"input_schema": {
"type": "object",
"properties": {p.name: p.schema_data for p in metadata.parameters},
"required": [p.name for p in metadata.parameters if p.is_required],
},
}

@override
def _update_function_choice_settings_callback(self):
return self.update_settings_from_function_call_configuration_anthropic

async def _send_chat_request(self, settings: AnthropicChatPromptExecutionSettings) -> list["ChatMessageContent"]:
"""Send the chat request."""
try:
Expand Down Expand Up @@ -382,10 +378,3 @@ def _get_tool_calls_from_message(self, message: Message) -> list[FunctionCallCon
)

return tool_calls

@override
def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None:
if hasattr(settings, "tool_choice"):
settings.tool_choice = None
if hasattr(settings, "tools"):
settings.tools = None
74 changes: 59 additions & 15 deletions python/semantic_kernel/connectors/ai/anthropic/services/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from collections.abc import Callable, Mapping
from typing import Any

from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,29 +54,32 @@ def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]:
"type": "tool_use",
"id": item.id or "",
"name": item.name or "",
"input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""),
"input": item.arguments
if isinstance(item.arguments, Mapping)
else json.loads(item.arguments)
if item.arguments
else {},
})
else:
logger.warning(
f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}"
)

formatted_message: dict[str, Any] = {"role": "assistant", "content": []}

if message.content:
# Only include the text content if it is not empty.
# Otherwise, the Anthropic client will throw an error.
formatted_message["content"].append({ # type: ignore
"type": "text",
"text": message.content,
})
if tool_calls:
return {
"role": "assistant",
"content": [
{
"type": "text",
"text": message.content,
},
*tool_calls,
],
}
# Only include the tool calls if there are any.
# Otherwise, the Anthropic client will throw an error.
formatted_message["content"].extend(tool_calls) # type: ignore

return {
"role": "assistant",
"content": message.content,
}
return formatted_message


def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]:
Expand Down Expand Up @@ -108,3 +115,40 @@ def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]:
AuthorRole.ASSISTANT: _format_assistant_message,
AuthorRole.TOOL: _format_tool_message,
}


def update_settings_from_function_call_configuration(
function_choice_configuration: FunctionCallChoiceConfiguration,
settings: PromptExecutionSettings,
type: FunctionChoiceType,
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
if (
function_choice_configuration.available_functions
and hasattr(settings, "tools")
and hasattr(settings, "tool_choice")
):
settings.tools = [
kernel_function_metadata_to_function_call_format(f)
for f in function_choice_configuration.available_functions
]

if (
settings.function_choice_behavior and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED
) or type == FunctionChoiceType.REQUIRED:
settings.tool_choice = {"type": "any"}
else:
settings.tool_choice = {"type": type.value}


def kernel_function_metadata_to_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]:
"""Convert the kernel function metadata to function calling format."""
return {
"name": metadata.fully_qualified_name,
"description": metadata.description or "",
"input_schema": {
"type": "object",
"properties": {p.name: p.schema_data for p in metadata.parameters},
"required": [p.name for p in metadata.parameters if p.is_required],
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@
onnx_setup: bool = is_service_setup_for_testing(
["ONNX_GEN_AI_CHAT_MODEL_FOLDER"], raise_if_not_set=False
) # Tests are optional for ONNX
anthropic_setup: bool = is_service_setup_for_testing(
["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"], raise_if_not_set=False
) # We don't have an Anthropic deployment
anthropic_setup: bool = is_service_setup_for_testing(["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"])
# When testing Bedrock, after logging into AWS CLI this has been set, so we can use it to check if the service is setup
bedrock_setup: bool = is_service_setup_for_testing(["AWS_DEFAULT_REGION"], raise_if_not_set=False)

Expand Down
Loading
Loading