Skip to content

Commit

Permalink
Python: Anthropic function calling fixes (#9938)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
The current implementation of the Anthropic connector relies on the
`inner_content`s in chat messages to prepare the chat history for the
Anthropic client. This will only work when the chat history is created
by the Anthropic connector. This won't work if the chat history has been
processed by other connectors, or if it is hardcoded as in testing.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
1. Prepare the chat history for the Anthropic client by parsing the
actual Semantic Kernel item types.
2. Fix tests for the Anthropic connector.


### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
TaoChenOSU authored Dec 16, 2024
1 parent b427208 commit 4a21254
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import (
AnthropicChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS
from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
Expand All @@ -34,7 +35,6 @@
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import ITEM_TYPES as STREAMING_ITEM_TYPES
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
Expand Down Expand Up @@ -192,69 +192,25 @@ def _prepare_chat_history_for_request(
A tuple containing the prepared chat history and the first SYSTEM message content.
"""
system_message_content = None
remaining_messages: list[dict[str, Any]] = []
system_message_found = False
system_message_count = 0
formatted_messages: list[dict[str, Any]] = []
for message in chat_history.messages:
# Skip system messages after the first one is found
if message.role == AuthorRole.SYSTEM:
if not system_message_found:
if system_message_count == 0:
system_message_content = message.content
system_message_found = True
elif message.role == AuthorRole.TOOL:
# if tool result message isn't the most recent message, add it to the remaining messages
if not remaining_messages or remaining_messages[-1][role_key] != AuthorRole.USER:
remaining_messages.append({
role_key: AuthorRole.USER,
content_key: [],
})

# add the tool result to the most recent message
tool_results_message = remaining_messages[-1]
for item in message.items:
if isinstance(item, FunctionResultContent):
tool_results_message["content"].append({
"type": "tool_result",
"tool_use_id": item.id,
content_key: str(item.result),
})
elif message.finish_reason == SemanticKernelFinishReason.TOOL_CALLS:
if not stream:
if not message.inner_content:
raise ServiceInvalidResponseError(
"Expected a message with an Anthropic Message as inner content."
)

remaining_messages.append({
role_key: AuthorRole.ASSISTANT,
content_key: [content_block.to_dict() for content_block in message.inner_content.content],
})
else:
content: list[TextBlock | ToolUseBlock] = []
# for remaining items, add them to the content
for item in message.items:
if isinstance(item, TextContent):
content.append(TextBlock(text=item.text, type="text"))
elif isinstance(item, FunctionCallContent):
item_arguments = (
item.arguments if not isinstance(item.arguments, str) else json.loads(item.arguments)
)

content.append(
ToolUseBlock(id=item.id, input=item_arguments, name=item.name, type="tool_use")
)

remaining_messages.append({
role_key: AuthorRole.ASSISTANT,
content_key: content,
})
system_message_count += 1
else:
# The API requires only role and content keys for the remaining messages
remaining_messages.append({
role_key: getattr(message, role_key),
content_key: getattr(message, content_key),
})
formatted_messages.append(MESSAGE_CONVERTERS[message.role](message))

return remaining_messages, system_message_content
if system_message_count > 1:
logger.warning(
"Anthropic service only supports one system message, but %s system messages were found."
" Only the first system message will be included in the request.",
system_message_count,
)

return formatted_messages, system_message_content

# endregion

Expand Down
110 changes: 110 additions & 0 deletions python/semantic_kernel/connectors/ai/anthropic/services/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft. All rights reserved.

import json
import logging
from collections.abc import Callable, Mapping
from typing import Any

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole

logger: logging.Logger = logging.getLogger(__name__)


def _format_user_message(message: ChatMessageContent) -> dict[str, Any]:
"""Format a user message to the expected object for the Anthropic client.
Args:
message: The user message.
Returns:
The formatted user message.
"""
return {
"role": "user",
"content": message.content,
}


def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]:
"""Format an assistant message to the expected object for the Anthropic client.
Args:
message: The assistant message.
Returns:
The formatted assistant message.
"""
tool_calls: list[dict[str, Any]] = []

for item in message.items:
if isinstance(item, TextContent):
# Assuming the assistant message will have only one text content item
# and we assign the content directly to the message content, which is a string.
continue
if isinstance(item, FunctionCallContent):
tool_calls.append({
"type": "tool_use",
"id": item.id or "",
"name": item.name or "",
"input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""),
})
else:
logger.warning(
f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}"
)

if tool_calls:
return {
"role": "assistant",
"content": [
{
"type": "text",
"text": message.content,
},
*tool_calls,
],
}

return {
"role": "assistant",
"content": message.content,
}


def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]:
"""Format a tool message to the expected object for the Anthropic client.
Args:
message: The tool message.
Returns:
The formatted tool message.
"""
function_result_contents: list[dict[str, Any]] = []
for item in message.items:
if not isinstance(item, FunctionResultContent):
logger.warning(
f"Unsupported item type in Tool message while formatting chat history for Anthropic: {type(item)}"
)
continue
function_result_contents.append({
"type": "tool_result",
"tool_use_id": item.id,
"content": str(item.result),
})

return {
"role": "user",
"content": function_result_contents,
}


MESSAGE_CONVERTERS: dict[AuthorRole, Callable[[ChatMessageContent], dict[str, Any]]] = {
AuthorRole.USER: _format_user_message,
AuthorRole.ASSISTANT: _format_assistant_message,
AuthorRole.TOOL: _format_tool_message,
}
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,12 @@ class FunctionChoiceTestTypes(str, Enum):
),
pytest.param(
"anthropic",
{},
{
# Anthropic expects tools in the request when it sees tool use in the chat history.
"function_choice_behavior": FunctionChoiceBehavior.Auto(
auto_invoke=True, filters={"excluded_plugins": ["task_plugin"]}
),
},
[
[
ChatMessageContent(
Expand All @@ -460,9 +465,12 @@ class FunctionChoiceTestTypes(str, Enum):
ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[
# Anthropic will often include a chain of thought in the tool call by default.
# If this is not in the message, it will complain about the missing chain of thought.
TextContent(text="I will find the revenue for you."),
FunctionCallContent(
id="123456789", name="finance-search", arguments='{"company": "contoso", "year": 2024}'
)
),
],
),
ChatMessageContent(
Expand Down
20 changes: 20 additions & 0 deletions python/tests/integration/completions/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Reasoning(KernelBaseModel):
pytestmark = pytest.mark.parametrize(
"service_id, execution_settings_kwargs, inputs, kwargs",
[
# region OpenAI
pytest.param(
"openai",
{},
Expand All @@ -63,6 +64,8 @@ class Reasoning(KernelBaseModel):
{},
id="openai_json_schema_response_format",
),
# endregion
# region Azure
pytest.param(
"azure",
{},
Expand All @@ -83,6 +86,8 @@ class Reasoning(KernelBaseModel):
{},
id="azure_custom_client",
),
# endregion
# region Azure AI Inference
pytest.param(
"azure_ai_inference",
{},
Expand All @@ -93,6 +98,8 @@ class Reasoning(KernelBaseModel):
{},
id="azure_ai_inference_text_input",
),
# endregion
# region Anthropic
pytest.param(
"anthropic",
{},
Expand All @@ -104,6 +111,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not anthropic_setup, reason="Anthropic Environment Variables not set"),
id="anthropic_text_input",
),
# endregion
# region Mistral AI
pytest.param(
"mistral_ai",
{},
Expand All @@ -115,6 +124,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI Environment Variables not set"),
id="mistral_ai_text_input",
),
# endregion
# region Ollama
pytest.param(
"ollama",
{},
Expand All @@ -129,6 +140,8 @@ class Reasoning(KernelBaseModel):
),
id="ollama_text_input",
),
# endregion
# region Onnx Gen AI
pytest.param(
"onnx_gen_ai",
{},
Expand All @@ -140,6 +153,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not onnx_setup, reason="Need a Onnx Model setup"),
id="onnx_gen_ai",
),
# endregion
# region Google AI
pytest.param(
"google_ai",
{},
Expand All @@ -151,6 +166,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skip(reason="Skipping due to 429s from Google AI."),
id="google_ai_text_input",
),
# endregion
# region Vertex AI
pytest.param(
"vertex_ai",
{},
Expand All @@ -162,6 +179,8 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skipif(not vertex_ai_setup, reason="Vertex AI Environment Variables not set"),
id="vertex_ai_text_input",
),
# endregion
# region Bedrock
pytest.param(
"bedrock_amazon_titan",
{},
Expand Down Expand Up @@ -228,6 +247,7 @@ class Reasoning(KernelBaseModel):
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_mistralai_text_input",
),
# endregion
],
)

Expand Down

0 comments on commit 4a21254

Please sign in to comment.