Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Anthropic function calling fixes #9938

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import (
AnthropicChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS
from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
Expand All @@ -34,7 +35,6 @@
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import ITEM_TYPES as STREAMING_ITEM_TYPES
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
Expand Down Expand Up @@ -192,69 +192,25 @@ def _prepare_chat_history_for_request(
A tuple containing the prepared chat history and the first SYSTEM message content.
"""
system_message_content = None
remaining_messages: list[dict[str, Any]] = []
system_message_found = False
system_message_count = 0
formatted_messages: list[dict[str, Any]] = []
for message in chat_history.messages:
# Skip system messages after the first one is found
if message.role == AuthorRole.SYSTEM:
if not system_message_found:
if system_message_count == 0:
system_message_content = message.content
system_message_found = True
elif message.role == AuthorRole.TOOL:
# if tool result message isn't the most recent message, add it to the remaining messages
if not remaining_messages or remaining_messages[-1][role_key] != AuthorRole.USER:
remaining_messages.append({
role_key: AuthorRole.USER,
content_key: [],
})

# add the tool result to the most recent message
tool_results_message = remaining_messages[-1]
for item in message.items:
if isinstance(item, FunctionResultContent):
tool_results_message["content"].append({
"type": "tool_result",
"tool_use_id": item.id,
content_key: str(item.result),
})
elif message.finish_reason == SemanticKernelFinishReason.TOOL_CALLS:
if not stream:
if not message.inner_content:
raise ServiceInvalidResponseError(
"Expected a message with an Anthropic Message as inner content."
)

remaining_messages.append({
role_key: AuthorRole.ASSISTANT,
content_key: [content_block.to_dict() for content_block in message.inner_content.content],
})
else:
content: list[TextBlock | ToolUseBlock] = []
# for remaining items, add them to the content
for item in message.items:
if isinstance(item, TextContent):
content.append(TextBlock(text=item.text, type="text"))
elif isinstance(item, FunctionCallContent):
item_arguments = (
item.arguments if not isinstance(item.arguments, str) else json.loads(item.arguments)
)

content.append(
ToolUseBlock(id=item.id, input=item_arguments, name=item.name, type="tool_use")
)

remaining_messages.append({
role_key: AuthorRole.ASSISTANT,
content_key: content,
})
system_message_count += 1
else:
# The API requires only role and content keys for the remaining messages
remaining_messages.append({
role_key: getattr(message, role_key),
content_key: getattr(message, content_key),
})
formatted_messages.append(MESSAGE_CONVERTERS[message.role](message))

return remaining_messages, system_message_content
if system_message_count > 1:
logger.warning(
"Anthropic service only supports one system message, but %s system messages were found."
" Only the first system message will be included in the request.",
system_message_count,
)

return formatted_messages, system_message_content

# endregion

Expand Down
110 changes: 110 additions & 0 deletions python/semantic_kernel/connectors/ai/anthropic/services/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft. All rights reserved.

import json
import logging
from collections.abc import Callable, Mapping
from typing import Any

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole

logger: logging.Logger = logging.getLogger(__name__)


def _format_user_message(message: ChatMessageContent) -> dict[str, Any]:
"""Format a user message to the expected object for the Anthropic client.

Args:
message: The user message.

Returns:
The formatted user message.
"""
return {
"role": "user",
"content": message.content,
}


def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]:
"""Format an assistant message to the expected object for the Anthropic client.

Args:
message: The assistant message.

Returns:
The formatted assistant message.
"""
tool_calls: list[dict[str, Any]] = []

for item in message.items:
if isinstance(item, TextContent):
# Assuming the assistant message will have only one text content item
# and we assign the content directly to the message content, which is a string.
continue
if isinstance(item, FunctionCallContent):
tool_calls.append({
"type": "tool_use",
"id": item.id or "",
"name": item.name or "",
"input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""),
})
else:
logger.warning(
f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}"
)

if tool_calls:
return {
"role": "assistant",
"content": [
{
"type": "text",
"text": message.content,
},
*tool_calls,
],
}

return {
"role": "assistant",
"content": message.content,
}


def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]:
"""Format a tool message to the expected object for the Anthropic client.

Args:
message: The tool message.

Returns:
The formatted tool message.
"""
function_result_contents: list[dict[str, Any]] = []
for item in message.items:
if not isinstance(item, FunctionResultContent):
logger.warning(
f"Unsupported item type in Tool message while formatting chat history for Anthropic: {type(item)}"
)
continue
function_result_contents.append({
"type": "tool_result",
"tool_use_id": item.id,
"content": str(item.result),
})

return {
"role": "user",
"content": function_result_contents,
}


MESSAGE_CONVERTERS: dict[AuthorRole, Callable[[ChatMessageContent], dict[str, Any]]] = {
AuthorRole.USER: _format_user_message,
AuthorRole.ASSISTANT: _format_assistant_message,
AuthorRole.TOOL: _format_tool_message,
}
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,12 @@ class FunctionChoiceTestTypes(str, Enum):
),
pytest.param(
"anthropic",
{},
{
# Anthropic expects tools in the request when it sees tool use in the chat history.
"function_choice_behavior": FunctionChoiceBehavior.Auto(
auto_invoke=True, filters={"excluded_plugins": ["task_plugin"]}
),
},
[
[
ChatMessageContent(
Expand All @@ -460,9 +465,12 @@ class FunctionChoiceTestTypes(str, Enum):
ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[
# Anthropic will often include a chain of thought in the tool call by default.
# If this is not in the message, it will complain about the missing chain of thought.
TextContent(text="I will find the revenue for you."),
FunctionCallContent(
id="123456789", name="finance-search", arguments='{"company": "contoso", "year": 2024}'
)
),
],
),
ChatMessageContent(
Expand Down
20 changes: 20 additions & 0 deletions python/tests/integration/completions/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Reasoning(KernelBaseModel):
pytestmark = pytest.mark.parametrize(
"service_id, execution_settings_kwargs, inputs, kwargs",
[
# region OpenAI
pytest.param(
"openai",
{},
Expand All @@ -63,6 +64,8 @@ class Reasoning(KernelBaseModel):
{},
id="openai_json_schema_response_format",
),
# endregion
# region Azure
pytest.param(
"azure",
{},
Expand All @@ -83,6 +86,8 @@ class Reasoning(KernelBaseModel):
{},
id="azure_custom_client",
),
# endregion
# region Azure AI Inference
pytest.param(
"azure_ai_inference",
{},
Expand All @@ -93,6 +98,8 @@ class Reasoning(KernelBaseModel):
{},
id="azure_ai_inference_text_input",
),
# endregion
# region Anthropic
pytest.param(
"anthropic",
{},
Expand All @@ -104,6 +111,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not anthropic_setup, reason="Anthropic Environment Variables not set"),
id="anthropic_text_input",
),
# endregion
# region Mistral AI
pytest.param(
"mistral_ai",
{},
Expand All @@ -115,6 +124,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI Environment Variables not set"),
id="mistral_ai_text_input",
),
# endregion
# region Ollama
pytest.param(
"ollama",
{},
Expand All @@ -129,6 +140,8 @@ class Reasoning(KernelBaseModel):
),
id="ollama_text_input",
),
# endregion
# region Onnx Gen AI
pytest.param(
"onnx_gen_ai",
{},
Expand All @@ -140,6 +153,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not onnx_setup, reason="Need a Onnx Model setup"),
id="onnx_gen_ai",
),
# endregion
# region Google AI
pytest.param(
"google_ai",
{},
Expand All @@ -151,6 +166,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skip(reason="Skipping due to 429s from Google AI."),
id="google_ai_text_input",
),
# endregion
# region Vertex AI
pytest.param(
"vertex_ai",
{},
Expand All @@ -162,6 +179,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not vertex_ai_setup, reason="Vertex AI Environment Variables not set"),
id="vertex_ai_text_input",
),
# endregion
# region Bedrock
pytest.param(
"bedrock_amazon_titan",
{},
Expand Down Expand Up @@ -228,6 +247,7 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_mistralai_text_input",
),
# endregion
],
)

Expand Down
Loading