From 188b2a7f0628581efb6aa6848eb6159e05095d08 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 20 Dec 2024 15:20:54 +0100 Subject: [PATCH] feat: support for tools in `OpenAIChatGenerator` (#8666) * move chatmsg>openai conversion to chatmsg dataclass * implementation and tests cleanup * release note * try fixing azure chat generator * add serde test for toolinvoker * small fix --- haystack/components/generators/chat/azure.py | 5 + .../generators/chat/hugging_face_api.py | 14 +- haystack/components/generators/chat/openai.py | 350 +++++++++++------- haystack/components/generators/openai.py | 3 +- .../components/generators/openai_utils.py | 23 -- haystack/dataclasses/chat_message.py | 45 +++ haystack/dataclasses/tool.py | 6 +- .../notes/openai-tools-26f58a981c4066ef.yaml | 4 + test/components/generators/chat/conftest.py | 14 - .../chat/test_hugging_face_local.py | 8 + .../components/generators/chat/test_openai.py | 316 +++++++++++++--- test/components/generators/conftest.py | 69 +++- test/components/generators/test_openai.py | 14 +- .../generators/test_openai_utils.py | 16 - test/components/tools/test_tool_invoker.py | 57 +++ test/conftest.py | 26 -- test/dataclasses/test_chat_message.py | 55 ++- 17 files changed, 720 insertions(+), 305 deletions(-) delete mode 100644 haystack/components/generators/openai_utils.py create mode 100644 releasenotes/notes/openai-tools-26f58a981c4066ef.yaml delete mode 100644 test/components/generators/chat/conftest.py delete mode 100644 test/components/generators/test_openai_utils.py diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index b74be533dc..ae7787d637 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -142,6 +142,11 @@ def __init__( # pylint: disable=too-many-positional-arguments self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) self.default_headers = default_headers or {} + # This ChatGenerator does not yet supports tools. The following workaround ensures that we do not + # get an error when invoking the run method of the parent class (OpenAIChatGenerator). + self.tools = None + self.tools_strict = False + self.client = AzureOpenAI( api_version=api_version, azure_endpoint=azure_endpoint, diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index dab61e4d93..9a0cc75906 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -163,10 +163,9 @@ def __init__( # pylint: disable=too-many-positional-arguments msg = f"Unknown api_type {api_type}" raise ValueError(msg) - if tools: - if streaming_callback is not None: - raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") - _check_duplicate_tool_names(tools) + if tools and streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} @@ -241,10 +240,9 @@ def run( formatted_messages = [convert_message_to_hf_format(message) for message in messages] tools = tools or self.tools - if tools: - if self.streaming_callback: - raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") - _check_duplicate_tool_names(tools) + if tools and self.streaming_callback: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) if self.streaming_callback: return self._run_streaming(formatted_messages, generation_kwargs) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index bcbaeced0e..2662014f9a 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -2,10 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -import copy import json import os -from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -14,13 +12,16 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from haystack import component, default_from_dict, default_to_dict, logging -from haystack.components.generators.openai_utils import _convert_message_to_openai_format -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) +StreamingCallbackT = Callable[[StreamingChunk], None] + + @component class OpenAIChatGenerator: """ @@ -68,12 +69,14 @@ def __init__( # pylint: disable=too-many-positional-arguments self, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), model: str = "gpt-4o-mini", - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, api_base_url: Optional[str] = None, organization: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, max_retries: Optional[int] = None, + tools: Optional[List[Tool]] = None, + tools_strict: bool = False, ): """ Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-4o-mini @@ -117,6 +120,11 @@ def __init__( # pylint: disable=too-many-positional-arguments :param max_retries: Maximum number of retries to contact OpenAI after an internal error. If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5. + :param tools: + A list of tools for which the model can prepare calls. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly + the schema provided in the `parameters` field of the tool definition, but this may increase latency. """ self.api_key = api_key self.model = model @@ -124,6 +132,12 @@ def __init__( # pylint: disable=too-many-positional-arguments self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.organization = organization + self.timeout = timeout + self.max_retries = max_retries + self.tools = tools + self.tools_strict = tools_strict + + _check_duplicate_tool_names(tools) if timeout is None: timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0)) @@ -160,6 +174,10 @@ def to_dict(self) -> Dict[str, Any]: organization=self.organization, generation_kwargs=self.generation_kwargs, api_key=self.api_key.to_dict(), + timeout=self.timeout, + max_retries=self.max_retries, + tools=[tool.to_dict() for tool in self.tools] if self.tools else None, + tools_strict=self.tools_strict, ) @classmethod @@ -172,6 +190,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": The deserialized component instance. """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: @@ -182,130 +201,195 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": def run( self, messages: List[ChatMessage], - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + *, + tools: Optional[List[Tool]] = None, + tools_strict: Optional[bool] = None, ): """ Invokes chat completion based on the provided messages and generation parameters. - :param messages: A list of ChatMessage instances representing the input messages. - :param streaming_callback: A callback function that is called when a new token is received from the stream. - :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - override the parameters passed during component initialization. - For details on OpenAI API parameters, see - [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). + :param messages: + A list of ChatMessage instances representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: + Additional keyword arguments for text generation. These parameters will + override the parameters passed during component initialization. + For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly + the schema provided in the `parameters` field of the tool definition, but this may increase latency. + If set, it will override the `tools_strict` parameter set during component initialization. :returns: - A list containing the generated responses as ChatMessage instances. + A dictionary with the following key: + - `replies`: A list containing the generated responses as ChatMessage instances. """ + if len(messages) == 0: + return {"replies": []} - # update generation kwargs by merging with the generation kwargs passed to the run method - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - - # check if streaming_callback is passed streaming_callback = streaming_callback or self.streaming_callback - # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = [_convert_message_to_openai_format(message) for message in messages] - + api_args = self._prepare_api_call( + messages=messages, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + tools=tools, + tools_strict=tools_strict, + ) chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( - model=self.model, - messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types - stream=streaming_callback is not None, - **generation_kwargs, + **api_args ) - completions: List[ChatMessage] = [] - # if streaming is enabled, the completion is a Stream of ChatCompletionChunk - if isinstance(chat_completion, Stream): - num_responses = generation_kwargs.pop("n", 1) - if num_responses > 1: - raise ValueError("Cannot stream multiple responses, please set n=1.") - chunks: List[StreamingChunk] = [] - completion_chunk = None - _first_token = True - - # pylint: disable=not-an-iterable - for completion_chunk in chat_completion: - if completion_chunk.choices and streaming_callback: - chunk_delta: StreamingChunk = self._build_chunk(completion_chunk) - if _first_token: - _first_token = False - chunk_delta.meta["completion_start_time"] = datetime.now().isoformat() - chunks.append(chunk_delta) - streaming_callback(chunk_delta) # invoke callback with the chunk_delta - completions = [self._create_message_from_chunks(completion_chunk, chunks)] - # if streaming is disabled, the completion is a ChatCompletion - elif isinstance(chat_completion, ChatCompletion): - completions = [self._build_message(chat_completion, choice) for choice in chat_completion.choices] + is_streaming = isinstance(chat_completion, Stream) + assert is_streaming or streaming_callback is None + + if is_streaming: + completions = self._handle_stream_response( + chat_completion, # type: ignore + streaming_callback, # type: ignore + ) + else: + assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request." + completions = [ + self._convert_chat_completion_to_chat_message(chat_completion, choice) + for choice in chat_completion.choices + ] # before returning, do post-processing of the completions for message in completions: - self._check_finish_reason(message) + self._check_finish_reason(message.meta) return {"replies": completions} - def _create_message_from_chunks( - self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk] - ) -> ChatMessage: + def _prepare_api_call( # noqa: PLR0913 + self, + *, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + tools_strict: Optional[bool] = None, + ) -> Dict[str, Any]: + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + # adapt ChatMessage(s) to the format expected by the OpenAI API + openai_formatted_messages = [message.to_openai_dict_format() for message in messages] + + tools = tools or self.tools + tools_strict = tools_strict if tools_strict is not None else self.tools_strict + _check_duplicate_tool_names(tools) + + openai_tools = None + if tools: + openai_tools = [ + {"type": "function", "function": {**t.tool_spec, **({"strict": tools_strict} if tools_strict else {})}} + for t in tools + ] + + is_streaming = streaming_callback is not None + num_responses = generation_kwargs.pop("n", 1) + if is_streaming and num_responses > 1: + raise ValueError("Cannot stream multiple responses, please set n=1.") + + return { + "model": self.model, + "messages": openai_formatted_messages, # type: ignore[arg-type] # openai expects list of specific message types + "stream": streaming_callback is not None, + "tools": openai_tools, # type: ignore[arg-type] + "n": num_responses, + **generation_kwargs, + } + + def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]: + print("callback") + print(callback) + print("-" * 100) + + chunks: List[StreamingChunk] = [] + chunk = None + + for chunk in chat_completion: # pylint: disable=not-an-iterable + assert len(chunk.choices) == 1, "Streaming responses should have only one choice." + chunk_delta: StreamingChunk = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) + chunks.append(chunk_delta) + + callback(chunk_delta) + + return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + + def _check_finish_reason(self, meta: Dict[str, Any]) -> None: + if meta["finish_reason"] == "length": + logger.warning( + "The completion for index {index} has been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions.", + index=meta["index"], + finish_reason=meta["finish_reason"], + ) + if meta["finish_reason"] == "content_filter": + logger.warning( + "The completion for index {index} has been truncated due to the content filter.", + index=meta["index"], + finish_reason=meta["finish_reason"], + ) + + def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: """ - Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk. + Connects the streaming chunks into a single ChatMessage. - :param completion_chunk: The last completion chunk returned by the OpenAI API. - :param streamed_chunks: The list of all chunks returned by the OpenAI API. + :param chunk: The last chunk returned by the OpenAI API. + :param chunks: The list of all `StreamingChunk` objects. """ - is_tools_call = bool(streamed_chunks[0].meta.get("tool_calls")) - is_function_call = bool(streamed_chunks[0].meta.get("function_call")) - # if it's a tool call or function call, we need to build the payload dict from all the chunks - if is_tools_call or is_function_call: - tools_len = 1 if is_function_call else len(streamed_chunks[0].meta.get("tool_calls", [])) - # don't change this approach of building payload dicts, otherwise mypy will complain - p_def: Dict[str, Any] = { - "index": 0, - "id": "", - "function": {"arguments": "", "name": ""}, - "type": "function", - } - payloads = [copy.deepcopy(p_def) for _ in range(tools_len)] - for chunk_payload in streamed_chunks: - if is_tools_call: - deltas = chunk_payload.meta.get("tool_calls") or [] - else: - deltas = [chunk_payload.meta["function_call"]] if chunk_payload.meta.get("function_call") else [] + + text = "".join([chunk.content for chunk in chunks]) + tool_calls = [] + + # if it's a tool call , we need to build the payload dict from all the chunks + if bool(chunks[0].meta.get("tool_calls")): + tools_len = len(chunks[0].meta.get("tool_calls", [])) + + payloads = [{"arguments": "", "name": ""} for _ in range(tools_len)] + for chunk_payload in chunks: + deltas = chunk_payload.meta.get("tool_calls") or [] # deltas is a list of ChoiceDeltaToolCall or ChoiceDeltaFunctionCall for i, delta in enumerate(deltas): - payload = payloads[i] - if is_tools_call: - payload["id"] = delta.id or payload["id"] - payload["type"] = delta.type or payload["type"] - if delta.function: - payload["function"]["name"] += delta.function.name or "" - payload["function"]["arguments"] += delta.function.arguments or "" - elif is_function_call: - payload["function"]["name"] += delta.name or "" - payload["function"]["arguments"] += delta.arguments or "" - complete_response = ChatMessage.from_assistant(json.dumps(payloads)) - else: - total_content = "" - total_meta = {} - for streaming_chunk in streamed_chunks: - total_content += streaming_chunk.content - total_meta.update(streaming_chunk.meta) - complete_response = ChatMessage.from_assistant(total_content, meta=total_meta) - finish_reason = streamed_chunks[-1].meta["finish_reason"] - complete_response.meta.update( - { - "model": completion_chunk.model, - "index": 0, - "finish_reason": finish_reason, - # Usage is available when streaming only if the user explicitly requests it - "usage": dict(completion_chunk.usage or {}), - } - ) - return complete_response - - def _build_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage: + payloads[i]["id"] = delta.id or payloads[i].get("id", "") + if delta.function: + payloads[i]["name"] += delta.function.name or "" + payloads[i]["arguments"] += delta.function.arguments or "" + + for payload in payloads: + arguments_str = payload["arguments"] + try: + arguments = json.loads(arguments_str) + tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "OpenAI returned a malformed JSON string for tool call arguments. This tool call " + "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=payload["id"], + _name=payload["name"], + _arguments=arguments_str, + ) + + meta = { + "model": chunk.model, + "index": 0, + "finish_reason": chunk.choices[0].finish_reason, + "usage": {}, # we don't have usage data for streaming responses + } + + return ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + + def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage: """ Converts the non-streaming response from the OpenAI API to a ChatMessage. @@ -314,20 +398,26 @@ def _build_message(self, completion: ChatCompletion, choice: Choice) -> ChatMess :return: The ChatMessage. """ message: ChatCompletionMessage = choice.message - content = message.content or "" - if message.function_call: - # here we mimic the tools format response so that if user passes deprecated `functions` parameter - # she'll get the same output as if new `tools` parameter was passed - # use pydantic model dump to serialize the function call - content = json.dumps( - [{"function": message.function_call.model_dump(), "type": "function", "id": completion.id}] - ) - elif message.tool_calls: - # new `tools` parameter was passed, use pydantic model dump to serialize the tool calls - content = json.dumps([tc.model_dump() for tc in message.tool_calls]) - - chat_message = ChatMessage.from_assistant(content) - chat_message.meta.update( + text = message.content + tool_calls = [] + if openai_tool_calls := message.tool_calls: + for openai_tc in openai_tool_calls: + arguments_str = openai_tc.function.arguments + try: + arguments = json.loads(arguments_str) + tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "OpenAI returned a malformed JSON string for tool call arguments. This tool call " + "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=openai_tc.id, + _name=openai_tc.function.name, + _arguments=arguments_str, + ) + + chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) + chat_message._meta.update( { "model": completion.model, "index": choice.index, @@ -337,7 +427,7 @@ def _build_message(self, completion: ChatCompletion, choice: Choice) -> ChatMess ) return chat_message - def _build_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: + def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: """ Converts the streaming response chunk from the OpenAI API to a StreamingChunk. @@ -350,35 +440,13 @@ def _build_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: content = choice.delta.content or "" chunk_message = StreamingChunk(content) # but save the tool calls and function call in the meta if they are present - # and then connect the chunks in the _connect_chunks method + # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method chunk_message.meta.update( { "model": chunk.model, "index": choice.index, "tool_calls": choice.delta.tool_calls, - "function_call": choice.delta.function_call, "finish_reason": choice.finish_reason, } ) return chunk_message - - def _check_finish_reason(self, message: ChatMessage) -> None: - """ - Check the `finish_reason` returned with the OpenAI completions. - - If the `finish_reason` is `length` or `content_filter`, log a warning. - :param message: The message returned by the LLM. - """ - if message.meta["finish_reason"] == "length": - logger.warning( - "The completion for index {index} has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions.", - index=message.meta["index"], - finish_reason=message.meta["finish_reason"], - ) - if message.meta["finish_reason"] == "content_filter": - logger.warning( - "The completion for index {index} has been truncated due to the content filter.", - index=message.meta["index"], - finish_reason=message.meta["finish_reason"], - ) diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index d50b082556..d2f07f9d85 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -9,7 +9,6 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from haystack import component, default_from_dict, default_to_dict, logging -from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -207,7 +206,7 @@ def run( streaming_callback = streaming_callback or self.streaming_callback # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = [_convert_message_to_openai_format(message) for message in messages] + openai_formatted_messages = [message.to_openai_dict_format() for message in messages] completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( model=self.model, diff --git a/haystack/components/generators/openai_utils.py b/haystack/components/generators/openai_utils.py deleted file mode 100644 index ab6d5e7b1d..0000000000 --- a/haystack/components/generators/openai_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from typing import Dict - -from haystack.dataclasses import ChatMessage - - -def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: - """ - Convert a message to the format expected by OpenAI's Chat API. - - See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. - - :returns: A dictionary with the following keys: - - `role` - - `content` - """ - if message.text is None: - raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") - - return {"role": message.role.value, "content": message.text} diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 5aadb9f752..e4d656b15e 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import json import warnings from dataclasses import asdict, dataclass, field from enum import Enum @@ -381,3 +382,47 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": data["_content"] = content return cls(**data) + + def to_openai_dict_format(self) -> Dict[str, Any]: + """ + Convert a ChatMessage to the dictionary format expected by OpenAI's Chat API. + """ + text_contents = self.texts + tool_calls = self.tool_calls + tool_call_results = self.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError( + "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + ) + if len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + openai_msg: Dict[str, Any] = {"role": self._role.value} + + if tool_call_results: + result = tool_call_results[0] + if result.origin.id is None: + raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.") + openai_msg["content"] = result.result + openai_msg["tool_call_id"] = result.origin.id + # OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field + return openai_msg + + if text_contents: + openai_msg["content"] = text_contents[0] + if tool_calls: + openai_tool_calls = [] + for tc in tool_calls: + if tc.id is None: + raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.") + openai_tool_calls.append( + { + "id": tc.id, + "type": "function", + # We disable ensure_ascii so special chars like emojis are not converted + "function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)}, + } + ) + openai_msg["tool_calls"] = openai_tool_calls + return openai_msg diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py index c6606d51e8..833c2796ef 100644 --- a/haystack/dataclasses/tool.py +++ b/haystack/dataclasses/tool.py @@ -216,13 +216,15 @@ def _remove_title_from_schema(schema: Dict[str, Any]): del property_schema[key] -def _check_duplicate_tool_names(tools: List[Tool]) -> None: +def _check_duplicate_tool_names(tools: Optional[List[Tool]]) -> None: """ - Check for duplicate tool names and raises a ValueError if they are found. + Checks for duplicate tool names and raises a ValueError if they are found. :param tools: The list of tools to check. :raises ValueError: If duplicate tool names are found. """ + if tools is None: + return tool_names = [tool.name for tool in tools] duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} if duplicate_tool_names: diff --git a/releasenotes/notes/openai-tools-26f58a981c4066ef.yaml b/releasenotes/notes/openai-tools-26f58a981c4066ef.yaml new file mode 100644 index 0000000000..d7a6e1779b --- /dev/null +++ b/releasenotes/notes/openai-tools-26f58a981c4066ef.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for Tools in the OpenAI Chat Generator. diff --git a/test/components/generators/chat/conftest.py b/test/components/generators/chat/conftest.py deleted file mode 100644 index 842e447b56..0000000000 --- a/test/components/generators/chat/conftest.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -import pytest - -from haystack.dataclasses import ChatMessage - - -@pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), - ChatMessage.from_user("Tell me about Berlin"), - ] diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index fe5308b7b3..9b01acb134 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -17,6 +17,14 @@ def streaming_callback_handler(x): return x +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + @pytest.fixture def model_info_mock(): with patch( diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 0461ba3cde..243eb36c89 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1,17 +1,27 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import MagicMock, patch +import pytest + +from typing import Iterator import logging import os -from unittest.mock import patch +import json +from datetime import datetime -import pytest from openai import OpenAIError +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_tool_call import Function +from openai.types.chat import chat_completion_chunk +from openai import Stream -from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret +from haystack.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent +from haystack.components.generators.chat.openai import OpenAIChatGenerator @pytest.fixture @@ -22,6 +32,59 @@ def chat_messages(): ] +@pytest.fixture +def mock_chat_completion_chunk_with_tools(openai_mock_stream): + """ + Mock the OpenAI API completion chunk response and reuse it for tests + """ + + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletionChunk( + id="foo", + model="gpt-4", + object="chat.completion.chunk", + choices=[ + chat_completion_chunk.Choice( + finish_reason="tool_calls", + logprobs=None, + index=0, + delta=chat_completion_chunk.ChoiceDelta( + role="assistant", + tool_calls=[ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id="123", + type="function", + function=chat_completion_chunk.ChoiceDeltaToolCallFunction( + name="weather", arguments='{"city": "Paris"}' + ), + ) + ], + ), + ) + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + mock_chat_completion_create.return_value = openai_mock_stream( + completion, cast_to=None, response=None, client=None + ) + yield mock_chat_completion_create + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] + + class TestOpenAIChatGenerator: def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -32,13 +95,24 @@ def test_init_default(self, monkeypatch): assert not component.generation_kwargs assert component.client.timeout == 30 assert component.client.max_retries == 5 + assert component.tools is None + assert not component.tools_strict def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="None of the .* environment variables are set"): + with pytest.raises(ValueError): OpenAIChatGenerator() + def test_init_fail_with_duplicate_tool_names(self, monkeypatch, tools): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + OpenAIChatGenerator(tools=duplicate_tools) + def test_init_with_parameters(self, monkeypatch): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) + monkeypatch.setenv("OPENAI_TIMEOUT", "100") monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") component = OpenAIChatGenerator( @@ -49,6 +123,8 @@ def test_init_with_parameters(self, monkeypatch): generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, timeout=40.0, max_retries=1, + tools=[tool], + tools_strict=True, ) assert component.client.api_key == "test-api-key" assert component.model == "gpt-4o-mini" @@ -56,6 +132,8 @@ def test_init_with_parameters(self, monkeypatch): assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.client.timeout == 40.0 assert component.client.max_retries == 1 + assert component.tools == [tool] + assert component.tools_strict def test_init_with_parameters_and_env_vars(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") @@ -87,10 +165,16 @@ def test_to_dict_default(self, monkeypatch): "streaming_callback": None, "api_base_url": None, "generation_kwargs": {}, + "tools": None, + "tools_strict": False, + "max_retries": None, + "timeout": None, }, } def test_to_dict_with_parameters(self, monkeypatch): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + monkeypatch.setenv("ENV_VAR", "test-api-key") component = OpenAIChatGenerator( api_key=Secret.from_env_var("ENV_VAR"), @@ -98,8 +182,13 @@ def test_to_dict_with_parameters(self, monkeypatch): streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=[tool], + tools_strict=True, + max_retries=10, + timeout=100.0, ) data = component.to_dict() + assert data == { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { @@ -107,8 +196,19 @@ def test_to_dict_with_parameters(self, monkeypatch): "model": "gpt-4o-mini", "organization": None, "api_base_url": "test-base-url", + "max_retries": 10, + "timeout": 100.0, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ], + "tools_strict": True, }, } @@ -128,8 +228,12 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): "model": "gpt-4o-mini", "organization": None, "api_base_url": "test-base-url", + "max_retries": None, + "timeout": None, "streaming_callback": "chat.test_openai.", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": None, + "tools_strict": False, }, } @@ -142,15 +246,34 @@ def test_from_dict(self, monkeypatch): "model": "gpt-4o-mini", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "max_retries": 10, + "timeout": 100.0, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ], + "tools_strict": True, }, } component = OpenAIChatGenerator.from_dict(data) + + assert isinstance(component, OpenAIChatGenerator) assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") + assert component.tools == [ + Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + ] + assert component.tools_strict + assert component.client.timeout == 100.0 + assert component.client.max_retries == 10 def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) @@ -158,17 +281,17 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gpt-4o-mini", + "model": "gpt-4", "organization": None, "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - with pytest.raises(ValueError, match="None of the .* environment variables are set"): + with pytest.raises(ValueError): OpenAIChatGenerator.from_dict(data) - def test_run(self, chat_messages, mock_chat_completion): + def test_run(self, chat_messages, openai_mock_chat_completion): component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) response = component.run(chat_messages) @@ -179,14 +302,14 @@ def test_run(self, chat_messages, mock_chat_completion): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_run_with_params(self, chat_messages, mock_chat_completion): + def test_run_with_params(self, chat_messages, openai_mock_chat_completion): component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run(chat_messages) # check that the component calls the OpenAI API with the correct parameters - _, kwargs = mock_chat_completion.call_args + _, kwargs = openai_mock_chat_completion.call_args assert kwargs["max_tokens"] == 10 assert kwargs["temperature"] == 0.5 @@ -197,7 +320,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_run_with_params_streaming(self, chat_messages, mock_chat_completion_chunk): + def test_run_with_params_streaming(self, chat_messages, openai_mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: @@ -218,10 +341,9 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk + assert "Hello" in response["replies"][0].text # see openai_mock_chat_completion_chunk - @patch("haystack.components.generators.chat.openai.datetime") - def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk): + def test_run_with_streaming_callback_in_run_method(self, chat_messages, openai_mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: @@ -240,14 +362,7 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk - - assert hasattr(response["replies"][0], "meta") - assert isinstance(response["replies"][0].meta, dict) - assert ( - response["replies"][0].meta["completion_start_time"] - == mock_datetime.now.return_value.isoformat.return_value - ) + assert "Hello" in response["replies"][0].text # see openai_mock_chat_completion_chunk def test_check_abnormal_completions(self, caplog): caplog.set_level(logging.INFO) @@ -260,7 +375,7 @@ def test_check_abnormal_completions(self, caplog): ] for m in messages: - component._check_finish_reason(m) + component._check_finish_reason(m.meta) # check truncation warning message_template = ( @@ -276,6 +391,119 @@ def test_check_abnormal_completions(self, caplog): for index in [0, 2]: assert caplog.records[index].message == message_template.format(index=index) + def test_run_with_tools(self, tools): + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4", + object="chat.completion", + choices=[ + Choice( + finish_reason="tool_calls", + logprobs=None, + index=0, + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="123", + type="function", + function=Function(name="weather", arguments='{"city": "Paris"}'), + ) + ], + ), + ) + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) + response = component.run([ChatMessage.from_user("What's the weather like in Paris?")]) + + assert len(response["replies"]) == 1 + message = response["replies"][0] + + assert not message.texts + assert not message.text + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + + def test_run_with_tools_streaming(self, mock_chat_completion_chunk_with_tools, tools): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback + ) + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + response = component.run(chat_messages, tools=tools) + + # check we called the streaming callback + assert streaming_callback_called + + # check that the component still returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + message = response["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + + def test_invalid_tool_call_json(self, tools, caplog): + caplog.set_level(logging.WARNING) + + with patch("openai.resources.chat.completions.Completions.create") as mock_create: + mock_create.return_value = ChatCompletion( + id="test", + model="gpt-4o-mini", + object="chat.completion", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function(name="weather", arguments='"invalid": "json"'), + ) + ], + ), + ) + ], + created=1234567890, + usage={"prompt_tokens": 50, "completion_tokens": 30, "total_tokens": 80}, + ) + + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) + response = component.run([ChatMessage.from_user("What's the weather in Paris?")]) + + assert len(response["replies"]) == 1 + message = response["replies"][0] + assert len(message.tool_calls) == 0 + assert "OpenAI returned a malformed JSON string for tool call arguments" in caplog.text + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", @@ -288,7 +516,7 @@ def test_live_run(self): assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert "gpt-4o-mini" in message.meta["model"] + assert "gpt-4o" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @pytest.mark.skipif( @@ -324,7 +552,7 @@ def __call__(self, chunk: StreamingChunk) -> None: message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert "gpt-4o-mini" in message.meta["model"] + assert "gpt-4o" in message.meta["model"] assert message.meta["finish_reason"] == "stop" assert callback.counter > 1 @@ -335,28 +563,18 @@ def __call__(self, chunk: StreamingChunk) -> None: reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration - def test_live_run_streaming_with_include_usage(self): - class Callback: - def __init__(self): - self.responses = "" - self.counter = 0 - - def __call__(self, chunk: StreamingChunk) -> None: - self.counter += 1 - self.responses += chunk.content if chunk.content else "" - - callback = Callback() - component = OpenAIChatGenerator( - streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}} - ) - results = component.run([ChatMessage.from_user("What's the capital of France?")]) - + def test_live_run_with_tools(self, tools): + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = OpenAIChatGenerator(tools=tools) + results = component.run(chat_messages) assert len(results["replies"]) == 1 - message: ChatMessage = results["replies"][0] - assert "Paris" in message.text - - assert "gpt-4o-mini" in message.meta["model"] - assert message.meta["finish_reason"] == "stop" - - assert callback.counter > 1 - assert "Paris" in callback.responses + message = results["replies"][0] + + assert not message.texts + assert not message.text + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" diff --git a/test/components/generators/conftest.py b/test/components/generators/conftest.py index 92ed8feb3a..4aa931451f 100644 --- a/test/components/generators/conftest.py +++ b/test/components/generators/conftest.py @@ -7,8 +7,10 @@ import pytest from openai import Stream -from openai.types.chat import ChatCompletionChunk +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall +from openai.types.chat import chat_completion_chunk @pytest.fixture @@ -29,21 +31,55 @@ def mock_auto_tokenizer(): yield mock_tokenizer +class OpenAIMockStream(Stream[ChatCompletionChunk]): + def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): + client = client or MagicMock() + super().__init__(client=client, *args, **kwargs) + self.mock_chunk = mock_chunk + + def __stream__(self) -> Iterator[ChatCompletionChunk]: + yield self.mock_chunk + + @pytest.fixture -def mock_chat_completion_chunk(): +def openai_mock_stream(): """ - Mock the OpenAI API completion chunk response and reuse it for tests + Fixture that returns a function to create MockStream instances with custom chunks + """ + return OpenAIMockStream + + +@pytest.fixture +def openai_mock_chat_completion(): + """ + Mock the OpenAI API completion response and reuse it for tests """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4", + object="chat.completion", + choices=[ + { + "finish_reason": "stop", + "logprobs": None, + "index": 0, + "message": {"content": "Hello world!", "role": "assistant"}, + } + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create - class MockStream(Stream[ChatCompletionChunk]): - def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): - client = client or MagicMock() - super().__init__(client=client, *args, **kwargs) - self.mock_chunk = mock_chunk - def __stream__(self) -> Iterator[ChatCompletionChunk]: - # Yielding only one ChatCompletionChunk object - yield self.mock_chunk +@pytest.fixture +def openai_mock_chat_completion_chunk(): + """ + Mock the OpenAI API completion chunk response and reuse it for tests + """ with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletionChunk( @@ -51,12 +87,17 @@ def __stream__(self) -> Iterator[ChatCompletionChunk]: model="gpt-4", object="chat.completion.chunk", choices=[ - Choice( - finish_reason="stop", logprobs=None, index=0, delta=ChoiceDelta(content="Hello", role="assistant") + chat_completion_chunk.Choice( + finish_reason="stop", + logprobs=None, + index=0, + delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"), ) ], created=int(datetime.now().timestamp()), usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, ) - mock_chat_completion_create.return_value = MockStream(completion, cast_to=None, response=None, client=None) + mock_chat_completion_create.return_value = OpenAIMockStream( + completion, cast_to=None, response=None, client=None + ) yield mock_chat_completion_create diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 97c071c809..32628f7c45 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -148,7 +148,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(ValueError, match="None of the .* environment variables are set"): OpenAIGenerator.from_dict(data) - def test_run(self, mock_chat_completion): + def test_run(self, openai_mock_chat_completion): component = OpenAIGenerator(api_key=Secret.from_token("test-api-key")) response = component.run("What's Natural Language Processing?") @@ -159,7 +159,7 @@ def test_run(self, mock_chat_completion): assert len(response["replies"]) == 1 assert [isinstance(reply, str) for reply in response["replies"]] - def test_run_with_params_streaming(self, mock_chat_completion_chunk): + def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: @@ -177,9 +177,9 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 - assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk + assert "Hello" in response["replies"][0] # see openai_mock_chat_completion_chunk - def test_run_with_streaming_callback_in_run_method(self, mock_chat_completion_chunk): + def test_run_with_streaming_callback_in_run_method(self, openai_mock_chat_completion_chunk): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: @@ -198,16 +198,16 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 - assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk + assert "Hello" in response["replies"][0] # see openai_mock_chat_completion_chunk - def test_run_with_params(self, mock_chat_completion): + def test_run_with_params(self, openai_mock_chat_completion): component = OpenAIGenerator( api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run("What's Natural Language Processing?") # check that the component calls the OpenAI API with the correct parameters - _, kwargs = mock_chat_completion.call_args + _, kwargs = openai_mock_chat_completion.call_args assert kwargs["max_tokens"] == 10 assert kwargs["temperature"] == 0.5 diff --git a/test/components/generators/test_openai_utils.py b/test/components/generators/test_openai_utils.py deleted file mode 100644 index 916a3e3d70..0000000000 --- a/test/components/generators/test_openai_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from haystack.dataclasses import ChatMessage -from haystack.components.generators.openai_utils import _convert_message_to_openai_format - - -def test_convert_message_to_openai_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"} - - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index 34b1ca9fef..f492b2c0a1 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -6,6 +6,7 @@ from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole from haystack.dataclasses.tool import Tool, ToolInvocationError from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError +from haystack.components.generators.chat.openai import OpenAIChatGenerator def weather_function(location): @@ -218,3 +219,59 @@ def test_from_dict(self, weather_tool): assert invoker._tools_with_names == {"weather_tool": weather_tool} assert invoker.raise_on_failure assert not invoker.convert_result_to_json_string + + def test_serde_in_pipeline(self, invoker, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + pipeline = Pipeline() + pipeline.add_component("invoker", invoker) + pipeline.add_component("chatgenerator", OpenAIChatGenerator()) + pipeline.connect("invoker", "chatgenerator") + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "invoker": { + "type": "haystack.components.tools.tool_invoker.ToolInvoker", + "init_parameters": { + "tools": [ + { + "name": "weather_tool", + "description": "Provides weather information for a given location.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + "function": "tools.test_tool_invoker.weather_function", + } + ], + "raise_on_failure": True, + "convert_result_to_json_string": False, + }, + }, + "chatgenerator": { + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "model": "gpt-4o-mini", + "streaming_callback": None, + "api_base_url": None, + "organization": None, + "generation_kwargs": {}, + "max_retries": None, + "timeout": None, + "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, + "tools": None, + "tools_strict": False, + }, + }, + }, + "connections": [{"sender": "invoker.tool_messages", "receiver": "chatgenerator.messages"}], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline diff --git a/test/conftest.py b/test/conftest.py index a7282be645..513009d234 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -36,32 +36,6 @@ def test_files_path(): return Path(__file__).parent / "test_files" -@pytest.fixture -def mock_chat_completion(): - """ - Mock the OpenAI API completion response and reuse it for tests - """ - with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: - completion = ChatCompletion( - id="foo", - model="gpt-4", - object="chat.completion", - choices=[ - Choice( - finish_reason="stop", - logprobs=None, - index=0, - message=ChatCompletionMessage(content="Hello world!", role="assistant"), - ) - ], - created=int(datetime.now().timestamp()), - usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, - ) - - mock_chat_completion_create.return_value = completion - yield mock_chat_completion_create - - @pytest.fixture(autouse=True) def request_blocker(request: pytest.FixtureRequest, monkeypatch): """ diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 832617e712..2209af998f 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest from transformers import AutoTokenizer +import json from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent -from haystack.components.generators.openai_utils import _convert_message_to_openai_format def test_tool_call_init(): @@ -239,11 +239,60 @@ def test_chat_message_function_role_deprecated(): ChatMessage(ChatRole.FUNCTION, TextContent("This is a message")) +def test_to_openai_dict_format(): + message = ChatMessage.from_system("You are good assistant") + assert message.to_openai_dict_format() == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert message.to_openai_dict_format() == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert message.to_openai_dict_format() == {"role": "assistant", "content": "I have an answer"} + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert message.to_openai_dict_format() == { + "role": "assistant", + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "weather", "arguments": '{"city": "Paris"}'}} + ], + } + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + assert message.to_openai_dict_format() == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + + +def test_to_openai_dict_format_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + message.to_openai_dict_format() + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], + ) + with pytest.raises(ValueError): + message.to_openai_dict_format() + + tool_call_null_id = ToolCall(id=None, tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call_null_id]) + with pytest.raises(ValueError): + message.to_openai_dict_format() + + message = ChatMessage.from_tool(tool_result="result", origin=tool_call_null_id) + with pytest.raises(ValueError): + message.to_openai_dict_format() + + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - formatted_messages = [_convert_message_to_openai_format(m) for m in messages] + formatted_messages = [m.to_openai_dict_format() for m in messages] tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False) assert tokenized_messages == "<|system|>\nYou are good assistant\n<|user|>\nI have a question\n" @@ -264,7 +313,7 @@ def test_apply_custom_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] # could be any tokenizer, let's use the one we already likely have in cache tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - formatted_messages = [_convert_message_to_openai_format(m) for m in messages] + formatted_messages = [m.to_openai_dict_format() for m in messages] tokenized_messages = tokenizer.apply_chat_template( formatted_messages, chat_template=anthropic_template, tokenize=False )