From 85ee34b924996105970bdf014572ce5aa8f4bb62 Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Thu, 12 Oct 2023 19:34:22 +0200 Subject: [PATCH 1/2] Python: Implement Function calling for Chat (#2356) ### Motivation and Context I built support for function calling into SK! Related to/fixes: - #2315 - #2175 - #1450 ### Description This implementation builds on top of all the existing pieces, but did require some major work, so feel free to comment on where that is or is not appropriate. - Added a `ChatMessage` class to capture the relevant pieces of function calling (name and content) - Added a `complete_chat_with_functions_async` into OpenAIChatCompletions class - Added a `function_call` field to ChatRequestSettings class - Added several helper functions and smaller changes - Added a sample with updated core_skill that uses function calling to demonstrate - Added a second sample that shows how to use function_calling with non-sk functions. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --------- Co-authored-by: Abby Harrison --- python/.vscode/settings.json | 5 + .../chat_gpt_api_function_calling.py | 121 ++++++++++++ .../openai_function_calling.py | 109 +++++++++++ .../ai/chat_completion_client_base.py | 15 +- .../connectors/ai/chat_request_settings.py | 8 +- .../ai/open_ai/models/chat/function_call.py | 36 ++++ .../models/chat/open_ai_chat_message.py | 14 ++ .../open_ai_chat_prompt_template.py | 103 ++++++++++ .../services/open_ai_chat_completion.py | 164 +++++++++++----- .../connectors/ai/open_ai/utils.py | 184 ++++++++++++++++++ .../semantic_kernel/core_skills/math_skill.py | 4 + python/semantic_kernel/kernel.py | 3 +- .../models/chat/chat_message.py | 40 ++++ .../orchestration/sk_context.py | 14 +- .../orchestration/sk_function.py | 103 ++++++---- .../orchestration/sk_function_base.py | 3 +- .../chat_prompt_template.py | 87 ++++++--- .../semantic_functions/prompt_template.py | 10 +- .../prompt_template_config.py | 15 +- .../skill_definition/parameter_view.py | 12 +- ...sk_function_context_parameter_decorator.py | 12 +- .../open_ai/models/chat/test_function_call.py | 28 +++ .../services/test_azure_chat_completion.py | 2 +- .../unit/models/chat/test_chat_message.py | 43 ++++ .../orchestration/test_native_function.py | 9 + .../skill_definition/test_prompt_templates.py | 10 +- python/tests/unit/test_serialization.py | 10 +- samples/skills/FunSkill/Joke/config.json | 2 +- 28 files changed, 1023 insertions(+), 143 deletions(-) create mode 100644 python/samples/kernel-syntax-examples/chat_gpt_api_function_calling.py create mode 100644 python/samples/kernel-syntax-examples/openai_function_calling.py create mode 100644 python/semantic_kernel/connectors/ai/open_ai/models/chat/function_call.py create mode 100644 python/semantic_kernel/connectors/ai/open_ai/models/chat/open_ai_chat_message.py create mode 100644 python/semantic_kernel/connectors/ai/open_ai/semantic_functions/open_ai_chat_prompt_template.py create mode 100644 python/semantic_kernel/connectors/ai/open_ai/utils.py create mode 100644 python/semantic_kernel/models/chat/chat_message.py create mode 100644 python/tests/unit/ai/open_ai/models/chat/test_function_call.py create mode 100644 python/tests/unit/models/chat/test_chat_message.py diff --git a/python/.vscode/settings.json b/python/.vscode/settings.json index 4db6521e2588..4e6c0ad8e15e 100644 --- a/python/.vscode/settings.json +++ b/python/.vscode/settings.json @@ -18,4 +18,9 @@ "OPENAI", "skfunction" ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, } \ No newline at end of file diff --git a/python/samples/kernel-syntax-examples/chat_gpt_api_function_calling.py b/python/samples/kernel-syntax-examples/chat_gpt_api_function_calling.py new file mode 100644 index 000000000000..df0d3d2e9c21 --- /dev/null +++ b/python/samples/kernel-syntax-examples/chat_gpt_api_function_calling.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os +from typing import Tuple + +import semantic_kernel as sk +import semantic_kernel.connectors.ai.open_ai as sk_oai +from semantic_kernel.connectors.ai.open_ai.semantic_functions.open_ai_chat_prompt_template import ( + OpenAIChatPromptTemplate, +) +from semantic_kernel.connectors.ai.open_ai.utils import ( + chat_completion_with_function_call, + get_function_calling_object, +) +from semantic_kernel.core_skills import MathSkill + +system_message = """ +You are a chat bot. Your name is Mosscap and +you have one goal: figure out what people need. +Your full name, should you need to know it, is +Splendid Speckled Mosscap. You communicate +effectively, but you tend to answer with long +flowery prose. You are also a math wizard, +especially for adding and subtracting. +You also excel at joke telling, where your tone is often sarcastic. +Once you have the answer I am looking for, +you will return a full answer to me as soon as possible. +""" + +kernel = sk.Kernel() + +deployment_name, api_key, endpoint = sk.azure_openai_settings_from_dot_env() +api_version = "2023-07-01-preview" +kernel.add_chat_service( + "chat-gpt", + sk_oai.AzureChatCompletion( + deployment_name, + endpoint, + api_key, + api_version=api_version, + ), +) + +skills_directory = os.path.join(__file__, "../../../../samples/skills") +# adding skills to the kernel +# the joke skill in the FunSkills is a semantic skill and has the function calling disabled. +kernel.import_semantic_skill_from_directory(skills_directory, "FunSkill") +# the math skill is a core skill and has the function calling enabled. +kernel.import_skill(MathSkill(), skill_name="math") + +# enabling or disabling function calling is done by setting the function_call parameter for the completion. +# when the function_call parameter is set to "auto" the model will decide which function to use, if any. +# if you only want to use a specific function, set the name of that function in this parameter, +# the format for that is 'SkillName-FunctionName', (i.e. 'math-Add'). +# if the model or api version do not support this you will get an error. +prompt_config = sk.PromptTemplateConfig.from_completion_parameters( + max_tokens=2000, + temperature=0.7, + top_p=0.8, + function_call="auto", + chat_system_prompt=system_message, +) +prompt_template = OpenAIChatPromptTemplate( + "{{$user_input}}", kernel.prompt_template_engine, prompt_config +) +prompt_template.add_user_message("Hi there, who are you?") +prompt_template.add_assistant_message( + "I am Mosscap, a chat bot. I'm trying to figure out what people need." +) + +function_config = sk.SemanticFunctionConfig(prompt_config, prompt_template) +chat_function = kernel.register_semantic_function("ChatBot", "Chat", function_config) + +# calling the chat, you could add a overloaded version of the settings here, +# to enable or disable function calling or set the function calling to a specific skill. +# see the openai_function_calling example for how to use this with a unrelated function definition +filter = {"exclude_skill": ["ChatBot"]} +functions = get_function_calling_object(kernel, filter) + + +async def chat(context: sk.SKContext) -> Tuple[bool, sk.SKContext]: + try: + user_input = input("User:> ") + context.variables["user_input"] = user_input + except KeyboardInterrupt: + print("\n\nExiting chat...") + return False, None + except EOFError: + print("\n\nExiting chat...") + return False, None + + if user_input == "exit": + print("\n\nExiting chat...") + return False, None + + context = await chat_completion_with_function_call( + kernel, + chat_skill_name="ChatBot", + chat_function_name="Chat", + context=context, + functions=functions, + ) + print(f"Mosscap:> {context.result}") + return True, context + + +async def main() -> None: + chatting = True + context = kernel.create_new_context() + print( + "Welcome to the chat bot!\ +\n Type 'exit' to exit.\ +\n Try a math question to see the function calling in action (i.e. what is 3+3?)." + ) + while chatting: + chatting, context = await chat(context) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/kernel-syntax-examples/openai_function_calling.py b/python/samples/kernel-syntax-examples/openai_function_calling.py new file mode 100644 index 000000000000..ecce09438df2 --- /dev/null +++ b/python/samples/kernel-syntax-examples/openai_function_calling.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +import semantic_kernel as sk +import semantic_kernel.connectors.ai.open_ai as sk_oai +from semantic_kernel.core_skills import MathSkill + +system_message = """ +You are a chat bot. Your name is Mosscap and +you have one goal: figure out what people need. +Your full name, should you need to know it, is +Splendid Speckled Mosscap. You communicate +effectively, but you tend to answer with long +flowery prose. You are also a math wizard, +especially for adding and subtracting. +You also excel at joke telling, where your tone is often sarcastic. +Once you have the answer I am looking for, +you will return a full answer to me as soon as possible. +""" + +kernel = sk.Kernel() + +deployment_name, api_key, endpoint = sk.azure_openai_settings_from_dot_env() +api_version = "2023-07-01-preview" +kernel.add_chat_service( + "chat-gpt", + sk_oai.AzureChatCompletion( + deployment_name, + endpoint, + api_key, + api_version=api_version, + ), +) + +skills_directory = os.path.join(__file__, "../../../../samples/skills") +# adding skills to the kernel +# the joke skill in the FunSkills is a semantic skill and has the function calling disabled. +kernel.import_semantic_skill_from_directory(skills_directory, "FunSkill") +# the math skill is a core skill and has the function calling enabled. +kernel.import_skill(MathSkill(), skill_name="math") + +# enabling or disabling function calling is done by setting the function_call parameter for the completion. +# when the function_call parameter is set to "auto" the model will decide which function to use, if any. +# if you only want to use a specific function, set the name of that function in this parameter, +# the format for that is 'SkillName-FunctionName', (i.e. 'math-Add'). +# if the model or api version do not support this you will get an error. +prompt_config = sk.PromptTemplateConfig.from_completion_parameters( + max_tokens=2000, + temperature=0.7, + top_p=0.8, + function_call="auto", + chat_system_prompt=system_message, +) +prompt_template = sk.ChatPromptTemplate( + "{{$user_input}}", kernel.prompt_template_engine, prompt_config +) +prompt_template.add_user_message("Hi there, who are you?") +prompt_template.add_assistant_message( + "I am Mosscap, a chat bot. I'm trying to figure out what people need." +) + +function_config = sk.SemanticFunctionConfig(prompt_config, prompt_template) +chat_function = kernel.register_semantic_function("ChatBot", "Chat", function_config) +# define the functions available +functions = [ + { + "name": "search_hotels", + "description": "Retrieves hotels from the search index based on the parameters provided", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location of the hotel (i.e. Seattle, WA)", + }, + "max_price": { + "type": "number", + "description": "The maximum price for the hotel", + }, + "features": { + "type": "string", + "description": "A comma separated list of features (i.e. beachfront, free wifi, etc.)", + }, + }, + "required": ["location"], + }, + } +] + + +async def main() -> None: + context = kernel.create_new_context() + context.variables[ + "user_input" + ] = "I want to find a hotel in Seattle with free wifi and a pool." + + context = await chat_function.invoke_async(context=context, functions=functions) + if function_call := context.pop_function_call(): + print(f"Function to be called: {function_call.name}") + print(f"Function parameters: \n{function_call.arguments}") + return + print("No function was called") + print(f"Output was: {str(context)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/semantic_kernel/connectors/ai/chat_completion_client_base.py b/python/semantic_kernel/connectors/ai/chat_completion_client_base.py index 798103e3c89c..118c84c8ac7f 100644 --- a/python/semantic_kernel/connectors/ai/chat_completion_client_base.py +++ b/python/semantic_kernel/connectors/ai/chat_completion_client_base.py @@ -2,17 +2,18 @@ from abc import ABC, abstractmethod from logging import Logger -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Union if TYPE_CHECKING: from semantic_kernel.connectors.ai.chat_request_settings import ChatRequestSettings + from semantic_kernel.models.chat.chat_message import ChatMessage class ChatCompletionClientBase(ABC): @abstractmethod async def complete_chat_async( self, - messages: List[Tuple[str, str]], + messages: List["ChatMessage"], settings: "ChatRequestSettings", logger: Optional[Logger] = None, ) -> Union[str, List[str]]: @@ -20,8 +21,8 @@ async def complete_chat_async( This is the method that is called from the kernel to get a response from a chat-optimized LLM. Arguments: - messages {List[Tuple[str, str]]} -- A list of tuples, where each tuple is - comprised of a speaker ID and a message. + messages {List[ChatMessage]} -- A list of chat messages, that can be rendered into a + set of messages, from system, user, assistant and function. settings {ChatRequestSettings} -- Settings for the request. logger {Logger} -- A logger to use for logging. @@ -33,7 +34,7 @@ async def complete_chat_async( @abstractmethod async def complete_chat_stream_async( self, - messages: List[Tuple[str, str]], + messages: List["ChatMessage"], settings: "ChatRequestSettings", logger: Optional[Logger] = None, ): @@ -41,8 +42,8 @@ async def complete_chat_stream_async( This is the method that is called from the kernel to get a stream response from a chat-optimized LLM. Arguments: - messages {List[Tuple[str, str]]} -- A list of tuples, where each tuple is - comprised of a speaker ID and a message. + messages {List[ChatMessage]} -- A list of chat messages, that can be rendered into a + set of messages, from system, user, assistant and function. settings {ChatRequestSettings} -- Settings for the request. logger {Logger} -- A logger to use for logging. diff --git a/python/semantic_kernel/connectors/ai/chat_request_settings.py b/python/semantic_kernel/connectors/ai/chat_request_settings.py index 09bf53715a8f..5718f28fbe30 100644 --- a/python/semantic_kernel/connectors/ai/chat_request_settings.py +++ b/python/semantic_kernel/connectors/ai/chat_request_settings.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional if TYPE_CHECKING: from semantic_kernel.semantic_functions.prompt_template_config import ( @@ -19,6 +19,7 @@ class ChatRequestSettings: max_tokens: int = 256 token_selection_biases: Dict[int, int] = field(default_factory=dict) stop_sequences: List[str] = field(default_factory=list) + function_call: Optional[str] = None def update_from_completion_config( self, completion_config: "PromptTemplateConfig.CompletionConfig" @@ -31,6 +32,11 @@ def update_from_completion_config( self.presence_penalty = completion_config.presence_penalty self.frequency_penalty = completion_config.frequency_penalty self.token_selection_biases = completion_config.token_selection_biases + self.function_call = ( + completion_config.function_call + if hasattr(completion_config, "function_call") + else None + ) @staticmethod def from_completion_config( diff --git a/python/semantic_kernel/connectors/ai/open_ai/models/chat/function_call.py b/python/semantic_kernel/connectors/ai/open_ai/models/chat/function_call.py new file mode 100644 index 000000000000..6a645ba89e51 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/open_ai/models/chat/function_call.py @@ -0,0 +1,36 @@ +"""Class to hold chat messages.""" +import json +from typing import Dict, Tuple + +from semantic_kernel.orchestration.context_variables import ContextVariables +from semantic_kernel.sk_pydantic import SKBaseModel + + +class FunctionCall(SKBaseModel): + """Class to hold a function call response.""" + + name: str + arguments: str + + def parse_arguments(self) -> Dict[str, str]: + """Parse the arguments into a dictionary.""" + try: + return json.loads(self.arguments) + except json.JSONDecodeError: + return None + + def to_context_variables(self) -> ContextVariables: + """Return the arguments as a ContextVariables instance.""" + args = self.parse_arguments() + return ContextVariables(variables={k.lower(): v for k, v in args.items()}) + + def split_name(self) -> Tuple[str, str]: + """Split the name into a skill and function name.""" + if "-" not in self.name: + return None, self.name + return self.name.split("-") + + def split_name_dict(self) -> dict: + """Split the name into a skill and function name.""" + parts = self.split_name() + return {"skill_name": parts[0], "function_name": parts[1]} diff --git a/python/semantic_kernel/connectors/ai/open_ai/models/chat/open_ai_chat_message.py b/python/semantic_kernel/connectors/ai/open_ai/models/chat/open_ai_chat_message.py new file mode 100644 index 000000000000..4e0d90c2088c --- /dev/null +++ b/python/semantic_kernel/connectors/ai/open_ai/models/chat/open_ai_chat_message.py @@ -0,0 +1,14 @@ +"""Class to hold chat messages.""" +from typing import Optional + +from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import ( + FunctionCall, +) +from semantic_kernel.models.chat.chat_message import ChatMessage + + +class OpenAIChatMessage(ChatMessage): + """Class to hold openai chat messages, which might include name and function_call fields.""" + + name: Optional[str] = None + function_call: Optional[FunctionCall] = None diff --git a/python/semantic_kernel/connectors/ai/open_ai/semantic_functions/open_ai_chat_prompt_template.py b/python/semantic_kernel/connectors/ai/open_ai/semantic_functions/open_ai_chat_prompt_template.py new file mode 100644 index 000000000000..6b4b4d234f07 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/open_ai/semantic_functions/open_ai_chat_prompt_template.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft. All rights reserved. + +from logging import Logger +from typing import Any, Dict, List, Optional + +from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall +from semantic_kernel.connectors.ai.open_ai.models.chat.open_ai_chat_message import ( + OpenAIChatMessage, +) +from semantic_kernel.semantic_functions.chat_prompt_template import ChatPromptTemplate +from semantic_kernel.semantic_functions.prompt_template import PromptTemplate +from semantic_kernel.semantic_functions.prompt_template_config import ( + PromptTemplateConfig, +) +from semantic_kernel.template_engine.protocols.prompt_templating_engine import ( + PromptTemplatingEngine, +) + + +class OpenAIChatPromptTemplate(ChatPromptTemplate): + def add_function_response_message(self, name: str, content: Any) -> None: + """Add a function response message to the chat template.""" + self._messages.append( + OpenAIChatMessage(role="function", name=name, fixed_content=str(content)) + ) + + def add_message( + self, role: str, message: Optional[str] = None, **kwargs: Any + ) -> None: + """Add a message to the chat template. + + Arguments: + role: The role of the message, one of "user", "assistant", "system", "function" + message: The message to add, can include templating components. + kwargs: can be used by inherited classes. + name: the name of the function that was used, to be used with role: function + function_call: the function call that is specified, to be used with role: assistant + """ + name = kwargs.get("name") + if name is not None and role != "function": + self._log.warning("name is only used with role: function, ignoring") + name = None + function_call = kwargs.get("function_call") + if function_call is not None and role != "assistant": + self._log.warning( + "function_call is only used with role: assistant, ignoring" + ) + function_call = None + if function_call and not isinstance(function_call, FunctionCall): + self._log.warning( + "function_call is not a FunctionCall, ignoring: %s", function_call + ) + function_call = None + self._messages.append( + OpenAIChatMessage( + role=role, + content_template=PromptTemplate( + message, self._template_engine, self._prompt_config + ), + name=name, + function_call=function_call, + ) + ) + + @classmethod + def restore( + cls, + messages: List[Dict[str, str]], + template: str, + template_engine: PromptTemplatingEngine, + prompt_config: PromptTemplateConfig, + log: Optional[Logger] = None, + ) -> "OpenAIChatPromptTemplate": + """Restore a ChatPromptTemplate from a list of role and message pairs. + + If there is a chat_system_prompt in the prompt_config.completion settings, + that takes precedence over the first message in the list of messages, + if that is a system message. + """ + chat_template = cls(template, template_engine, prompt_config, log) + if ( + prompt_config.completion.chat_system_prompt + and messages[0]["role"] == "system" + ): + existing_system_message = messages.pop(0) + if ( + existing_system_message["message"] + != prompt_config.completion.chat_system_prompt + ): + chat_template._log.info( + "Overriding system prompt with chat_system_prompt, old system message: %s, new system message: %s", + existing_system_message["message"], + prompt_config.completion.chat_system_prompt, + ) + for message in messages: + chat_template.add_message( + message["role"], + message["message"], + name=message["name"], + function_call=message["function_call"], + ) + + return chat_template diff --git a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion.py index f4be4c42259c..5d7689f6f80c 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion.py @@ -1,10 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. from logging import Logger -from typing import Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import openai +from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall + +if TYPE_CHECKING: + from openai.openai_object import OpenAIObject + from semantic_kernel.connectors.ai.ai_exception import AIException from semantic_kernel.connectors.ai.chat_completion_client_base import ( ChatCompletionClientBase, @@ -64,24 +69,49 @@ def __init__( async def complete_chat_async( self, - messages: List[Tuple[str, str]], + messages: List[Dict[str, str]], request_settings: ChatRequestSettings, logger: Optional[Logger] = None, ) -> Union[str, List[str]]: - response = await self._send_chat_request(messages, request_settings, False) + # TODO: tracking on token counts/etc. + response = await self._send_chat_request( + messages, request_settings, False, None + ) if len(response.choices) == 1: return response.choices[0].message.content + return [choice.message.content for choice in response.choices] + + async def complete_chat_with_functions_async( + self, + messages: List[Dict[str, str]], + functions: List[Dict[str, Any]], + request_settings: ChatRequestSettings, + logger: Optional[Logger] = None, + ) -> Union[ + Tuple[Optional[str], Optional[FunctionCall]], + List[Tuple[Optional[str], Optional[FunctionCall]]], + ]: + # TODO: tracking on token counts/etc. + + response = await self._send_chat_request( + messages, request_settings, False, functions + ) + + if len(response.choices) == 1: + return _parse_message(response.choices[0].message, self._log) else: - return [choice.message.content for choice in response.choices] + return [ + _parse_message(choice.message, self._log) for choice in response.choices + ] async def complete_chat_stream_async( self, - messages: List[Tuple[str, str]], + messages: List[Dict[str, str]], request_settings: ChatRequestSettings, - logger: Optional[Logger] = None, ): - response = await self._send_chat_request(messages, request_settings, True) + # TODO: enable function calling + response = await self._send_chat_request(messages, request_settings, True, None) # parse the completion text(s) and yield them async for chunk in response: @@ -111,9 +141,8 @@ async def complete_async( Returns: str -- The completed text. """ - prompt_to_message = [("user", prompt)] + prompt_to_message = [{"role": "user", "content": prompt}] chat_settings = ChatRequestSettings.from_completion_config(request_settings) - response = await self._send_chat_request( prompt_to_message, chat_settings, False ) @@ -129,7 +158,7 @@ async def complete_stream_async( request_settings: CompleteRequestSettings, logger: Optional[Logger] = None, ): - prompt_to_message = [("user", prompt)] + prompt_to_message = [{"role": "user", "content": prompt}] chat_settings = ChatRequestSettings( temperature=request_settings.temperature, top_p=request_settings.top_p, @@ -159,13 +188,16 @@ async def _send_chat_request( messages: List[Tuple[str, str]], request_settings: ChatRequestSettings, stream: bool, + functions: Optional[List[Dict[str, Any]]] = None, ): """ Completes the given user message with an asynchronous stream. Arguments: - user_message {str} -- The message (from a user) to respond to. + messages {List[Tuple[str,str]]} -- The messages (from a user) to respond to. request_settings {ChatRequestSettings} -- The request settings. + stream {bool} -- Whether to stream the response. + functions {List[Dict[str, Any]]} -- The functions available to the api. Returns: str -- The completed text. @@ -186,57 +218,62 @@ async def _send_chat_request( "To complete a chat you need at least one message", ) - if messages[-1][0] != "user": + if messages[-1]["role"] in ["assistant", "system"]: raise AIException( AIException.ErrorCodes.InvalidRequest, - "The last message must be from the user", + "The last message must be from the user or a function output", ) - model_args = {} - if self._api_type in ["azure", "azure_ad"]: - model_args["engine"] = self._model_id - else: - model_args["model"] = self._model_id - - formatted_messages = [ - {"role": role, "content": message} for role, message in messages - ] + model_args = { + "api_key": self._api_key, + "api_type": self._api_type, + "api_base": self._endpoint, + "api_version": self._api_version, + "organization": self._org_id, + "engine" + if self._api_type in ["azure", "azure_ad"] + else "model": self._model_id, + "messages": messages, + "temperature": request_settings.temperature, + "top_p": request_settings.top_p, + "n": request_settings.number_of_responses, + "stream": stream, + "stop": ( + request_settings.stop_sequences + if request_settings.stop_sequences is not None + and len(request_settings.stop_sequences) > 0 + else None + ), + "max_tokens": request_settings.max_tokens, + "presence_penalty": request_settings.presence_penalty, + "frequency_penalty": request_settings.frequency_penalty, + "logit_bias": ( + request_settings.token_selection_biases + if request_settings.token_selection_biases is not None + and len(request_settings.token_selection_biases) > 0 + else {} + ), + } + + if functions and request_settings.function_call is not None: + model_args["function_call"] = request_settings.function_call + if request_settings.function_call != "auto": + model_args["functions"] = [ + func + for func in functions + if func["name"] == request_settings.function_call + ] + else: + model_args["functions"] = functions try: - response: Any = await openai.ChatCompletion.acreate( - **model_args, - api_key=self._api_key, - api_type=self._api_type, - api_base=self._endpoint, - api_version=self._api_version, - organization=self._org_id, - messages=formatted_messages, - temperature=request_settings.temperature, - top_p=request_settings.top_p, - n=request_settings.number_of_responses, - stream=stream, - stop=( - request_settings.stop_sequences - if request_settings.stop_sequences is not None - and len(request_settings.stop_sequences) > 0 - else None - ), - max_tokens=request_settings.max_tokens, - presence_penalty=request_settings.presence_penalty, - frequency_penalty=request_settings.frequency_penalty, - logit_bias=( - request_settings.token_selection_biases - if request_settings.token_selection_biases is not None - and len(request_settings.token_selection_biases) > 0 - else {} - ), - ) + response: Any = await openai.ChatCompletion.acreate(**model_args) except Exception as ex: raise AIException( AIException.ErrorCodes.ServiceError, "OpenAI service failed to complete the chat", ex, - ) + ) from ex # streaming does not have usage info, therefore checking the type of the response if not stream and "usage" in response: @@ -266,6 +303,31 @@ def _parse_choices(chunk): message += chunk.choices[0].delta.role + ": " if "content" in chunk.choices[0].delta: message += chunk.choices[0].delta.content + if "function_call" in chunk.choices[0].delta: + message += chunk.choices[0].delta.function_call index = chunk.choices[0].index return message, index + + +def _parse_message( + message: "OpenAIObject", logger: Optional[Logger] = None +) -> Tuple[Optional[str], Optional[FunctionCall]]: + """ + Parses the message. + + Arguments: + message {OpenAIObject} -- The message to parse. + + Returns: + Tuple[Optional[str], Optional[Dict]] -- The parsed message. + """ + content = message.content if hasattr(message, "content") else None + function_call = message.function_call if hasattr(message, "function_call") else None + if function_call: + function_call = FunctionCall( + name=function_call.name, + arguments=function_call.arguments, + ) + + return (content, function_call) diff --git a/python/semantic_kernel/connectors/ai/open_ai/utils.py b/python/semantic_kernel/connectors/ai/open_ai/utils.py new file mode 100644 index 000000000000..9b6da59fc6e6 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/open_ai/utils.py @@ -0,0 +1,184 @@ +from logging import Logger +from typing import Any, Dict, List, Optional + +from semantic_kernel import Kernel, SKContext +from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall +from semantic_kernel.connectors.ai.open_ai.semantic_functions.open_ai_chat_prompt_template import ( + OpenAIChatPromptTemplate, +) +from semantic_kernel.orchestration.sk_function_base import SKFunctionBase + + +def _describe_function(function: SKFunctionBase) -> Dict[str, str]: + """Create the object used for function_calling. + + Assumes that arguments for semantic functions are optional, for native functions required. + """ + func_view = function.describe() + return { + "name": f"{func_view.skill_name}-{func_view.name}", + "description": func_view.description, + "parameters": { + "type": "object", + "properties": { + param.name: {"description": param.description, "type": param.type_} + for param in func_view.parameters + }, + "required": [p.name for p in func_view.parameters if p.required], + }, + } + + +def get_function_calling_object( + kernel: Kernel, filter: Dict[str, List[str]] +) -> List[Dict[str, str]]: + """Create the object used for function_calling. + + args: + kernel: the kernel. + filter: a dictionary with keys + exclude_skill, include_skill, exclude_function, include_function + and lists of the required filter. + The function name should be in the format "skill_name-function_name". + Using exclude_skill and include_skill at the same time will raise an error. + Using exclude_function and include_function at the same time will raise an error. + If using include_* implies that all other function will be excluded. + Example: + filter = { + "exclude_skill": ["skill1", "skill2"], + "include_function": ["skill3-function1", "skill4-function2"], + } + will return only skill3-function1 and skill4-function2. + filter = { + "exclude_function": ["skill1-function1", "skill2-function2"], + } + will return all functions except skill1-function1 and skill2-function2. + caller_function_name: the name of the function that is calling the other functions. + returns: + a filtered list of dictionaries of the functions in the kernel that can be passed to the function calling api. + """ + include_skill = filter.get("include_skill", None) + exclude_skill = filter.get("exclude_skill", []) + include_function = filter.get("include_function", None) + exclude_function = filter.get("exclude_function", []) + if include_skill and exclude_skill: + raise ValueError( + "Cannot use both include_skill and exclude_skill at the same time." + ) + if include_function and exclude_function: + raise ValueError( + "Cannot use both include_function and exclude_function at the same time." + ) + if include_skill: + include_skill = [skill.lower() for skill in include_skill] + if exclude_skill: + exclude_skill = [skill.lower() for skill in exclude_skill] + if include_function: + include_function = [function.lower() for function in include_function] + if exclude_function: + exclude_function = [function.lower() for function in exclude_function] + result = [] + for ( + skill_name, + skill, + ) in kernel.skills.data.items(): + if skill_name in exclude_skill or ( + include_skill and skill_name not in include_skill + ): + continue + for function_name, function in skill.items(): + current_name = f"{skill_name}-{function_name}" + if current_name in exclude_function or ( + include_function and current_name not in include_function + ): + continue + result.append(_describe_function(function)) + return result + + +async def execute_function_call( + kernel: Kernel, function_call: FunctionCall, log: Optional[Logger] = None +) -> str: + result = await kernel.run_async( + kernel.func(**function_call.split_name_dict()), + input_vars=function_call.to_context_variables(), + ) + if log: + log.info(f"Function call result: {result}") + return str(result) + + +async def chat_completion_with_function_call( + kernel: Kernel, + context: SKContext, + functions: List[Dict[str, str]] = [], + chat_skill_name: Optional[str] = None, + chat_function_name: Optional[str] = None, + chat_function: Optional[SKFunctionBase] = None, + *, + log: Optional[Logger] = None, + **kwargs: Dict[str, Any], +) -> SKContext: + """Perform a chat completion with auto-executing function calling. + + This is a recursive function that will execute the chat function multiple times, + at least once to get a first completion, if a function_call is returned, + the function_call is executed (using the execute_function_call method), + the result is added to the chat prompt template and another completion is requested, + by calling the function again, if it returns a function_call, it is executed again, + until the maximum number of function calls is reached, + at that time a final completion is done without functions. + + args: + kernel: the kernel to use. + context: the context to use. + functions: the function calling object, + make sure to use get_function_calling_object method to create it. + Optional arguments: + chat_skill_name: the skill name of the chat function. + chat_function_name: the function name of the chat function. + chat_function: the chat function, if not provided, it will be retrieved from the kernel. + make sure to provide either the chat_function or the chat_skill_name and chat_function_name. + + log: the logger to use. + max_function_calls: the maximum number of function calls to execute, defaults to 5. + current_call_count: the current number of function calls executed. + + returns: + the context with the result of the chat completion, just like a regular invoke_async/run_async. + """ + # check the number of function calls + max_function_calls = kwargs.get("max_function_calls", 5) + current_call_count = kwargs.get("current_call_count", 0) + # get the chat function + if chat_function is None: + chat_function = kernel.func( + skill_name=chat_skill_name, function_name=chat_function_name + ) + assert isinstance( + chat_function._chat_prompt_template, OpenAIChatPromptTemplate + ), "Please make sure to initialize your chat function with the OpenAIChatPromptTemplate class." + context = await chat_function.invoke_async( + context=context, + # when the maximum number of function calls is reached, execute the chat function without Functions. + functions=[] if current_call_count >= max_function_calls else functions, + ) + function_call = context.objects.pop("function_call", None) + # if there is no function_call or if the content is not a FunctionCall object, return the context + if function_call is None or not isinstance(function_call, FunctionCall): + return context + result = await execute_function_call(kernel, function_call, log=log) + # add the result to the chat prompt template + chat_function._chat_prompt_template.add_function_response_message( + name=function_call.name, content=str(result) + ) + # request another completion + return await chat_completion_with_function_call( + kernel, + chat_function=chat_function, + functions=functions, + context=context, + log=log, + max_function_calls=max_function_calls, + current_call_count=current_call_count + 1, + ) diff --git a/python/semantic_kernel/core_skills/math_skill.py b/python/semantic_kernel/core_skills/math_skill.py index a8daba2fc89c..d0533feadbd8 100644 --- a/python/semantic_kernel/core_skills/math_skill.py +++ b/python/semantic_kernel/core_skills/math_skill.py @@ -27,6 +27,8 @@ class MathSkill(PydanticField): @sk_function_context_parameter( name="Amount", description="Amount to add", + type="number", + required=True, ) def add(self, initial_value_text: str, context: "SKContext") -> str: """ @@ -46,6 +48,8 @@ def add(self, initial_value_text: str, context: "SKContext") -> str: @sk_function_context_parameter( name="Amount", description="Amount to subtract", + type="number", + required=True, ) def subtract(self, initial_value_text: str, context: "SKContext") -> str: """ diff --git a/python/semantic_kernel/kernel.py b/python/semantic_kernel/kernel.py index 1d921da05945..9257086465e1 100644 --- a/python/semantic_kernel/kernel.py +++ b/python/semantic_kernel/kernel.py @@ -248,6 +248,7 @@ async def run_async( input_context: Optional[SKContext] = None, input_vars: Optional[ContextVariables] = None, input_str: Optional[str] = None, + **kwargs: Dict[str, Any], ) -> SKContext: # if the user passed in a context, prioritize it, but merge with any other inputs if input_context is not None: @@ -300,7 +301,7 @@ async def run_async( pipeline_step += 1 try: - context = await func.invoke_async(input=None, context=context) + context = await func.invoke_async(input=None, context=context, **kwargs) if context.error_occurred: self._log.error( diff --git a/python/semantic_kernel/models/chat/chat_message.py b/python/semantic_kernel/models/chat/chat_message.py new file mode 100644 index 000000000000..1f132e730f6e --- /dev/null +++ b/python/semantic_kernel/models/chat/chat_message.py @@ -0,0 +1,40 @@ +"""Class to hold chat messages.""" +from typing import TYPE_CHECKING, Dict, Optional + +from pydantic import Field + +from semantic_kernel.semantic_functions.prompt_template import PromptTemplate +from semantic_kernel.sk_pydantic import SKBaseModel + +if TYPE_CHECKING: + from semantic_kernel.orchestration.sk_context import SKContext + + +class ChatMessage(SKBaseModel): + """Class to hold chat messages.""" + + role: Optional[str] = "assistant" + fixed_content: Optional[str] = Field(default=None, init=False, alias="content") + content_template: Optional[PromptTemplate] = Field( + default=None, init=True, repr=False + ) + + @property + def content(self) -> Optional[str]: + """Return the content of the message.""" + return self.fixed_content + + async def render_message_async(self, context: "SKContext") -> None: + """Render the message. + The first time this is called for a given message, + it will render the message with the context at that time. + Subsequent calls will do nothing. + """ + if self.fixed_content is None: + self.fixed_content = await self.content_template.render_async(context) + + def as_dict(self) -> Dict[str, str]: + """Return the message as a dict. + Make sure to call render_message_async first to embed the context in the content. + """ + return self.dict(exclude_none=True, by_alias=True, exclude={"content_template"}) diff --git a/python/semantic_kernel/orchestration/sk_context.py b/python/semantic_kernel/orchestration/sk_context.py index f9b989ba093d..d8a0e5018fbb 100644 --- a/python/semantic_kernel/orchestration/sk_context.py +++ b/python/semantic_kernel/orchestration/sk_context.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from logging import Logger -from typing import Any, Generic, Literal, Optional, Tuple, Union +from typing import Any, Dict, Generic, Literal, Optional, Tuple, Union import pydantic as pdt @@ -25,9 +25,11 @@ class SKContext(SKGenericModel, Generic[SemanticTextMemoryT]): memory: SemanticTextMemoryT variables: ContextVariables + # This field can be used to hold anything that is not a string skill_collection: ReadOnlySkillCollection = pdt.Field( default_factory=ReadOnlySkillCollection ) + _objects: Dict[str, Any] = pdt.PrivateAttr(default_factory=dict) _error_occurred: bool = pdt.PrivateAttr(False) _last_exception: Optional[Exception] = pdt.PrivateAttr(None) _last_error_description: str = pdt.PrivateAttr("") @@ -118,6 +120,16 @@ def last_exception(self) -> Optional[Exception]: """ return self._last_exception + @property + def objects(self) -> Dict[str, Any]: + """ + The objects dictionary. + + Returns: + Dict[str, Any] -- The objects dictionary. + """ + return self._objects + @property def skills(self) -> ReadOnlySkillCollectionBase: """ diff --git a/python/semantic_kernel/orchestration/sk_function.py b/python/semantic_kernel/orchestration/sk_function.py index 7584d36caa22..c39d558d0b6f 100644 --- a/python/semantic_kernel/orchestration/sk_function.py +++ b/python/semantic_kernel/orchestration/sk_function.py @@ -6,7 +6,7 @@ import threading from enum import Enum from logging import Logger -from typing import TYPE_CHECKING, Any, Callable, List, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from semantic_kernel.connectors.ai.chat_completion_client_base import ( ChatCompletionClientBase, @@ -58,6 +58,7 @@ class SKFunction(SKFunctionBase): _ai_request_settings: CompleteRequestSettings _chat_service: Optional[ChatCompletionClientBase] _chat_request_settings: ChatRequestSettings + _chat_prompt_template: ChatPromptTemplate @staticmethod def from_native_method(method, skill_name="", log=None) -> "SKFunction": @@ -77,7 +78,11 @@ def from_native_method(method, skill_name="", log=None) -> "SKFunction": parameters.append( ParameterView( - param["name"], param["description"], param["default_value"] + name=param["name"], + description=param["description"], + default_value=param["default_value"], + type=param.get("type", "string"), + required=param.get("required", False), ) ) @@ -87,9 +92,11 @@ def from_native_method(method, skill_name="", log=None) -> "SKFunction": and method.__sk_function_input_description__ != "" ): input_param = ParameterView( - "input", - method.__sk_function_input_description__, - method.__sk_function_input_default_value__, + name="input", + description=method.__sk_function_input_description__, + default_value=method.__sk_function_input_default_value__, + type="string", + required=False, ) parameters = [input_param] + parameters @@ -115,41 +122,59 @@ def from_semantic_config( if function_config is None: raise ValueError("Function configuration cannot be `None`") - async def _local_func(client, request_settings, context): + async def _local_func(client, request_settings, context: "SKContext", **kwargs): if client is None: raise ValueError("AI LLM service cannot be `None`") try: - if function_config.has_chat_prompt: - as_chat_prompt = cast( - ChatPromptTemplate, function_config.prompt_template - ) - - # Similar to non-chat, render prompt (which renders to a - # list of messages) - messages = await as_chat_prompt.render_messages_async(context) - completion = await client.complete_chat_async( - messages, request_settings - ) - - # Add the last message from the rendered chat prompt - # (which will be the user message) and the response - # from the model (the assistant message) - _, content = messages[-1] - as_chat_prompt.add_user_message(content) - as_chat_prompt.add_assistant_message(completion) - - # Update context - context.variables.update(completion) - else: + if not function_config.has_chat_prompt: prompt = await function_config.prompt_template.render_async(context) completion = await client.complete_async(prompt, request_settings) context.variables.update(completion) + return context except Exception as e: # TODO: "critical exceptions" context.fail(str(e), e) + return context - return context + as_chat_prompt = function_config.prompt_template + # Similar to non-chat, render prompt (which renders to a + # dict of messages) + messages = await as_chat_prompt.render_messages_async(context) + + functions = ( + kwargs.get("functions") + if request_settings.function_call is not None + else None + ) + if request_settings.function_call is not None and functions is None: + log.warning("Function call is not None, but functions is None") + try: + if functions and hasattr(client, "complete_chat_with_functions_async"): + ( + completion, + function_call, + ) = await client.complete_chat_with_functions_async( + messages, functions, request_settings + ) + as_chat_prompt.add_message( + "assistant", message=completion, function_call=function_call + ) + if completion is not None: + context.variables.update(completion) + if function_call is not None: + context.objects["function_call"] = function_call + else: + completion = await client.complete_chat_async( + messages, request_settings + ) + as_chat_prompt.add_assistant_message(completion) + context.variables.update(completion) + except Exception as exc: + # TODO: "critical exceptions" + context.fail(str(exc), exc) + finally: + return context async def _local_stream_func(client, request_settings, context): if client is None: @@ -157,9 +182,7 @@ async def _local_stream_func(client, request_settings, context): try: if function_config.has_chat_prompt: - as_chat_prompt = cast( - ChatPromptTemplate, function_config.prompt_template - ) + as_chat_prompt = function_config.prompt_template # Similar to non-chat, render prompt (which renders to a # list of messages) @@ -204,6 +227,9 @@ async def _local_stream_func(client, request_settings, context): function_name=function_name, is_semantic=True, log=log, + chat_prompt_template=function_config.prompt_template + if function_config.has_chat_prompt + else None, ) @property @@ -245,6 +271,7 @@ def __init__( is_semantic: bool, log: Optional[Logger] = None, delegate_stream_function: Optional[Callable[..., Any]] = None, + **kwargs: Dict[str, Any], ) -> None: self._delegate_type = delegate_type self._function = delegate_function @@ -260,6 +287,7 @@ def __init__( self._ai_request_settings = CompleteRequestSettings() self._chat_service = None self._chat_request_settings = ChatRequestSettings() + self._chat_prompt_template = kwargs.get("chat_prompt_template", None) def set_default_skill_collection( self, skills: ReadOnlySkillCollectionBase @@ -382,6 +410,7 @@ async def invoke_async( memory: Optional[SemanticTextMemoryBase] = None, settings: Optional[CompleteRequestSettings] = None, log: Optional[Logger] = None, + **kwargs: Dict[str, Any], ) -> "SKContext": from semantic_kernel.orchestration.sk_context import SKContext @@ -406,14 +435,14 @@ async def invoke_async( try: if self.is_semantic: - return await self._invoke_semantic_async(context, settings) + return await self._invoke_semantic_async(context, settings, **kwargs) else: - return await self._invoke_native_async(context) + return await self._invoke_native_async(context, **kwargs) except Exception as e: context.fail(str(e), e) return context - async def _invoke_semantic_async(self, context, settings): + async def _invoke_semantic_async(self, context: "SKContext", settings, **kwargs): self._verify_is_semantic() self._ensure_context_has_skills(context) @@ -432,7 +461,9 @@ async def _invoke_semantic_async(self, context, settings): service = ( self._ai_service if self._ai_service is not None else self._chat_service ) - new_context = await self._function(service, settings, context) + new_context = await self._function( + service, settings, context, functions=kwargs.get("functions", None) + ) context.variables.merge_or_overwrite(new_context.variables) return context diff --git a/python/semantic_kernel/orchestration/sk_function_base.py b/python/semantic_kernel/orchestration/sk_function_base.py index 2633b452f7eb..1e0f0e411dd0 100644 --- a/python/semantic_kernel/orchestration/sk_function_base.py +++ b/python/semantic_kernel/orchestration/sk_function_base.py @@ -2,7 +2,7 @@ from abc import abstractmethod from logging import Logger -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from semantic_kernel.connectors.ai.complete_request_settings import ( CompleteRequestSettings, @@ -133,6 +133,7 @@ async def invoke_async( memory: Optional[SemanticTextMemoryBase] = None, settings: Optional[CompleteRequestSettings] = None, log: Optional[Logger] = None, + **kwargs: Dict[str, Any], ) -> "SKContext": """ Invokes the function with an explicit string input diff --git a/python/semantic_kernel/semantic_functions/chat_prompt_template.py b/python/semantic_kernel/semantic_functions/chat_prompt_template.py index c6da9676bb1a..8aa999e948ec 100644 --- a/python/semantic_kernel/semantic_functions/chat_prompt_template.py +++ b/python/semantic_kernel/semantic_functions/chat_prompt_template.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio from logging import Logger -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar +from semantic_kernel.models.chat.chat_message import ChatMessage from semantic_kernel.semantic_functions.prompt_template import PromptTemplate from semantic_kernel.semantic_functions.prompt_template_config import ( PromptTemplateConfig, @@ -14,9 +16,11 @@ if TYPE_CHECKING: from semantic_kernel.orchestration.sk_context import SKContext +ChatMessageT = TypeVar("ChatMessageT", bound=ChatMessage) -class ChatPromptTemplate(PromptTemplate): - _messages: List[Tuple[str, PromptTemplate]] + +class ChatPromptTemplate(PromptTemplate, Generic[ChatMessageT]): + _messages: List[ChatMessageT] def __init__( self, @@ -37,40 +41,52 @@ async def render_async(self, context: "SKContext") -> str: ) def add_system_message(self, message: str) -> None: + """Add a system message to the chat template.""" self.add_message("system", message) def add_user_message(self, message: str) -> None: + """Add a user message to the chat template.""" self.add_message("user", message) def add_assistant_message(self, message: str) -> None: + """Add an assistant message to the chat template.""" self.add_message("assistant", message) - def add_message(self, role: str, message: str) -> None: + def add_message( + self, role: str, message: Optional[str] = None, **kwargs: Any + ) -> None: + """Add a message to the chat template. + + Arguments: + role: The role of the message, one of "user", "assistant", "system". + message: The message to add, can include templating components. + kwargs: can be used by inherited classes. + """ self._messages.append( - (role, PromptTemplate(message, self._template_engine, self._prompt_config)) + ChatMessage( + role=role, + content_template=PromptTemplate( + message, self._template_engine, self._prompt_config + ), + ) ) - async def render_messages_async( - self, context: "SKContext" - ) -> List[Tuple[str, str]]: - rendered_messages = [] - for role, message in self._messages: - rendered_messages.append((role, await message.render_async(context))) - - latest_user_message = await self._template_engine.render_async( - self._template, context + async def render_messages_async(self, context: "SKContext") -> List[Dict[str, str]]: + """Render the content of the message in the chat template, based on the context.""" + if len(self._messages) == 0 or self._messages[-1].role in [ + "assistant", + "system", + ]: + self.add_user_message(message=self._template) + await asyncio.gather( + *[message.render_message_async(context) for message in self._messages] ) - rendered_messages.append(("user", latest_user_message)) - - return rendered_messages + return [message.as_dict() for message in self._messages] @property def messages(self) -> List[Dict[str, str]]: - """Return the messages as a list of tuples of role and message.""" - return [ - {"role": role, "message": message._template} - for role, message in self._messages - ] + """Return the messages as a list of dicts with role, content, name.""" + return [message.as_dict() for message in self._messages] @classmethod def restore( @@ -81,14 +97,27 @@ def restore( prompt_config: PromptTemplateConfig, log: Optional[Logger] = None, ) -> "ChatPromptTemplate": - """Restore a ChatPromptTemplate from a list of role and message pairs.""" - chat_template = cls(template, template_engine, prompt_config, log) - - if prompt_config.chat_system_prompt: - chat_template.add_system_message( - prompt_config.completion.chat_system_prompt - ) + """Restore a ChatPromptTemplate from a list of role and message pairs. + If there is a chat_system_prompt in the prompt_config.completion settings, + that takes precedence over the first message in the list of messages, + if that is a system message. + """ + chat_template = cls(template, template_engine, prompt_config, log) + if ( + prompt_config.completion.chat_system_prompt + and messages[0]["role"] == "system" + ): + existing_system_message = messages.pop(0) + if ( + existing_system_message["message"] + != prompt_config.completion.chat_system_prompt + ): + chat_template._log.info( + "Overriding system prompt with chat_system_prompt, old system message: %s, new system message: %s", + existing_system_message["message"], + prompt_config.completion.chat_system_prompt, + ) for message in messages: chat_template.add_message(message["role"], message["message"]) diff --git a/python/semantic_kernel/semantic_functions/prompt_template.py b/python/semantic_kernel/semantic_functions/prompt_template.py index 2fe943e0a547..b8ead27d112e 100644 --- a/python/semantic_kernel/semantic_functions/prompt_template.py +++ b/python/semantic_kernel/semantic_functions/prompt_template.py @@ -46,7 +46,11 @@ def get_parameters(self) -> List[ParameterView]: continue result.append( - ParameterView(param.name, param.description, param.default_value) + ParameterView( + name=param.name, + description=param.description, + default_value=param.default_value, + ) ) seen.add(param.name) @@ -62,7 +66,9 @@ def get_parameters(self) -> List[ParameterView]: if var_block.name in seen: continue - result.append(ParameterView(var_block.name, "", "")) + result.append( + ParameterView(name=var_block.name, description="", default_value="") + ) seen.add(var_block.name) diff --git a/python/semantic_kernel/semantic_functions/prompt_template_config.py b/python/semantic_kernel/semantic_functions/prompt_template_config.py index 9ce324e1a465..1d996d2054cd 100644 --- a/python/semantic_kernel/semantic_functions/prompt_template_config.py +++ b/python/semantic_kernel/semantic_functions/prompt_template_config.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass, field -from typing import Dict, List +from typing import Dict, List, Optional @dataclass @@ -17,12 +17,17 @@ class CompletionConfig: stop_sequences: List[str] = field(default_factory=list) token_selection_biases: Dict[int, int] = field(default_factory=dict) chat_system_prompt: str = None + # the function_call should be 'auto' or the name of a specific function in order to leverage function calling + # when not using auto, the format is 'SkillName-FunctionName', e.g. 'Weather-GetWeather' + function_call: Optional[str] = None @dataclass class InputParameter: name: str = "" description: str = "" default_value: str = "" + type_: str = "string" + required: bool = True @dataclass class InputConfig: @@ -61,6 +66,7 @@ def from_dict(data: dict) -> "PromptTemplateConfig": "token_selection_biases", "default_services", "chat_system_prompt", + "function_call", ] for comp_key in completion_keys: if comp_key in completion_dict: @@ -92,11 +98,16 @@ def from_dict(data: dict) -> "PromptTemplateConfig": f"Input parameter '{name}' doesn't have a default value (function: {config.description})" ) + type_ = parameter.get("type") + required = parameter.get("required") + config.input.parameters.append( PromptTemplateConfig.InputParameter( name, description, defaultValue, + type_, + required, ) ) return config @@ -123,6 +134,7 @@ def from_completion_parameters( stop_sequences: List[str] = [], token_selection_biases: Dict[int, int] = {}, chat_system_prompt: str = None, + function_call: Optional[str] = None, ) -> "PromptTemplateConfig": config = PromptTemplateConfig() config.completion.temperature = temperature @@ -134,4 +146,5 @@ def from_completion_parameters( config.completion.stop_sequences = stop_sequences config.completion.token_selection_biases = token_selection_biases config.completion.chat_system_prompt = chat_system_prompt + config.completion.function_call = function_call return config diff --git a/python/semantic_kernel/skill_definition/parameter_view.py b/python/semantic_kernel/skill_definition/parameter_view.py index 8aaa552d554e..38b3c794d730 100644 --- a/python/semantic_kernel/skill_definition/parameter_view.py +++ b/python/semantic_kernel/skill_definition/parameter_view.py @@ -1,5 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. + +from pydantic import Field, validator + from semantic_kernel.sk_pydantic import SKBaseModel from semantic_kernel.utils.validation import validate_function_param_name @@ -8,9 +11,10 @@ class ParameterView(SKBaseModel): name: str description: str default_value: str + type_: str = Field(default="string", alias="type") + required: bool = False - def __init__(self, name: str, description: str, default_value: str) -> None: + @validator("name") + def validate_name(cls, name: str): validate_function_param_name(name) - super().__init__( - name=name, description=description, default_value=default_value - ) + return name diff --git a/python/semantic_kernel/skill_definition/sk_function_context_parameter_decorator.py b/python/semantic_kernel/skill_definition/sk_function_context_parameter_decorator.py index d5aa6ad19dbc..c7eb2d670614 100644 --- a/python/semantic_kernel/skill_definition/sk_function_context_parameter_decorator.py +++ b/python/semantic_kernel/skill_definition/sk_function_context_parameter_decorator.py @@ -2,7 +2,12 @@ def sk_function_context_parameter( - *, name: str, description: str, default_value: str = "" + *, + name: str, + description: str, + default_value: str = "", + type: str = "string", + required: bool = False ): """ Decorator for SK function context parameters. @@ -11,6 +16,9 @@ def sk_function_context_parameter( name -- The name of the context parameter description -- The description of the context parameter default_value -- The default value of the context parameter + type -- The type of the context parameter, used for function calling + required -- Whether the context parameter is required + """ def decorator(func): @@ -22,6 +30,8 @@ def decorator(func): "name": name, "description": description, "default_value": default_value, + "type": type, + "required": required, } ) return func diff --git a/python/tests/unit/ai/open_ai/models/chat/test_function_call.py b/python/tests/unit/ai/open_ai/models/chat/test_function_call.py new file mode 100644 index 000000000000..e0200ff37d1f --- /dev/null +++ b/python/tests/unit/ai/open_ai/models/chat/test_function_call.py @@ -0,0 +1,28 @@ +import pytest + +from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall +from semantic_kernel.orchestration.context_variables import ContextVariables + + +def test_function_call(): + # Test initialization with default values + fc = FunctionCall(name="Test-Function", arguments="""{"input": "world"}""") + assert fc.name == "Test-Function" + assert fc.arguments == """{"input": "world"}""" + + +@pytest.mark.asyncio +async def test_function_call_to_content_variables(create_kernel): + # Test parsing arguments to variables + kernel = create_kernel + + func_call = FunctionCall( + name="Test-Function", + arguments="""{"input": "world", "input2": "world2"}""", + ) + context = kernel.create_new_context() + assert isinstance(func_call.to_context_variables(), ContextVariables) + + context.variables.merge_or_overwrite(func_call.to_context_variables()) + assert context.variables.input == "world" + assert context.variables["input2"] == "world2" diff --git a/python/tests/unit/ai/open_ai/services/test_azure_chat_completion.py b/python/tests/unit/ai/open_ai/services/test_azure_chat_completion.py index fefcb3c432d2..8ef3429232e1 100644 --- a/python/tests/unit/ai/open_ai/services/test_azure_chat_completion.py +++ b/python/tests/unit/ai/open_ai/services/test_azure_chat_completion.py @@ -240,12 +240,12 @@ async def test_azure_chat_completion_call_with_parameters_and_Stop_Defined() -> await azure_chat_completion.complete_async(prompt, complete_request_settings) mock_openai.ChatCompletion.acreate.assert_called_once_with( - engine=deployment_name, api_key=api_key, api_type=api_type, api_base=endpoint, api_version=api_version, organization=None, + engine=deployment_name, messages=messages, temperature=complete_request_settings.temperature, top_p=complete_request_settings.top_p, diff --git a/python/tests/unit/models/chat/test_chat_message.py b/python/tests/unit/models/chat/test_chat_message.py new file mode 100644 index 000000000000..375e412725e1 --- /dev/null +++ b/python/tests/unit/models/chat/test_chat_message.py @@ -0,0 +1,43 @@ +import pytest + +from semantic_kernel.models.chat.chat_message import ChatMessage +from semantic_kernel.semantic_functions.prompt_template import PromptTemplate +from semantic_kernel.semantic_functions.prompt_template_config import ( + PromptTemplateConfig, +) + + +def test_chat_message(): + # Test initialization with default values + message = ChatMessage() + assert message.role == "assistant" + assert message.fixed_content is None + assert message.content is None + assert message.content_template is None + + +@pytest.mark.asyncio +async def test_chat_message_rendering(create_kernel): + # Test initialization with custom values + kernel = create_kernel + expected_content = "Hello, world!" + prompt_config = PromptTemplateConfig.from_completion_parameters( + max_tokens=2000, temperature=0.7, top_p=0.8 + ) + content_template = PromptTemplate( + "Hello, {{$input}}!", kernel.prompt_template_engine, prompt_config + ) + + message = ChatMessage( + role="user", + content_template=content_template, + ) + context = kernel.create_new_context() + context.variables["input"] = "world" + await message.render_message_async(context) + assert message.role == "user" + assert message.fixed_content == expected_content + assert message.content_template == content_template + + # Test content property + assert message.content == expected_content diff --git a/python/tests/unit/orchestration/test_native_function.py b/python/tests/unit/orchestration/test_native_function.py index 4fb263d925a7..edb7a32c8c4d 100644 --- a/python/tests/unit/orchestration/test_native_function.py +++ b/python/tests/unit/orchestration/test_native_function.py @@ -34,9 +34,13 @@ def mock_function(input: str, context: "SKContext") -> None: assert native_function._parameters[0].name == "input" assert native_function._parameters[0].description == "Mock input description" assert native_function._parameters[0].default_value == "default_input_value" + assert native_function._parameters[0].type_ == "string" + assert native_function._parameters[0].required is False assert native_function._parameters[1].name == "param1" assert native_function._parameters[1].description == "Param 1 description" assert native_function._parameters[1].default_value == "default_param1_value" + assert native_function._parameters[1].type_ == "string" + assert native_function._parameters[1].required is False def test_init_native_function_without_input_description(): @@ -51,6 +55,7 @@ def mock_function(context: "SKContext") -> None: "name": "param1", "description": "Param 1 description", "default_value": "default_param1_value", + "required": True, } ] @@ -62,6 +67,8 @@ def mock_function(context: "SKContext") -> None: assert native_function._parameters[0].name == "param1" assert native_function._parameters[0].description == "Param 1 description" assert native_function._parameters[0].default_value == "default_param1_value" + assert native_function._parameters[0].type_ == "string" + assert native_function._parameters[0].required is True def test_init_native_function_from_sk_function_decorator(): @@ -90,6 +97,8 @@ def decorated_function() -> None: assert native_function._parameters[0].name == "input" assert native_function._parameters[0].description == "Test input description" assert native_function._parameters[0].default_value == "test_default_value" + assert native_function._parameters[0].type_ == "string" + assert native_function._parameters[0].required is False def test_init_native_function_from_sk_function_decorator_defaults(): diff --git a/python/tests/unit/skill_definition/test_prompt_templates.py b/python/tests/unit/skill_definition/test_prompt_templates.py index d8d51dc222e9..82c03875ba00 100644 --- a/python/tests/unit/skill_definition/test_prompt_templates.py +++ b/python/tests/unit/skill_definition/test_prompt_templates.py @@ -166,8 +166,10 @@ def test_chat_prompt_template_with_system_prompt(): None, prompt_config=prompt_template_config, ) - - print(chat_prompt_template.messages) + print(chat_prompt_template._messages) assert len(chat_prompt_template.messages) == 1 - assert chat_prompt_template.messages[0]["role"] == "system" - assert chat_prompt_template.messages[0]["message"] == "Custom system prompt." + assert chat_prompt_template._messages[0].role == "system" + assert ( + chat_prompt_template._messages[0].content_template._template + == "Custom system prompt." + ) diff --git a/python/tests/unit/test_serialization.py b/python/tests/unit/test_serialization.py index 1267b9f8d023..53d2d760daaa 100644 --- a/python/tests/unit/test_serialization.py +++ b/python/tests/unit/test_serialization.py @@ -129,12 +129,18 @@ def create_skill_collection() -> SkillCollection: CodeTokenizer: CodeTokenizer(log=logging.getLogger("test")), PromptTemplateEngine: PromptTemplateEngine(logger=logging.getLogger("test")), TemplateTokenizer: TemplateTokenizer(log=logging.getLogger("test")), - ParameterView: ParameterView("foo", "bar", default_value="baz"), + ParameterView: ParameterView( + name="foo", + description="bar", + default_value="baz", + type="string", + required=True, + ), FunctionView: FunctionView( "foo", "bar", "baz", - [ParameterView("qux", "bar", "baz")], + [ParameterView(name="qux", description="bar", default_value="baz")], True, False, ), diff --git a/samples/skills/FunSkill/Joke/config.json b/samples/skills/FunSkill/Joke/config.json index f712ee36de82..5ec9e5fe44ff 100644 --- a/samples/skills/FunSkill/Joke/config.json +++ b/samples/skills/FunSkill/Joke/config.json @@ -23,4 +23,4 @@ } ] } -} +} \ No newline at end of file From f5dc51e567230d127fa99024a6a7541bbe8dfa88 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Thu, 12 Oct 2023 18:48:51 +0100 Subject: [PATCH 2/2] .Net Fix for double retries (#3141) ### Motivation and Context SK has two mechanisms that should be used to enable/configure a retry policy for accessing REST APIs: - A custom HTTP client supplied as an argument of all `KernelBuilder.With*Service` extension methods. - A retry handler supplied through `KernelBuilder.WithHttpHandlerFactory`, `KernelBuilder.WithRetryPolly` or `KernelBuilder.WithRetryBasic' extension methods. ### Description This PR disables default policy provided Azure SDK library to avoid "double" retries and use the one configured by SK consumer instead. Related issue - https://github.com/microsoft/semantic-kernel/issues/2486 ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../AzureSdk/AzureOpenAIClientBase.cs | 25 +++++----- .../AzureSdk/OpenAIClientBase.cs | 21 +++++---- .../OpenAIKernelBuilderExtensions.cs | 1 + .../OpenAI/OpenAICompletionTests.cs | 46 +++++++------------ .../IntegrationTests/Planners/PlanTests.cs | 4 +- .../SequentialPlanParserTests.cs | 1 + .../SequentialPlannerTests.cs | 1 + .../StepwisePlanner/StepwisePlannerTests.cs | 4 +- 8 files changed, 52 insertions(+), 51 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/AzureOpenAIClientBase.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/AzureOpenAIClientBase.cs index 5a0e88d89bfe..71f193e032c3 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/AzureOpenAIClientBase.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/AzureOpenAIClientBase.cs @@ -41,11 +41,7 @@ private protected AzureOpenAIClientBase( Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'"); Verify.NotNullOrWhiteSpace(apiKey); - var options = GetClientOptions(); - if (httpClient != null) - { - options.Transport = new HttpClientTransport(httpClient); - } + var options = GetClientOptions(httpClient); this.ModelId = modelId; this.Client = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey), options); @@ -70,11 +66,7 @@ private protected AzureOpenAIClientBase( Verify.NotNullOrWhiteSpace(endpoint); Verify.StartsWith(endpoint, "https://", "The Azure OpenAI endpoint must start with 'https://'"); - var options = GetClientOptions(); - if (httpClient != null) - { - options.Transport = new HttpClientTransport(httpClient); - } + var options = GetClientOptions(httpClient); this.ModelId = modelId; this.Client = new OpenAIClient(new Uri(endpoint), credential, options); @@ -103,10 +95,11 @@ private protected AzureOpenAIClientBase( /// /// Options used by the Azure OpenAI client, e.g. User Agent. /// + /// Custom for HTTP requests. /// An instance of . - private static OpenAIClientOptions GetClientOptions() + private static OpenAIClientOptions GetClientOptions(HttpClient? httpClient) { - return new OpenAIClientOptions + var options = new OpenAIClientOptions { Diagnostics = { @@ -114,6 +107,14 @@ private static OpenAIClientOptions GetClientOptions() ApplicationId = Telemetry.HttpUserAgent, } }; + + if (httpClient != null) + { + options.Transport = new HttpClientTransport(httpClient); + options.RetryPolicy = new RetryPolicy(maxRetries: 0); //Disabling Azure SDK retry policy to use the one provided by the custom HTTP client. + } + + return options; } /// diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/OpenAIClientBase.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/OpenAIClientBase.cs index 0b68910e9f7c..ccd33cb90a4e 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/OpenAIClientBase.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/OpenAIClientBase.cs @@ -40,11 +40,7 @@ private protected OpenAIClientBase( this.ModelId = modelId; - var options = GetClientOptions(); - if (httpClient != null) - { - options.Transport = new HttpClientTransport(httpClient); - } + var options = GetClientOptions(httpClient); if (!string.IsNullOrWhiteSpace(organization)) { @@ -86,10 +82,11 @@ private protected void LogActionDetails([CallerMemberName] string? callerMemberN /// /// Options used by the OpenAI client, e.g. User Agent. /// - /// An instance of with the configured options. - private static OpenAIClientOptions GetClientOptions() + /// Custom for HTTP requests. + /// An instance of . + private static OpenAIClientOptions GetClientOptions(HttpClient? httpClient) { - return new OpenAIClientOptions + var options = new OpenAIClientOptions { Diagnostics = { @@ -97,5 +94,13 @@ private static OpenAIClientOptions GetClientOptions() ApplicationId = Telemetry.HttpUserAgent, } }; + + if (httpClient != null) + { + options.Transport = new HttpClientTransport(httpClient); + options.RetryPolicy = new RetryPolicy(maxRetries: 0); //Disabling Azure SDK retry policy to use the one provided by the custom HTTP client. + } + + return options; } } diff --git a/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs index 31bdcfcd6197..8f39afe34bd6 100644 --- a/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AI.OpenAI/OpenAIKernelBuilderExtensions.cs @@ -551,6 +551,7 @@ private static OpenAIClientOptions CreateOpenAIClientOptions(ILoggerFactory logg OpenAIClientOptions options = new(); #pragma warning disable CA2000 // Dispose objects before losing scope options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpHandlerFactory, httpClient, loggerFactory)); + options.RetryPolicy = new RetryPolicy(maxRetries: 0); //Disabling Azure SDK retry policy to use the one provided by the delegating handler factory or the HTTP client. #pragma warning restore CA2000 // Dispose objects before losing scope return options; diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs index aa034b99159f..b3aaed385d99 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs @@ -22,6 +22,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.OpenAI; public sealed class OpenAICompletionTests : IDisposable { + private readonly KernelBuilder _kernelBuilder; private readonly IConfigurationRoot _configuration; public OpenAICompletionTests(ITestOutputHelper output) @@ -37,6 +38,9 @@ public OpenAICompletionTests(ITestOutputHelper output) .AddEnvironmentVariables() .AddUserSecrets() .Build(); + + this._kernelBuilder = new KernelBuilder(); + this._kernelBuilder.WithRetryBasic(); } [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] @@ -47,7 +51,7 @@ public async Task OpenAITestAsync(string prompt, string expectedAnswerContains) var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); Assert.NotNull(openAIConfiguration); - IKernel target = Kernel.Builder + IKernel target = this._kernelBuilder .WithLoggerFactory(this._logger) .WithOpenAITextCompletionService( serviceId: openAIConfiguration.ServiceId, @@ -70,7 +74,7 @@ public async Task OpenAITestAsync(string prompt, string expectedAnswerContains) public async Task OpenAIChatAsTextTestAsync(string prompt, string expectedAnswerContains) { // Arrange - KernelBuilder builder = Kernel.Builder.WithLoggerFactory(this._logger); + KernelBuilder builder = this._kernelBuilder.WithLoggerFactory(this._logger); this.ConfigureChatOpenAI(builder); @@ -89,7 +93,7 @@ public async Task OpenAIChatAsTextTestAsync(string prompt, string expectedAnswer public async Task CanUseOpenAiChatForTextCompletionAsync() { // Note: we use OpenAi Chat Completion and GPT 3.5 Turbo - KernelBuilder builder = Kernel.Builder.WithLoggerFactory(this._logger); + KernelBuilder builder = this._kernelBuilder.WithLoggerFactory(this._logger); this.ConfigureChatOpenAI(builder); IKernel target = builder.Build(); @@ -110,7 +114,7 @@ public async Task CanUseOpenAiChatForTextCompletionAsync() public async Task AzureOpenAITestAsync(bool useChatModel, string prompt, string expectedAnswerContains) { // Arrange - var builder = Kernel.Builder.WithLoggerFactory(this._logger); + var builder = this._kernelBuilder.WithLoggerFactory(this._logger); if (useChatModel) { @@ -145,7 +149,7 @@ public async Task OpenAIHttpRetryPolicyTestAsync(string prompt, string expectedO OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); Assert.NotNull(openAIConfiguration); - IKernel target = Kernel.Builder + IKernel target = this._kernelBuilder .WithLoggerFactory(this._testOutputHelper) .WithRetryBasic(retryConfig) .WithOpenAITextCompletionService( @@ -173,7 +177,7 @@ public async Task AzureOpenAIHttpRetryPolicyTestAsync(string prompt, string expe var retryConfig = new BasicRetryConfig(); retryConfig.RetryableStatusCodes.Add(HttpStatusCode.Unauthorized); - KernelBuilder builder = Kernel.Builder + KernelBuilder builder = this._kernelBuilder .WithLoggerFactory(this._testOutputHelper) .WithRetryBasic(retryConfig); @@ -205,7 +209,7 @@ public async Task OpenAIHttpInvalidKeyShouldReturnErrorDetailAsync() Assert.NotNull(openAIConfiguration); // Use an invalid API key to force a 401 Unauthorized response - IKernel target = Kernel.Builder + IKernel target = this._kernelBuilder .WithOpenAITextCompletionService( modelId: openAIConfiguration.ModelId, apiKey: "INVALID_KEY", @@ -227,7 +231,7 @@ public async Task AzureOpenAIHttpInvalidKeyShouldReturnErrorDetailAsync() var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); - IKernel target = Kernel.Builder + IKernel target = this._kernelBuilder .WithLoggerFactory(this._testOutputHelper) .WithAzureTextCompletionService( deploymentName: azureOpenAIConfiguration.DeploymentName, @@ -251,7 +255,7 @@ public async Task AzureOpenAIHttpExceededMaxTokensShouldReturnErrorDetailAsync() Assert.NotNull(azureOpenAIConfiguration); // Arrange - IKernel target = Kernel.Builder + IKernel target = this._kernelBuilder .WithLoggerFactory(this._testOutputHelper) .WithAzureTextCompletionService( deploymentName: azureOpenAIConfiguration.DeploymentName, @@ -282,7 +286,7 @@ public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding, AIS const string ExpectedAnswerContains = "John"; - IKernel target = Kernel.Builder.WithLoggerFactory(this._logger).Build(); + IKernel target = this._kernelBuilder.WithLoggerFactory(this._logger).Build(); this._serviceConfiguration[service](target); @@ -299,7 +303,7 @@ public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding, AIS public async Task AzureOpenAIInvokePromptTestAsync() { // Arrange - var builder = Kernel.Builder.WithLoggerFactory(this._logger); + var builder = this._kernelBuilder.WithLoggerFactory(this._logger); this.ConfigureAzureOpenAI(builder); IKernel target = builder.Build(); @@ -316,7 +320,7 @@ public async Task AzureOpenAIInvokePromptTestAsync() public async Task AzureOpenAIDefaultValueTestAsync() { // Arrange - var builder = Kernel.Builder.WithLoggerFactory(this._logger); + var builder = this._kernelBuilder.WithLoggerFactory(this._logger); this.ConfigureAzureOpenAI(builder); IKernel target = builder.Build(); @@ -333,7 +337,7 @@ public async Task AzureOpenAIDefaultValueTestAsync() public async Task MultipleServiceLoadPromptConfigTestAsync() { // Arrange - var builder = Kernel.Builder.WithLoggerFactory(this._logger); + var builder = this._kernelBuilder.WithLoggerFactory(this._logger); this.ConfigureAzureOpenAI(builder); this.ConfigureInvalidAzureOpenAI(builder); @@ -397,22 +401,6 @@ private void Dispose(bool disposing) } } - private void ConfigureOpenAI(KernelBuilder kernelBuilder) - { - var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - - Assert.NotNull(openAIConfiguration); - Assert.NotNull(openAIConfiguration.ModelId); - Assert.NotNull(openAIConfiguration.ApiKey); - Assert.NotNull(openAIConfiguration.ServiceId); - - kernelBuilder.WithOpenAITextCompletionService( - modelId: openAIConfiguration.ModelId, - apiKey: openAIConfiguration.ApiKey, - serviceId: openAIConfiguration.ServiceId, - setAsDefault: true); - } - private void ConfigureChatOpenAI(KernelBuilder kernelBuilder) { var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); diff --git a/dotnet/src/IntegrationTests/Planners/PlanTests.cs b/dotnet/src/IntegrationTests/Planners/PlanTests.cs index 8ca1c786c44e..b4954fb91e0b 100644 --- a/dotnet/src/IntegrationTests/Planners/PlanTests.cs +++ b/dotnet/src/IntegrationTests/Planners/PlanTests.cs @@ -507,7 +507,9 @@ private IKernel InitializeKernel(bool useEmbeddings = false, bool useChatModel = AzureOpenAIConfiguration? azureOpenAIEmbeddingsConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); Assert.NotNull(azureOpenAIEmbeddingsConfiguration); - var builder = Kernel.Builder.WithLoggerFactory(this._loggerFactory); + var builder = Kernel.Builder + .WithLoggerFactory(this._loggerFactory) + .WithRetryBasic(); if (useChatModel) { diff --git a/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlanParserTests.cs b/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlanParserTests.cs index 5aba20da170c..b894666e9481 100644 --- a/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlanParserTests.cs +++ b/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlanParserTests.cs @@ -33,6 +33,7 @@ public void CanCallToPlanFromXml() Assert.NotNull(azureOpenAIConfiguration); IKernel kernel = Kernel.Builder + .WithRetryBasic() .WithAzureTextCompletionService( deploymentName: azureOpenAIConfiguration.DeploymentName, endpoint: azureOpenAIConfiguration.Endpoint, diff --git a/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlannerTests.cs b/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlannerTests.cs index a9e4e2ac6516..6236f64e4c12 100644 --- a/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlannerTests.cs +++ b/dotnet/src/IntegrationTests/Planners/SequentialPlanner/SequentialPlannerTests.cs @@ -118,6 +118,7 @@ private IKernel InitializeKernel(bool useEmbeddings = false, bool useChatModel = Assert.NotNull(azureOpenAIEmbeddingsConfiguration); var builder = Kernel.Builder.WithLoggerFactory(this._logger); + builder.WithRetryBasic(); if (useChatModel) { diff --git a/dotnet/src/IntegrationTests/Planners/StepwisePlanner/StepwisePlannerTests.cs b/dotnet/src/IntegrationTests/Planners/StepwisePlanner/StepwisePlannerTests.cs index e6869e8d42bf..c99ab62ec729 100644 --- a/dotnet/src/IntegrationTests/Planners/StepwisePlanner/StepwisePlannerTests.cs +++ b/dotnet/src/IntegrationTests/Planners/StepwisePlanner/StepwisePlannerTests.cs @@ -147,7 +147,9 @@ private IKernel InitializeKernel(bool useEmbeddings = false, bool useChatModel = AzureOpenAIConfiguration? azureOpenAIEmbeddingsConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); Assert.NotNull(azureOpenAIEmbeddingsConfiguration); - var builder = Kernel.Builder.WithLoggerFactory(this._loggerFactory); + var builder = Kernel.Builder + .WithLoggerFactory(this._loggerFactory) + .WithRetryBasic(); if (useChatModel) {