diff --git a/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py b/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py index f8490edba2cd..87e967184234 100644 --- a/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py @@ -3,7 +3,7 @@ import json import logging import sys -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import Any, ClassVar if sys.version_info >= (3, 12): @@ -26,7 +26,10 @@ from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( AnthropicChatPromptExecutionSettings, ) -from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS +from semantic_kernel.connectors.ai.anthropic.services.utils import ( + MESSAGE_CONVERTERS, + update_settings_from_function_call_configuration, +) from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration @@ -43,10 +46,10 @@ from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason from semantic_kernel.exceptions.service_exceptions import ( ServiceInitializationError, + ServiceInvalidRequestError, ServiceInvalidResponseError, ServiceResponseException, ) -from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.telemetry.model_diagnostics.decorators import ( trace_chat_completion, @@ -130,6 +133,19 @@ def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"] def service_url(self) -> str | None: return str(self.async_client.base_url) + @override + def _update_function_choice_settings_callback( + self, + ) -> Callable[[FunctionCallChoiceConfiguration, "PromptExecutionSettings", FunctionChoiceType], None]: + return update_settings_from_function_call_configuration + + @override + def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None: + if hasattr(settings, "tool_choice"): + settings.tool_choice = None + if hasattr(settings, "tools"): + settings.tools = None + @override @trace_chat_completion(MODEL_PROVIDER_NAME) async def _inner_get_chat_message_contents( @@ -172,6 +188,7 @@ async def _inner_get_streaming_chat_message_contents( async for message in response: yield message + @override def _prepare_chat_history_for_request( self, chat_history: "ChatHistory", @@ -195,14 +212,37 @@ def _prepare_chat_history_for_request( system_message_content = None system_message_count = 0 formatted_messages: list[dict[str, Any]] = [] - for message in chat_history.messages: - # Skip system messages after the first one is found - if message.role == AuthorRole.SYSTEM: + for i in range(len(chat_history)): + prev_message = chat_history[i - 1] if i > 0 else None + curr_message = chat_history[i] + if curr_message.role == AuthorRole.SYSTEM: + # Skip system messages after the first one is found if system_message_count == 0: - system_message_content = message.content + system_message_content = curr_message.content system_message_count += 1 + elif curr_message.role == AuthorRole.USER or curr_message.role == AuthorRole.ASSISTANT: + formatted_messages.append(MESSAGE_CONVERTERS[curr_message.role](curr_message)) + elif curr_message.role == AuthorRole.TOOL: + if prev_message is None: + # Under no circumstances should a tool message be the first message in the chat history + raise ServiceInvalidRequestError("Tool message found without a preceding message.") + if prev_message.role == AuthorRole.USER or prev_message.role == AuthorRole.SYSTEM: + # A tool message should not be found after a user or system message + # Please NOTE that in SK there are the USER role and the TOOL role, but in Anthropic + # the tool messages are considered as USER messages. We are checking against the SK roles. + raise ServiceInvalidRequestError("Tool message found after a user or system message.") + + formatted_message = MESSAGE_CONVERTERS[curr_message.role](curr_message) + if prev_message.role == AuthorRole.ASSISTANT: + # The first tool message after an assistant message should be a new message + formatted_messages.append(formatted_message) + else: + # Append the tool message to the previous tool message. + # This indicates that the assistant message requested multiple parallel tool calls. + # Anthropic requires that parallel Tool messages are grouped together in a single message. + formatted_messages[-1][content_key] += formatted_message[content_key] else: - formatted_messages.append(MESSAGE_CONVERTERS[message.role](message)) + raise ServiceInvalidRequestError(f"Unsupported role in chat history: {curr_message.role}") if system_message_count > 1: logger.warning( @@ -280,50 +320,6 @@ def _create_streaming_chat_message_content( function_invoke_attempt=function_invoke_attempt, ) - def update_settings_from_function_call_configuration_anthropic( - self, - function_choice_configuration: FunctionCallChoiceConfiguration, - settings: "PromptExecutionSettings", - type: "FunctionChoiceType", - ) -> None: - """Update the settings from a FunctionChoiceConfiguration.""" - if ( - function_choice_configuration.available_functions - and hasattr(settings, "tools") - and hasattr(settings, "tool_choice") - ): - settings.tools = [ - self.kernel_function_metadata_to_function_call_format_anthropic(f) - for f in function_choice_configuration.available_functions - ] - - if ( - settings.function_choice_behavior - and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED - ) or type == FunctionChoiceType.REQUIRED: - settings.tool_choice = {"type": "any"} - else: - settings.tool_choice = {"type": type.value} - - def kernel_function_metadata_to_function_call_format_anthropic( - self, - metadata: KernelFunctionMetadata, - ) -> dict[str, Any]: - """Convert the kernel function metadata to function calling format.""" - return { - "name": metadata.fully_qualified_name, - "description": metadata.description or "", - "input_schema": { - "type": "object", - "properties": {p.name: p.schema_data for p in metadata.parameters}, - "required": [p.name for p in metadata.parameters if p.is_required], - }, - } - - @override - def _update_function_choice_settings_callback(self): - return self.update_settings_from_function_call_configuration_anthropic - async def _send_chat_request(self, settings: AnthropicChatPromptExecutionSettings) -> list["ChatMessageContent"]: """Send the chat request.""" try: @@ -389,10 +385,3 @@ def _get_tool_calls_from_message(self, message: Message) -> list[FunctionCallCon ) return tool_calls - - @override - def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None: - if hasattr(settings, "tool_choice"): - settings.tool_choice = None - if hasattr(settings, "tools"): - settings.tools = None diff --git a/python/semantic_kernel/connectors/ai/anthropic/services/utils.py b/python/semantic_kernel/connectors/ai/anthropic/services/utils.py index 774d93615927..31acecb0468f 100644 --- a/python/semantic_kernel/connectors/ai/anthropic/services/utils.py +++ b/python/semantic_kernel/connectors/ai/anthropic/services/utils.py @@ -5,11 +5,15 @@ from collections.abc import Callable, Mapping from typing import Any +from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata logger: logging.Logger = logging.getLogger(__name__) @@ -50,29 +54,32 @@ def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]: "type": "tool_use", "id": item.id or "", "name": item.name or "", - "input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""), + "input": item.arguments + if isinstance(item.arguments, Mapping) + else json.loads(item.arguments) + if item.arguments + else {}, }) else: logger.warning( f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}" ) + formatted_message: dict[str, Any] = {"role": "assistant", "content": []} + + if message.content: + # Only include the text content if it is not empty. + # Otherwise, the Anthropic client will throw an error. + formatted_message["content"].append({ # type: ignore + "type": "text", + "text": message.content, + }) if tool_calls: - return { - "role": "assistant", - "content": [ - { - "type": "text", - "text": message.content, - }, - *tool_calls, - ], - } + # Only include the tool calls if there are any. + # Otherwise, the Anthropic client will throw an error. + formatted_message["content"].extend(tool_calls) # type: ignore - return { - "role": "assistant", - "content": message.content, - } + return formatted_message def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]: @@ -108,3 +115,40 @@ def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]: AuthorRole.ASSISTANT: _format_assistant_message, AuthorRole.TOOL: _format_tool_message, } + + +def update_settings_from_function_call_configuration( + function_choice_configuration: FunctionCallChoiceConfiguration, + settings: PromptExecutionSettings, + type: FunctionChoiceType, +) -> None: + """Update the settings from a FunctionChoiceConfiguration.""" + if ( + function_choice_configuration.available_functions + and hasattr(settings, "tools") + and hasattr(settings, "tool_choice") + ): + settings.tools = [ + kernel_function_metadata_to_function_call_format(f) + for f in function_choice_configuration.available_functions + ] + + if ( + settings.function_choice_behavior and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED + ) or type == FunctionChoiceType.REQUIRED: + settings.tool_choice = {"type": "any"} + else: + settings.tool_choice = {"type": type.value} + + +def kernel_function_metadata_to_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: + """Convert the kernel function metadata to function calling format.""" + return { + "name": metadata.fully_qualified_name, + "description": metadata.description or "", + "input_schema": { + "type": "object", + "properties": {p.name: p.schema_data for p in metadata.parameters}, + "required": [p.name for p in metadata.parameters if p.is_required], + }, + } diff --git a/python/tests/integration/completions/chat_completion_test_base.py b/python/tests/integration/completions/chat_completion_test_base.py index 61152512ae11..a31882951c9b 100644 --- a/python/tests/integration/completions/chat_completion_test_base.py +++ b/python/tests/integration/completions/chat_completion_test_base.py @@ -66,9 +66,7 @@ onnx_setup: bool = is_service_setup_for_testing( ["ONNX_GEN_AI_CHAT_MODEL_FOLDER"], raise_if_not_set=False ) # Tests are optional for ONNX -anthropic_setup: bool = is_service_setup_for_testing( - ["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"], raise_if_not_set=False -) # We don't have an Anthropic deployment +anthropic_setup: bool = is_service_setup_for_testing(["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"]) # When testing Bedrock, after logging into AWS CLI this has been set, so we can use it to check if the service is setup bedrock_setup: bool = is_service_setup_for_testing(["AWS_DEFAULT_REGION"], raise_if_not_set=False) diff --git a/python/tests/unit/connectors/ai/anthropic/conftest.py b/python/tests/unit/connectors/ai/anthropic/conftest.py new file mode 100644 index 000000000000..dc7d54cae463 --- /dev/null +++ b/python/tests/unit/connectors/ai/anthropic/conftest.py @@ -0,0 +1,400 @@ +# Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock + +import pytest +from anthropic import AsyncAnthropic +from anthropic.lib.streaming import TextEvent +from anthropic.lib.streaming._types import InputJsonEvent +from anthropic.types import ( + ContentBlockStopEvent, + InputJSONDelta, + Message, + MessageDeltaUsage, + MessageStopEvent, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + TextBlock, + TextDelta, + ToolUseBlock, + Usage, +) +from anthropic.types.raw_message_delta_event import Delta + +from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( + AnthropicChatPromptExecutionSettings, +) +from semantic_kernel.contents.chat_message_content import ( + ChatMessageContent, + FunctionCallContent, + FunctionResultContent, + TextContent, +) +from semantic_kernel.contents.const import ContentTypes +from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent, StreamingTextContent +from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.contents.utils.finish_reason import FinishReason + + +@pytest.fixture +def mock_tool_calls_message() -> ChatMessageContent: + return ChatMessageContent( + ai_model_id="claude-3-opus-20240229", + metadata={}, + content_type="message", + role=AuthorRole.ASSISTANT, + name=None, + items=[ + TextContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="text", + text="", + encoding=None, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="test_function_call_content", + index=1, + name="math-Add", + function_name="Add", + plugin_name="math", + arguments={"input": 3, "amount": 3}, + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_parallel_tool_calls_message() -> ChatMessageContent: + return ChatMessageContent( + ai_model_id="claude-3-opus-20240229", + metadata={}, + content_type="message", + role=AuthorRole.ASSISTANT, + name=None, + items=[ + TextContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="text", + text="", + encoding=None, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="test_function_call_content_1", + index=1, + name="math-Add", + function_name="Add", + plugin_name="math", + arguments={"input": 3, "amount": 3}, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="test_function_call_content_2", + index=1, + name="math-Subtract", + function_name="Subtract", + plugin_name="math", + arguments={"input": 6, "amount": 3}, + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_streaming_tool_calls_message() -> list: + stream_events = [ + RawMessageStartEvent( + message=Message( + id="test_message_id", + content=[], + model="claude-3-opus-20240229", + role="assistant", + stop_reason=None, + stop_sequence=None, + type="message", + usage=Usage(input_tokens=1720, output_tokens=2), + ), + type="message_start", + ), + RawContentBlockStartEvent(content_block=TextBlock(text="", type="text"), index=0, type="content_block_start"), + RawContentBlockDeltaEvent( + delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" + ), + TextEvent(type="text", text="", snapshot=""), + RawContentBlockDeltaEvent( + delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" + ), + TextEvent(type="text", text="", snapshot=""), + ContentBlockStopEvent( + index=0, type="content_block_stop", content_block=TextBlock(text="", type="text") + ), + RawContentBlockStartEvent( + content_block=ToolUseBlock(id="test_tool_use_message_id", input={}, name="math-Add", type="tool_use"), + index=1, + type="content_block_start", + ), + RawContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json='{"input": 3, "amount": 3}', type="input_json_delta"), + index=1, + type="content_block_delta", + ), + InputJsonEvent(type="input_json", partial_json='{"input": 3, "amount": 3}', snapshot={"input": 3, "amount": 3}), + ContentBlockStopEvent( + index=1, + type="content_block_stop", + content_block=ToolUseBlock( + id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" + ), + ), + RawMessageDeltaEvent( + delta=Delta(stop_reason="tool_use", stop_sequence=None), + type="message_delta", + usage=MessageDeltaUsage(output_tokens=159), + ), + MessageStopEvent( + type="message_stop", + message=Message( + id="test_message_id", + content=[ + TextBlock(text="", type="text"), + ToolUseBlock( + id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" + ), + ], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=100, output_tokens=100), + ), + ), + ] + + async def async_generator(): + for event in stream_events: + yield event + + stream_mock = AsyncMock() + stream_mock.__aenter__.return_value = async_generator() + + return stream_mock + + +@pytest.fixture +def mock_tool_call_result_message() -> ChatMessageContent: + return ChatMessageContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="message", + role=AuthorRole.TOOL, + name=None, + items=[ + FunctionResultContent( + id="test_function_call_content", + result=6, + ) + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_parallel_tool_call_result_message() -> ChatMessageContent: + return ChatMessageContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="message", + role=AuthorRole.TOOL, + name=None, + items=[ + FunctionResultContent( + id="test_function_call_content_1", + result=6, + ), + FunctionResultContent( + id="test_function_call_content_2", + result=3, + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_streaming_chat_message_content() -> StreamingChatMessageContent: + return StreamingChatMessageContent( + choice_index=0, + ai_model_id="claude-3-opus-20240229", + metadata={}, + role=AuthorRole.ASSISTANT, + name=None, + items=[ + StreamingTextContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="text", + text="", + encoding=None, + choice_index=0, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="tool_id", + index=0, + name="math-Add", + function_name="Add", + plugin_name="math", + arguments='{"input": 3, "amount": 3}', + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_settings() -> AnthropicChatPromptExecutionSettings: + return AnthropicChatPromptExecutionSettings() + + +@pytest.fixture +def mock_chat_message_response() -> Message: + return Message( + id="test_message_id", + content=[TextBlock(text="Hello, how are you?", type="text")], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=10, output_tokens=10), + ) + + +@pytest.fixture +def mock_streaming_message_response() -> AsyncGenerator: + raw_message_start_event = RawMessageStartEvent( + message=Message( + id="test_message_id", + content=[], + model="claude-3-opus-20240229", + role="assistant", + stop_reason=None, + stop_sequence=None, + type="message", + usage=Usage(input_tokens=41, output_tokens=3), + ), + type="message_start", + ) + + raw_content_block_start_event = RawContentBlockStartEvent( + content_block=TextBlock(text="", type="text"), + index=0, + type="content_block_start", + ) + + raw_content_block_delta_event = RawContentBlockDeltaEvent( + delta=TextDelta(text="Hello! It", type="text_delta"), + index=0, + type="content_block_delta", + ) + + text_event = TextEvent( + type="text", + text="Hello! It", + snapshot="Hello! It", + ) + + content_block_stop_event = ContentBlockStopEvent( + index=0, + type="content_block_stop", + content_block=TextBlock(text="Hello! It's nice to meet you.", type="text"), + ) + + raw_message_delta_event = RawMessageDeltaEvent( + delta=Delta(stop_reason="end_turn", stop_sequence=None), + type="message_delta", + usage=MessageDeltaUsage(output_tokens=84), + ) + + message_stop_event = MessageStopEvent( + type="message_stop", + message=Message( + id="test_message_stop_id", + content=[TextBlock(text="Hello! It's nice to meet you.", type="text")], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=41, output_tokens=84), + ), + ) + + # Combine all mock events into a list + stream_events = [ + raw_message_start_event, + raw_content_block_start_event, + raw_content_block_delta_event, + text_event, + content_block_stop_event, + raw_message_delta_event, + message_stop_event, + ] + + async def async_generator(): + for event in stream_events: + yield event + + # Create an AsyncMock for the stream + stream_mock = AsyncMock() + stream_mock.__aenter__.return_value = async_generator() + + return stream_mock + + +@pytest.fixture +def mock_anthropic_client_completion(mock_chat_message_response: Message) -> AsyncAnthropic: + client = MagicMock(spec=AsyncAnthropic) + messages_mock = MagicMock() + messages_mock.create = AsyncMock(return_value=mock_chat_message_response) + client.messages = messages_mock + return client + + +@pytest.fixture +def mock_anthropic_client_completion_stream(mock_streaming_message_response: AsyncGenerator) -> AsyncAnthropic: + client = MagicMock(spec=AsyncAnthropic) + messages_mock = MagicMock() + messages_mock.stream.return_value = mock_streaming_message_response + client.messages = messages_mock + return client diff --git a/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py b/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py index d368dd901c4d..bff83bfe89d6 100644 --- a/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py +++ b/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py @@ -1,27 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import pytest from anthropic import AsyncAnthropic -from anthropic.lib.streaming import TextEvent -from anthropic.lib.streaming._types import InputJsonEvent -from anthropic.types import ( - ContentBlockStopEvent, - InputJSONDelta, - Message, - MessageDeltaUsage, - MessageStopEvent, - RawContentBlockDeltaEvent, - RawContentBlockStartEvent, - RawMessageDeltaEvent, - RawMessageStartEvent, - TextBlock, - TextDelta, - ToolUseBlock, - Usage, -) -from anthropic.types.raw_message_delta_event import Delta +from anthropic.types import Message from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( AnthropicChatPromptExecutionSettings, @@ -33,406 +15,20 @@ OpenAIChatPromptExecutionSettings, ) from semantic_kernel.contents.chat_history import ChatHistory -from semantic_kernel.contents.chat_message_content import ( - ChatMessageContent, - FunctionCallContent, - FunctionResultContent, - TextContent, -) -from semantic_kernel.contents.const import ContentTypes -from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent, StreamingTextContent +from semantic_kernel.contents.chat_message_content import ChatMessageContent, FunctionCallContent, TextContent +from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole -from semantic_kernel.contents.utils.finish_reason import FinishReason from semantic_kernel.exceptions import ( ServiceInitializationError, ServiceInvalidExecutionSettingsError, ServiceResponseException, ) -from semantic_kernel.functions.function_result import FunctionResult +from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.functions.kernel_function_decorator import kernel_function -from semantic_kernel.functions.kernel_function_from_method import KernelFunctionMetadata -from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata from semantic_kernel.kernel import Kernel -@pytest.fixture -def mock_tool_calls_message() -> ChatMessageContent: - return ChatMessageContent( - inner_content=Message( - id="test_message_id", - content=[ - TextBlock(text="", type="text"), - ToolUseBlock( - id="test_tool_use_blocks", - input={"input": 3, "amount": 3}, - name="math-Add", - type="tool_use", - ), - ], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=1720, output_tokens=194), - ), - ai_model_id="claude-3-opus-20240229", - metadata={}, - content_type="message", - role=AuthorRole.ASSISTANT, - name=None, - items=[ - FunctionCallContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type=ContentTypes.FUNCTION_CALL_CONTENT, - id="test_function_call_content", - index=1, - name="math-Add", - function_name="Add", - plugin_name="math", - arguments={"input": 3, "amount": 3}, - ), - TextContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type="text", - text="", - encoding=None, - ), - ], - encoding=None, - finish_reason=FinishReason.TOOL_CALLS, - ) - - -@pytest.fixture -def mock_streaming_tool_calls_message() -> list: - stream_events = [ - RawMessageStartEvent( - message=Message( - id="test_message_id", - content=[], - model="claude-3-opus-20240229", - role="assistant", - stop_reason=None, - stop_sequence=None, - type="message", - usage=Usage(input_tokens=1720, output_tokens=2), - ), - type="message_start", - ), - RawContentBlockStartEvent(content_block=TextBlock(text="", type="text"), index=0, type="content_block_start"), - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - TextEvent(type="text", text="", snapshot=""), - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - TextEvent(type="text", text="", snapshot=""), - ContentBlockStopEvent( - index=0, type="content_block_stop", content_block=TextBlock(text="", type="text") - ), - RawContentBlockStartEvent( - content_block=ToolUseBlock(id="test_tool_use_message_id", input={}, name="math-Add", type="tool_use"), - index=1, - type="content_block_start", - ), - RawContentBlockDeltaEvent( - delta=InputJSONDelta(partial_json='{"input": 3, "amount": 3}', type="input_json_delta"), - index=1, - type="content_block_delta", - ), - InputJsonEvent(type="input_json", partial_json='{"input": 3, "amount": 3}', snapshot={"input": 3, "amount": 3}), - ContentBlockStopEvent( - index=1, - type="content_block_stop", - content_block=ToolUseBlock( - id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" - ), - ), - RawMessageDeltaEvent( - delta=Delta(stop_reason="tool_use", stop_sequence=None), - type="message_delta", - usage=MessageDeltaUsage(output_tokens=159), - ), - MessageStopEvent( - type="message_stop", - message=Message( - id="test_message_id", - content=[ - TextBlock(text="", type="text"), - ToolUseBlock( - id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" - ), - ], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=100, output_tokens=100), - ), - ), - ] - - async def async_generator(): - for event in stream_events: - yield event - - stream_mock = AsyncMock() - stream_mock.__aenter__.return_value = async_generator() - - return stream_mock - - -@pytest.fixture -def mock_tool_call_result_message() -> ChatMessageContent: - return ChatMessageContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type="message", - role=AuthorRole.TOOL, - name=None, - items=[ - FunctionResultContent( - id="tool_01", - inner_content=FunctionResult( - function=KernelFunctionMetadata( - name="Add", - plugin_name="math", - description="Returns the Addition result of the values provided.", - parameters=[ - KernelParameterMetadata( - name="input", - description="the first number to add", - default_value=None, - type_="int", - is_required=True, - type_object=int, - schema_data={"type": "integer", "description": "the first number to add"}, - function_schema_include=True, - ), - KernelParameterMetadata( - name="amount", - description="the second number to add", - default_value=None, - type_="int", - is_required=True, - type_object=int, - schema_data={"type": "integer", "description": "the second number to add"}, - function_schema_include=True, - ), - ], - is_prompt=False, - is_asynchronous=False, - return_parameter=KernelParameterMetadata( - name="return", - description="the output is a number", - default_value=None, - type_="int", - is_required=True, - type_object=int, - schema_data={"type": "integer", "description": "the output is a number"}, - function_schema_include=True, - ), - additional_properties={}, - ), - value=6, - metadata={}, - ), - value=6, - ) - ], - encoding=None, - finish_reason=FinishReason.TOOL_CALLS, - ) - - -# mock StreamingChatMessageContent -@pytest.fixture -def mock_streaming_chat_message_content() -> StreamingChatMessageContent: - return StreamingChatMessageContent( - choice_index=0, - inner_content=[ - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - ContentBlockStopEvent( - index=1, - type="content_block_stop", - content_block=ToolUseBlock( - id="tool_id", - input={"input": 3, "amount": 3}, - name="math-Add", - type="tool_use", - ), - ), - RawMessageDeltaEvent( - delta=Delta(stop_reason="tool_use", stop_sequence=None), - type="message_delta", - usage=MessageDeltaUsage(output_tokens=175), - ), - ], - ai_model_id="claude-3-opus-20240229", - metadata={}, - role=AuthorRole.ASSISTANT, - name=None, - items=[ - StreamingTextContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type="text", - text="", - encoding=None, - choice_index=0, - ), - FunctionCallContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type=ContentTypes.FUNCTION_CALL_CONTENT, - id="tool_id", - index=0, - name="math-Add", - function_name="Add", - plugin_name="math", - arguments='{"input": 3, "amount": 3}', - ), - ], - encoding=None, - finish_reason=FinishReason.TOOL_CALLS, - ) - - -@pytest.fixture -def mock_settings() -> AnthropicChatPromptExecutionSettings: - return AnthropicChatPromptExecutionSettings() - - -@pytest.fixture -def mock_chat_message_response() -> Message: - return Message( - id="test_message_id", - content=[TextBlock(text="Hello, how are you?", type="text")], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=10, output_tokens=10), - ) - - -@pytest.fixture -def mock_streaming_message_response() -> AsyncGenerator: - raw_message_start_event = RawMessageStartEvent( - message=Message( - id="test_message_id", - content=[], - model="claude-3-opus-20240229", - role="assistant", - stop_reason=None, - stop_sequence=None, - type="message", - usage=Usage(input_tokens=41, output_tokens=3), - ), - type="message_start", - ) - - raw_content_block_start_event = RawContentBlockStartEvent( - content_block=TextBlock(text="", type="text"), - index=0, - type="content_block_start", - ) - - raw_content_block_delta_event = RawContentBlockDeltaEvent( - delta=TextDelta(text="Hello! It", type="text_delta"), - index=0, - type="content_block_delta", - ) - - text_event = TextEvent( - type="text", - text="Hello! It", - snapshot="Hello! It", - ) - - content_block_stop_event = ContentBlockStopEvent( - index=0, - type="content_block_stop", - content_block=TextBlock(text="Hello! It's nice to meet you.", type="text"), - ) - - raw_message_delta_event = RawMessageDeltaEvent( - delta=Delta(stop_reason="end_turn", stop_sequence=None), - type="message_delta", - usage=MessageDeltaUsage(output_tokens=84), - ) - - message_stop_event = MessageStopEvent( - type="message_stop", - message=Message( - id="test_message_stop_id", - content=[TextBlock(text="Hello! It's nice to meet you.", type="text")], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=41, output_tokens=84), - ), - ) - - # Combine all mock events into a list - stream_events = [ - raw_message_start_event, - raw_content_block_start_event, - raw_content_block_delta_event, - text_event, - content_block_stop_event, - raw_message_delta_event, - message_stop_event, - ] - - async def async_generator(): - for event in stream_events: - yield event - - # Create an AsyncMock for the stream - stream_mock = AsyncMock() - stream_mock.__aenter__.return_value = async_generator() - - return stream_mock - - -@pytest.fixture -def mock_anthropic_client_completion(mock_chat_message_response: Message) -> AsyncAnthropic: - client = MagicMock(spec=AsyncAnthropic) - messages_mock = MagicMock() - messages_mock.create = AsyncMock(return_value=mock_chat_message_response) - client.messages = messages_mock - return client - - -@pytest.fixture -def mock_anthropic_client_completion_stream(mock_streaming_message_response: AsyncGenerator) -> AsyncAnthropic: - client = MagicMock(spec=AsyncAnthropic) - messages_mock = MagicMock() - messages_mock.stream.return_value = mock_streaming_message_response - client.messages = messages_mock - return client - - async def test_complete_chat_contents( kernel: Kernel, mock_settings: AnthropicChatPromptExecutionSettings, @@ -753,7 +349,7 @@ async def test_prepare_chat_history_for_request_with_system_message(mock_anthrop assert system_message_content == "System message" assert remaining_messages == [ {"role": AuthorRole.USER, "content": "User message"}, - {"role": AuthorRole.ASSISTANT, "content": "Assistant message"}, + {"role": AuthorRole.ASSISTANT, "content": [{"type": "text", "text": "Assistant message"}]}, ] assert not any(msg["role"] == AuthorRole.SYSTEM for msg in remaining_messages) @@ -780,35 +376,121 @@ async def test_prepare_chat_history_for_request_with_tool_message( ) assert system_message_content is None - assert len(remaining_messages) == 3 + assert remaining_messages == [ + {"role": AuthorRole.USER, "content": "What is 3+3?"}, + { + "role": AuthorRole.ASSISTANT, + "content": [ + {"type": "text", "text": mock_tool_calls_message.items[0].text}, + { + "type": "tool_use", + "id": mock_tool_calls_message.items[1].id, + "name": mock_tool_calls_message.items[1].name, + "input": mock_tool_calls_message.items[1].arguments, + }, + ], + }, + { + "role": AuthorRole.USER, + "content": [ + { + "type": "tool_result", + "tool_use_id": mock_tool_call_result_message.items[0].id, + "content": str(mock_tool_call_result_message.items[0].result), + } + ], + }, + ] -async def test_prepare_chat_history_for_request_with_tool_message_streaming( +async def test_prepare_chat_history_for_request_with_parallel_tool_message( + mock_anthropic_client_completion_stream: MagicMock, + mock_parallel_tool_calls_message: ChatMessageContent, + mock_parallel_tool_call_result_message: ChatMessageContent, +): + chat_history = ChatHistory() + chat_history.add_user_message("What is 3+3?") + chat_history.add_message(mock_parallel_tool_calls_message) + chat_history.add_message(mock_parallel_tool_call_result_message) + + chat_completion_client = AnthropicChatCompletion( + ai_model_id="test_model_id", + service_id="test", + api_key="", + async_client=mock_anthropic_client_completion_stream, + ) + + remaining_messages, system_message_content = chat_completion_client._prepare_chat_history_for_request( + chat_history, role_key="role", content_key="content" + ) + + assert system_message_content is None + assert remaining_messages == [ + {"role": AuthorRole.USER, "content": "What is 3+3?"}, + { + "role": AuthorRole.ASSISTANT, + "content": [ + {"type": "text", "text": mock_parallel_tool_calls_message.items[0].text}, + *[ + { + "type": "tool_use", + "id": function_call_content.id, + "name": function_call_content.name, + "input": function_call_content.arguments, + } + for function_call_content in mock_parallel_tool_calls_message.items[1:] + ], + ], + }, + { + "role": AuthorRole.USER, + "content": [ + { + "type": "tool_result", + "tool_use_id": function_result_content.id, + "content": str(function_result_content.result), + } + for function_result_content in mock_parallel_tool_call_result_message.items + ], + }, + ] + + +async def test_prepare_chat_history_for_request_with_tool_message_right_after_user_message( mock_anthropic_client_completion_stream: MagicMock, - mock_streaming_chat_message_content: StreamingChatMessageContent, mock_tool_call_result_message: ChatMessageContent, ): chat_history = ChatHistory() chat_history.add_user_message("What is 3+3?") - chat_history.add_message(mock_streaming_chat_message_content) chat_history.add_message(mock_tool_call_result_message) - chat_completion = AnthropicChatCompletion( + chat_completion_client = AnthropicChatCompletion( ai_model_id="test_model_id", service_id="test", api_key="", async_client=mock_anthropic_client_completion_stream, ) - remaining_messages, system_message_content = chat_completion._prepare_chat_history_for_request( - chat_history, - role_key="role", - content_key="content", - stream=True, + with pytest.raises(ServiceInvalidRequestError, match="Tool message found after a user or system message."): + chat_completion_client._prepare_chat_history_for_request(chat_history, role_key="role", content_key="content") + + +async def test_prepare_chat_history_for_request_with_tool_message_as_the_first_message( + mock_anthropic_client_completion_stream: MagicMock, + mock_tool_call_result_message: ChatMessageContent, +): + chat_history = ChatHistory() + chat_history.add_message(mock_tool_call_result_message) + + chat_completion_client = AnthropicChatCompletion( + ai_model_id="test_model_id", + service_id="test", + api_key="", + async_client=mock_anthropic_client_completion_stream, ) - assert system_message_content is None - assert len(remaining_messages) == 3 + with pytest.raises(ServiceInvalidRequestError, match="Tool message found without a preceding message."): + chat_completion_client._prepare_chat_history_for_request(chat_history, role_key="role", content_key="content") async def test_send_chat_stream_request_tool_calls(