diff --git a/.github/workflows/dotnet-build-and-test.yml b/.github/workflows/dotnet-build-and-test.yml index b537588965eb..2e8f886c6728 100644 --- a/.github/workflows/dotnet-build-and-test.yml +++ b/.github/workflows/dotnet-build-and-test.yml @@ -144,6 +144,7 @@ jobs: # Azure AI Inference Endpoint AzureAIInference__ApiKey: ${{ secrets.AZUREAIINFERENCE__APIKEY }} AzureAIInference__Endpoint: ${{ secrets.AZUREAIINFERENCE__ENDPOINT }} + AzureAIInference__ChatModelId: ${{ vars.AZUREAIINFERENCE__CHATMODELID }} # Generate test reports and check coverage - name: Generate test reports diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 0e80ac0a37ba..05a06f7c9901 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -56,6 +56,7 @@ + diff --git a/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs index 38f2add47fa6..e42600419a88 100644 --- a/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs +++ b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs @@ -1,9 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text; +using Azure.AI.Inference; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.AzureAIInference; namespace ChatCompletion; @@ -15,9 +16,13 @@ public async Task ServicePromptAsync() { Console.WriteLine("======== Azure AI Inference - Chat Completion ========"); - var chatService = new AzureAIInferenceChatCompletionService( - endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), - apiKey: TestConfiguration.AzureAIInference.ApiKey); + Assert.NotNull(TestConfiguration.AzureAIInference.ApiKey); + + var chatService = new ChatCompletionsClient( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + credential: new Azure.AzureKeyCredential(TestConfiguration.AzureAIInference.ApiKey)) + .AsChatClient(TestConfiguration.AzureAIInference.ChatModelId) + .AsChatCompletionService(); Console.WriteLine("Chat content:"); Console.WriteLine("------------------------"); @@ -81,6 +86,7 @@ public async Task ChatPromptAsync() var kernel = Kernel.CreateBuilder() .AddAzureAIInferenceChatCompletion( + modelId: TestConfiguration.AzureAIInference.ChatModelId, endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), apiKey: TestConfiguration.AzureAIInference.ApiKey) .Build(); diff --git a/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs index 62c1fd3dcb11..f7dbe9191167 100644 --- a/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs @@ -1,9 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text; +using Azure.AI.Inference; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.AzureAIInference; namespace ChatCompletion; @@ -20,9 +21,11 @@ public Task StreamChatAsync() { Console.WriteLine("======== Azure AI Inference - Chat Completion Streaming ========"); - var chatService = new AzureAIInferenceChatCompletionService( - endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), - apiKey: TestConfiguration.AzureAIInference.ApiKey); + var chatService = new ChatCompletionsClient( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + credential: new Azure.AzureKeyCredential(TestConfiguration.AzureAIInference.ApiKey!)) + .AsChatClient(TestConfiguration.AzureAIInference.ChatModelId) + .AsChatCompletionService(); return this.StartStreamingChatAsync(chatService); } @@ -42,6 +45,7 @@ public async Task StreamChatPromptAsync() var kernel = Kernel.CreateBuilder() .AddAzureAIInferenceChatCompletion( + modelId: TestConfiguration.AzureAIInference.ChatModelId, endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), apiKey: TestConfiguration.AzureAIInference.ApiKey) .Build(); @@ -67,9 +71,11 @@ public async Task StreamTextFromChatAsync() Console.WriteLine("======== Stream Text from Chat Content ========"); // Create chat completion service - var chatService = new AzureAIInferenceChatCompletionService( - endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), - apiKey: TestConfiguration.AzureAIInference.ApiKey); + var chatService = new ChatCompletionsClient( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + credential: new Azure.AzureKeyCredential(TestConfiguration.AzureAIInference.ApiKey!)) + .AsChatClient(TestConfiguration.AzureAIInference.ChatModelId) + .AsChatCompletionService(); // Create chat history with initial system and user messages ChatHistory chatHistory = new("You are a librarian, an expert on books."); diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index 4cb4a5a91561..942a3eca849b 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -25,6 +25,7 @@ + diff --git a/dotnet/samples/Demos/AIModelRouter/Program.cs b/dotnet/samples/Demos/AIModelRouter/Program.cs index 54d996a8786f..d2ca630a8843 100644 --- a/dotnet/samples/Demos/AIModelRouter/Program.cs +++ b/dotnet/samples/Demos/AIModelRouter/Program.cs @@ -101,6 +101,7 @@ private static async Task Main(string[] args) { services.AddAzureAIInferenceChatCompletion( serviceId: "azureai", + modelId: config["AzureAIInference:ChatModelId"]!, endpoint: new Uri(config["AzureAIInference:Endpoint"]!), apiKey: config["AzureAIInference:ApiKey"]); diff --git a/dotnet/samples/Demos/AIModelRouter/README.md b/dotnet/samples/Demos/AIModelRouter/README.md index 30c1057eb4e6..18c556db1e15 100644 --- a/dotnet/samples/Demos/AIModelRouter/README.md +++ b/dotnet/samples/Demos/AIModelRouter/README.md @@ -26,6 +26,9 @@ dotnet user-secrets set "OpenAI:ChatModelId" ".. chat completion model .." (defa dotnet user-secrets set "AzureOpenAI:Endpoint" ".. endpoint .." dotnet user-secrets set "AzureOpenAI:ChatDeploymentName" ".. chat deployment name .." (default: gpt-4o) dotnet user-secrets set "AzureOpenAI:ApiKey" ".. api key .." (default: Authenticate with Azure CLI credential) +dotnet user-secrets set "AzureAIInference:ApiKey" ".. api key .." +dotnet user-secrets set "AzureAIInference:Endpoint" ".. endpoint .." +dotnet user-secrets set "AzureAIInference:ChatModelId" ".. chat completion model .." dotnet user-secrets set "LMStudio:Endpoint" ".. endpoint .." (default: http://localhost:1234) dotnet user-secrets set "Ollama:ModelId" ".. model id .." dotnet user-secrets set "Ollama:Endpoint" ".. endpoint .." (default: http://localhost:11434) diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj index acf3f919710f..d7e1f65ec24f 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj @@ -24,7 +24,6 @@ all - diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs index 8d5b31548b5f..5f1e784f2c72 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs @@ -6,7 +6,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.AzureAIInference; using Xunit; namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Extensions; @@ -37,7 +36,7 @@ public void KernelBuilderAddAzureAIInferenceChatCompletionAddsValidService(Initi // Assert var chatCompletionService = builder.Build().GetRequiredService(); - Assert.True(chatCompletionService is AzureAIInferenceChatCompletionService); + Assert.Equal("ChatClientChatCompletionService", chatCompletionService.GetType().Name); } public enum InitializationType diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs index 02b26f12921b..3f6895c9c637 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs @@ -3,10 +3,10 @@ using System; using Azure; using Azure.AI.Inference; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.AzureAIInference; using Xunit; namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Extensions; @@ -18,11 +18,13 @@ public sealed class AzureAIInferenceServiceCollectionExtensionsTests [Theory] [InlineData(InitializationType.ApiKey)] [InlineData(InitializationType.ClientInline)] + [InlineData(InitializationType.ChatClientInline)] [InlineData(InitializationType.ClientInServiceProvider)] public void ItCanAddChatCompletionService(InitializationType type) { // Arrange var client = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key")); + using var chatClient = new AzureAIInferenceChatClient(client, "model-id"); var builder = Kernel.CreateBuilder(); builder.Services.AddSingleton(client); @@ -32,19 +34,21 @@ public void ItCanAddChatCompletionService(InitializationType type) { InitializationType.ApiKey => builder.Services.AddAzureAIInferenceChatCompletion("modelId", "api-key", this._endpoint), InitializationType.ClientInline => builder.Services.AddAzureAIInferenceChatCompletion("modelId", client), + InitializationType.ChatClientInline => builder.Services.AddAzureAIInferenceChatCompletion(chatClient), InitializationType.ClientInServiceProvider => builder.Services.AddAzureAIInferenceChatCompletion("modelId", chatClient: null), _ => builder.Services }; // Assert var chatCompletionService = builder.Build().GetRequiredService(); - Assert.True(chatCompletionService is AzureAIInferenceChatCompletionService); + Assert.Equal("ChatClientChatCompletionService", chatCompletionService.GetType().Name); } public enum InitializationType { ApiKey, ClientInline, + ChatClientInline, ClientInServiceProvider, } } diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs index 44bd2c006661..417f32cc545b 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs @@ -21,6 +21,7 @@ namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Services; /// /// Tests for the class. /// +[Obsolete("Keeping this test until the service is removed from code-base")] public sealed class AzureAIInferenceChatCompletionServiceTests : IDisposable { private readonly Uri _endpoint = new("https://localhost:1234"); @@ -55,11 +56,11 @@ public void ConstructorsWorksAsExpected() // Act & Assert // Endpoint constructor - new AzureAIInferenceChatCompletionService(endpoint: this._endpoint, apiKey: null); // Only the endpoint - new AzureAIInferenceChatCompletionService(httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined + new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // Only the endpoint + new AzureAIInferenceChatCompletionService(modelId: "model", httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // ModelId and endpoint new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "api-key", endpoint: this._endpoint); // ModelId, apiKey, and endpoint - new AzureAIInferenceChatCompletionService(endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory + new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory // Breaking Glass constructor new AzureAIInferenceChatCompletionService(modelId: null, chatClient: client); // Client without model @@ -132,14 +133,14 @@ public async Task ItUsesHttpClientBaseAddressWhenNoEndpointIsProvidedAsync() public void ItThrowsIfNoEndpointOrNoHttpClientBaseAddressIsProvided() { // Act & Assert - Assert.Throws(() => new AzureAIInferenceChatCompletionService(endpoint: null, httpClient: this._httpClient)); + Assert.Throws(() => new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: null, httpClient: this._httpClient)); } [Fact] public async Task ItGetChatMessageContentsShouldHaveModelIdDefinedAsync() { // Arrange - var chatCompletion = new AzureAIInferenceChatCompletionService(apiKey: "NOKEY", httpClient: this._httpClientWithBaseAddress); + var chatCompletion = new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "NOKEY", httpClient: this._httpClientWithBaseAddress); this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = this.CreateDefaultStringContent() }; @@ -158,7 +159,7 @@ public async Task ItGetChatMessageContentsShouldHaveModelIdDefinedAsync() public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync() { // Arrange - var service = new AzureAIInferenceChatCompletionService(httpClient: this._httpClientWithBaseAddress); + var service = new AzureAIInferenceChatCompletionService(modelId: "model", httpClient: this._httpClientWithBaseAddress); await using var stream = File.OpenRead("TestData/chat_completion_streaming_response.txt"); this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) @@ -174,7 +175,9 @@ public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync() await enumerator.MoveNextAsync(); Assert.Equal("Test content", enumerator.Current.Content); - Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]); + Assert.IsType(enumerator.Current.InnerContent); + StreamingChatCompletionsUpdate innerContent = (StreamingChatCompletionsUpdate)enumerator.Current.InnerContent; + Assert.Equal("stop", innerContent.FinishReason); } [Fact] @@ -210,7 +213,7 @@ public async Task GetChatMessageContentsWithChatMessageContentItemCollectionCorr Assert.Equal(3, messages.GetArrayLength()); - Assert.Equal(Prompt, messages[0].GetProperty("content").GetString()); + Assert.Contains(Prompt, messages[0].GetProperty("content").GetRawText()); Assert.Equal("user", messages[0].GetProperty("role").GetString()); Assert.Equal(AssistantMessage, messages[1].GetProperty("content").GetString()); @@ -250,7 +253,7 @@ public async Task GetChatMessageInResponseFormatsAsync(string formatType, string break; } - var sut = new AzureAIInferenceChatCompletionService(httpClient: this._httpClientWithBaseAddress); + var sut = new AzureAIInferenceChatCompletionService("any", httpClient: this._httpClientWithBaseAddress); AzureAIInferencePromptExecutionSettings executionSettings = new() { ResponseFormat = format }; this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj b/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj index 2f87b005fda1..817994449fc3 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj @@ -29,6 +29,7 @@ - + + diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs index d559428cbb1b..48679be56f6f 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs @@ -2,22 +2,14 @@ using System; using System.Collections.Generic; -using System.Diagnostics.Metrics; -using System.Linq; using System.Net.Http; -using System.Runtime.CompilerServices; -using System.Text; -using System.Text.Json; using System.Threading; -using System.Threading.Tasks; using Azure; using Azure.AI.Inference; using Azure.Core; using Azure.Core.Pipeline; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Services; @@ -161,194 +153,8 @@ internal void AddAttribute(string key, string? value) } } - /// - /// Get chat multiple chat content choices for the prompt and settings. - /// - /// - /// This should be used when the settings request for more than one choice. - /// - /// The chat history context. - /// The AI execution settings (optional). - /// The containing services, plugins, and other state for use throughout the operation. - /// The to monitor for cancellation requests. The default is . - /// List of different chat results generated by the remote model - internal async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) - { - Verify.NotNull(chatHistory); - - // Convert the incoming execution settings to specialized settings. - AzureAIInferencePromptExecutionSettings chatExecutionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(executionSettings); - - ValidateMaxTokens(chatExecutionSettings.MaxTokens); - - // Create the SDK ChatCompletionOptions instance from all available information. - ChatCompletionsOptions chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chatHistory, kernel, this.ModelId); - - // Make the request. - ChatCompletions? responseData = null; - var extraParameters = chatExecutionSettings.ExtraParameters; - - List responseContent; - using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.ModelId ?? string.Empty, ModelProvider, chatHistory, chatExecutionSettings)) - { - try - { - responseData = (await RunRequestAsync(() => this.Client!.CompleteAsync(chatOptions, cancellationToken)).ConfigureAwait(false)).Value; - - this.LogUsage(responseData.Usage); - if (responseData is null) - { - throw new KernelException("Chat completions not found"); - } - } - catch (Exception ex) when (activity is not null) - { - activity.SetError(ex); - if (responseData != null) - { - // Capture available metadata even if the operation failed. - activity - .SetResponseId(responseData.Id) - .SetPromptTokenUsage(responseData.Usage.PromptTokens) - .SetCompletionTokenUsage(responseData.Usage.CompletionTokens); - } - throw; - } - - responseContent = [this.GetChatMessage(responseData)]; - activity?.SetCompletionResponse(responseContent, responseData.Usage.PromptTokens, responseData.Usage.CompletionTokens); - } - - return responseContent; - } - - /// - /// Get streaming chat contents for the chat history provided using the specified settings. - /// - /// Throws if the specified type is not the same or fail to cast - /// The chat history to complete. - /// The AI execution settings (optional). - /// The containing services, plugins, and other state for use throughout the operation. - /// The to monitor for cancellation requests. The default is . - /// Streaming list of different completion streaming string updates generated by the remote model - internal async IAsyncEnumerable GetStreamingChatMessageContentsAsync( - ChatHistory chatHistory, - PromptExecutionSettings? executionSettings = null, - Kernel? kernel = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Verify.NotNull(chatHistory); - - AzureAIInferencePromptExecutionSettings chatExecutionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(executionSettings); - - ValidateMaxTokens(chatExecutionSettings.MaxTokens); - - var chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chatHistory, kernel, this.ModelId); - StringBuilder? contentBuilder = null; - - // Reset state - contentBuilder?.Clear(); - - // Stream the response. - IReadOnlyDictionary? metadata = null; - ChatRole? streamedRole = default; - CompletionsFinishReason finishReason = default; - - using var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.ModelId ?? string.Empty, ModelProvider, chatHistory, chatExecutionSettings); - StreamingResponse response; - try - { - response = await RunRequestAsync(() => this.Client.CompleteStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false); - } - catch (Exception ex) when (activity is not null) - { - activity.SetError(ex); - throw; - } - - var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); - List? streamedContents = activity is not null ? [] : null; - try - { - while (true) - { - try - { - if (!await responseEnumerator.MoveNextAsync()) - { - break; - } - } - catch (Exception ex) when (activity is not null) - { - activity.SetError(ex); - throw; - } - - StreamingChatCompletionsUpdate update = responseEnumerator.Current; - metadata = GetResponseMetadata(update); - streamedRole ??= update.Role; - finishReason = update.FinishReason ?? default; - - AuthorRole? role = null; - if (streamedRole.HasValue) - { - role = new AuthorRole(streamedRole.Value.ToString()); - } - - StreamingChatMessageContent streamingChatMessageContent = - new(role: update.Role.HasValue ? new AuthorRole(update.Role.ToString()!) : null, content: update.ContentUpdate, innerContent: update, modelId: update.Model, metadata: metadata) - { - Role = role, - Metadata = metadata, - }; - - streamedContents?.Add(streamingChatMessageContent); - yield return streamingChatMessageContent; - } - } - finally - { - activity?.EndStreaming(streamedContents, null); - await responseEnumerator.DisposeAsync(); - } - } - #region Private - private const string ModelProvider = "azure-ai-inference"; - /// - /// Instance of for metrics. - /// - private static readonly Meter s_meter = new("Microsoft.SemanticKernel.Connectors.AzureAIInference"); - - /// - /// Instance of to keep track of the number of prompt tokens used. - /// - private static readonly Counter s_promptTokensCounter = - s_meter.CreateCounter( - name: "semantic_kernel.connectors.azure-ai-inference.tokens.prompt", - unit: "{token}", - description: "Number of prompt tokens used"); - - /// - /// Instance of to keep track of the number of completion tokens used. - /// - private static readonly Counter s_completionTokensCounter = - s_meter.CreateCounter( - name: "semantic_kernel.connectors.azure-ai-inference.tokens.completion", - unit: "{token}", - description: "Number of completion tokens used"); - - /// - /// Instance of to keep track of the total number of tokens used. - /// - private static readonly Counter s_totalTokensCounter = - s_meter.CreateCounter( - name: "semantic_kernel.connectors.azure-ai-inference.tokens.total", - unit: "{token}", - description: "Number of tokens used"); - /// /// Single space constant. /// @@ -378,266 +184,5 @@ private static AzureAIInferenceClientOptions GetClientOptions(HttpClient? httpCl return options; } - /// - /// Invokes the specified request and handles exceptions. - /// - /// Type of the response. - /// Request to invoke. - /// Returns the response. - private static async Task RunRequestAsync(Func> request) - { - try - { - return await request.Invoke().ConfigureAwait(false); - } - catch (RequestFailedException e) - { - throw e.ToHttpOperationException(); - } - } - - /// - /// Checks if the maximum tokens value is valid. - /// - /// Maximum tokens value. - /// Throws if the maximum tokens value is invalid. - private static void ValidateMaxTokens(int? maxTokens) - { - if (maxTokens.HasValue && maxTokens < 1) - { - throw new ArgumentException($"MaxTokens {maxTokens} is not valid, the value must be greater than zero"); - } - } - - /// - /// Creates a new instance of based on the provided settings. - /// - /// The execution settings. - /// The chat history. - /// Kernel instance. - /// Model ID. - /// Create a new instance of . - private ChatCompletionsOptions CreateChatCompletionsOptions( - AzureAIInferencePromptExecutionSettings executionSettings, - ChatHistory chatHistory, - Kernel? kernel, - string? modelId) - { - if (this.Logger.IsEnabled(LogLevel.Trace)) - { - this.Logger.LogTrace("ChatHistory: {ChatHistory}, Settings: {Settings}", - JsonSerializer.Serialize(chatHistory), - JsonSerializer.Serialize(executionSettings)); - } - - var options = new ChatCompletionsOptions - { - MaxTokens = executionSettings.MaxTokens, - Temperature = executionSettings.Temperature, - NucleusSamplingFactor = executionSettings.NucleusSamplingFactor, - FrequencyPenalty = executionSettings.FrequencyPenalty, - PresencePenalty = executionSettings.PresencePenalty, - Model = modelId, - Seed = executionSettings.Seed, - }; - - switch (executionSettings.ResponseFormat) - { - case ChatCompletionsResponseFormat formatObject: - // If the response format is an Azure SDK ChatCompletionsResponseFormat, just pass it along. - options.ResponseFormat = formatObject; - break; - - case string formatString: - // If the response format is a string, map the ones we know about, and ignore the rest. - switch (formatString) - { - case "json_object": - options.ResponseFormat = new ChatCompletionsResponseFormatJSON(); - break; - - case "text": - options.ResponseFormat = new ChatCompletionsResponseFormatText(); - break; - } - break; - - case JsonElement formatElement: - // This is a workaround for a type mismatch when deserializing a JSON into an object? type property. - // Handling only string formatElement. - if (formatElement.ValueKind == JsonValueKind.String) - { - string formatString = formatElement.GetString() ?? ""; - switch (formatString) - { - case "json_object": - options.ResponseFormat = new ChatCompletionsResponseFormatJSON(); - break; - - case "text": - options.ResponseFormat = new ChatCompletionsResponseFormatText(); - break; - } - } - break; - } - - if (executionSettings.StopSequences is { Count: > 0 }) - { - foreach (var s in executionSettings.StopSequences) - { - options.StopSequences.Add(s); - } - } - - foreach (var message in chatHistory) - { - options.Messages.AddRange(GetRequestMessages(message)); - } - - return options; - } - - /// - /// Create request messages based on the chat message content. - /// - /// Chat message content. - /// A list of . - /// When the message role is not supported. - private static List GetRequestMessages(ChatMessageContent message) - { - if (message.Role == AuthorRole.System) - { - return [new ChatRequestSystemMessage(message.Content)]; - } - - if (message.Role == AuthorRole.User) - { - if (message.Items is { Count: 1 } && message.Items.FirstOrDefault() is TextContent textContent) - { - // Name removed temporarily as the Azure AI Inference service does not support it ATM. - // Issue: https://github.com/Azure/azure-sdk-for-net/issues/45415 - return [new ChatRequestUserMessage(textContent.Text) /*{ Name = message.AuthorName }*/ ]; - } - - return [new ChatRequestUserMessage(message.Items.Select(static (KernelContent item) => (ChatMessageContentItem)(item switch - { - TextContent textContent => new ChatMessageTextContentItem(textContent.Text), - ImageContent imageContent => GetImageContentItem(imageContent), - _ => throw new NotSupportedException($"Unsupported chat message content type '{item.GetType()}'.") - }))) - - // Name removed temporarily as the Azure AI Inference service does not support it ATM. - // Issue: https://github.com/Azure/azure-sdk-for-net/issues/45415 - /*{ Name = message.AuthorName }*/]; - } - - if (message.Role == AuthorRole.Assistant) - { - // Name removed temporarily as the Azure AI Inference service does not support it ATM. - // Issue: https://github.com/Azure/azure-sdk-for-net/issues/45415 - return [new ChatRequestAssistantMessage(message.Content) { /* Name = message.AuthorName */ }]; - } - - throw new NotSupportedException($"Role {message.Role} is not supported."); - } - - /// - /// Create a new instance of based on the provided - /// - /// Target . - /// new instance of - /// When the does not have Data or Uri. - private static ChatMessageImageContentItem GetImageContentItem(ImageContent imageContent) - { - if (imageContent.Data is { IsEmpty: false } data) - { - return new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MimeType); - } - - if (imageContent.Uri is not null) - { - return new ChatMessageImageContentItem(imageContent.Uri); - } - - throw new ArgumentException($"{nameof(ImageContent)} must have either Data or a Uri."); - } - - /// - /// Captures usage details, including token information. - /// - /// Instance of with usage details. - private void LogUsage(CompletionsUsage usage) - { - if (usage is null) - { - this.Logger.LogDebug("Token usage information unavailable."); - return; - } - - if (this.Logger.IsEnabled(LogLevel.Information)) - { - this.Logger.LogInformation( - "Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.", - usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens); - } - - s_promptTokensCounter.Add(usage.PromptTokens); - s_completionTokensCounter.Add(usage.CompletionTokens); - s_totalTokensCounter.Add(usage.TotalTokens); - } - - /// - /// Create a new based on the provided and . - /// - /// The object containing the response data. - /// A new object. - private ChatMessageContent GetChatMessage(ChatCompletions responseData) - { - var message = new ChatMessageContent( - new AuthorRole(responseData.Role.ToString()), - responseData.Content, - responseData.Model, - innerContent: responseData, - metadata: GetChatChoiceMetadata(responseData) - ); - return message; - } - - /// - /// Create the metadata dictionary based on the provided and . - /// - /// The object containing the response data. - /// A new dictionary with metadata. - private static Dictionary GetChatChoiceMetadata(ChatCompletions completions) - { - return new Dictionary(5) - { - { nameof(completions.Id), completions.Id }, - { nameof(completions.Created), completions.Created }, - { nameof(completions.Usage), completions.Usage }, - - // Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it. - { nameof(completions.FinishReason), completions.FinishReason?.ToString() }, - }; - } - - /// - /// Create the metadata dictionary based on the provided . - /// - /// The object containing the response data. - /// A new dictionary with metadata. - private static Dictionary GetResponseMetadata(StreamingChatCompletionsUpdate completions) - { - return new Dictionary(3) - { - { nameof(completions.Id), completions.Id }, - { nameof(completions.Created), completions.Created }, - - // Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it. - { nameof(completions.FinishReason), completions.FinishReason?.ToString() }, - }; - } - #endregion } diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs index c1760d4ac316..d234d66846a2 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs @@ -17,7 +17,7 @@ public static class AzureAIInferenceKernelBuilderExtensions /// Adds the to the . /// /// The instance to augment. - /// Target Model Id for endpoints supporting more than one model + /// Target Model Id /// API Key /// Endpoint / Target URI /// Custom for HTTP requests. @@ -25,7 +25,7 @@ public static class AzureAIInferenceKernelBuilderExtensions /// The same instance as . public static IKernelBuilder AddAzureAIInferenceChatCompletion( this IKernelBuilder builder, - string? modelId = null, + string modelId, string? apiKey = null, Uri? endpoint = null, HttpClient? httpClient = null, @@ -42,7 +42,7 @@ public static IKernelBuilder AddAzureAIInferenceChatCompletion( /// Adds the to the . /// /// The instance to augment. - /// Target Model Id for endpoints supporting more than one model + /// Target Model Id /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. /// Endpoint / Target URI /// Custom for HTTP requests. @@ -50,7 +50,7 @@ public static IKernelBuilder AddAzureAIInferenceChatCompletion( /// The same instance as . public static IKernelBuilder AddAzureAIInferenceChatCompletion( this IKernelBuilder builder, - string? modelId, + string modelId, TokenCredential credential, Uri? endpoint = null, HttpClient? httpClient = null, diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs index b508b38537d3..387d9b89a62a 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs @@ -4,10 +4,11 @@ using System.Net.Http; using Azure.AI.Inference; using Azure.Core; +using Azure.Core.Pipeline; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Connectors.AzureAIInference; using Microsoft.SemanticKernel.Http; namespace Microsoft.SemanticKernel; @@ -18,10 +19,10 @@ namespace Microsoft.SemanticKernel; public static class AzureAIInferenceServiceCollectionExtensions { /// - /// Adds the to the . + /// Adds an Azure AI Inference to the . /// /// The instance to augment. - /// Target Model Id for endpoints supporting more than one model + /// Target Model Id /// API Key /// Endpoint / Target URI /// Custom for HTTP requests. @@ -29,7 +30,7 @@ public static class AzureAIInferenceServiceCollectionExtensions /// The same instance as . public static IServiceCollection AddAzureAIInferenceChatCompletion( this IServiceCollection services, - string? modelId = null, + string modelId, string? apiKey = null, Uri? endpoint = null, HttpClient? httpClient = null, @@ -37,23 +38,41 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( { Verify.NotNull(services); - AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, object? _) => - new(modelId, - apiKey, - endpoint, - HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - serviceProvider.GetService()); + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + var chatClientBuilder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + var logger = serviceProvider.GetService()?.CreateLogger(); + if (logger is not null) + { + chatClientBuilder.UseLogging(logger); + } + + var options = new AzureAIInferenceClientOptions(); + if (httpClient is not null) + { + options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + } - services.AddKeyedSingleton(serviceId, (Func)Factory); + return + chatClientBuilder.Use( + new Microsoft.Extensions.AI.AzureAIInferenceChatClient( + modelId: modelId, + chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options) + ) + ).AsChatCompletionService(); + }); return services; } /// - /// Adds the to the . + /// Adds an Azure AI Inference to the . /// /// The instance to augment. - /// Target Model Id for endpoints supporting more than one model + /// Target Model Id /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. /// Endpoint / Target URI /// Custom for HTTP requests. @@ -61,7 +80,7 @@ AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, /// The same instance as . public static IServiceCollection AddAzureAIInferenceChatCompletion( this IServiceCollection services, - string? modelId, + string modelId, TokenCredential credential, Uri? endpoint = null, HttpClient? httpClient = null, @@ -69,20 +88,38 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( { Verify.NotNull(services); - AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, object? _) => - new(modelId, - credential, - endpoint, - HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - serviceProvider.GetService()); + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + var chatClientBuilder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + var logger = serviceProvider.GetService()?.CreateLogger(); + if (logger is not null) + { + chatClientBuilder.UseLogging(logger); + } - services.AddKeyedSingleton(serviceId, (Func)Factory); + var options = new AzureAIInferenceClientOptions(); + if (httpClient is not null) + { + options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + } + + return + chatClientBuilder.Use( + new Microsoft.Extensions.AI.AzureAIInferenceChatClient( + modelId: modelId, + chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options) + ) + ).AsChatCompletionService(); + }); return services; } /// - /// Adds the to the . + /// Adds an Azure AI Inference to the . /// /// The instance to augment. /// Azure AI Inference model id @@ -96,11 +133,88 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService { Verify.NotNull(services); - AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, object? _) => - new(modelId, chatClient ?? serviceProvider.GetRequiredService(), serviceProvider.GetService()); + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + chatClient ??= serviceProvider.GetRequiredService(); + + var chatClientBuilder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - services.AddKeyedSingleton(serviceId, (Func)Factory); + var logger = serviceProvider.GetService()?.CreateLogger(); + if (logger is not null) + { + chatClientBuilder.UseLogging(logger); + } + + return chatClientBuilder + .Use(new Microsoft.Extensions.AI.AzureAIInferenceChatClient(chatClient, modelId)) + .AsChatCompletionService(); + }); return services; } + + /// + /// Adds an Azure AI Inference to the . + /// + /// The instance to augment. + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatCompletion(this IServiceCollection services, + AzureAIInferenceChatClient? chatClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + chatClient ??= serviceProvider.GetRequiredService(); + + var chatClientBuilder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + var logger = serviceProvider.GetService()?.CreateLogger(); + if (logger is not null) + { + chatClientBuilder.UseLogging(logger); + } + + return chatClientBuilder + .Use(chatClient) + .AsChatCompletionService(); + }); + + return services; + } + + #region Private + + /// + /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current + /// asynchronous chain of execution. + /// + /// + /// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that + /// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself, + /// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close + /// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution. + /// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in + /// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that + /// prompt function could advertize itself as a candidate for auto-invocation. We don't want to outright block that, + /// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent + /// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made + /// configurable should need arise. + /// + private const int MaxInflightAutoInvokes = 128; + + /// + /// When using Azure AI Inference against Gateway APIs that don't require an API key, + /// this single space is used to avoid breaking the client. + /// + private const string SingleSpace = " "; + + #endregion } diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs index 0b55ac3cd696..392f93b47147 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Azure.AI.Inference; using Azure.Core; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; @@ -16,9 +17,11 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAIInference; /// /// Chat completion service for Azure AI Inference. /// +[Obsolete("Dedicated AzureAIInferenceChatCompletionService is deprecated. Use Azure.AI.Inference.ChatCompletionsClient.AsChatClient().AsChatCompletionService() instead.")] public sealed class AzureAIInferenceChatCompletionService : IChatCompletionService { private readonly ChatClientCore _core; + private readonly IChatCompletionService _chatService; /// /// Initializes a new instance of the class. @@ -29,18 +32,32 @@ public sealed class AzureAIInferenceChatCompletionService : IChatCompletionServi /// Custom for HTTP requests. /// The to use for logging. If null, no logging will be performed. public AzureAIInferenceChatCompletionService( - string? modelId = null, + string modelId, string? apiKey = null, Uri? endpoint = null, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { + var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService)); this._core = new( modelId, apiKey, endpoint, httpClient, - loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService))); + logger); + + var builder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (logger is not null) + { + builder = builder.UseLogging(logger); + } + + this._chatService = builder + .Use(this._core.Client.AsChatClient(modelId)) + .AsChatCompletionService(); } /// @@ -58,12 +75,26 @@ public AzureAIInferenceChatCompletionService( HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { + var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService)); this._core = new( modelId, credential, endpoint, httpClient, - loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService))); + logger); + + var builder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (logger is not null) + { + builder = builder.UseLogging(logger); + } + + this._chatService = builder + .Use(this._core.Client.AsChatClient(modelId)) + .AsChatCompletionService(); } /// @@ -77,10 +108,24 @@ public AzureAIInferenceChatCompletionService( ChatCompletionsClient chatClient, ILoggerFactory? loggerFactory = null) { + var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService)); this._core = new( modelId, chatClient, - loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService))); + logger); + + var builder = new ChatClientBuilder() + .UseFunctionInvocation(config => + config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (logger is not null) + { + builder = builder.UseLogging(logger); + } + + this._chatService = builder + .Use(this._core.Client.AsChatClient(modelId)) + .AsChatCompletionService(); } /// @@ -88,9 +133,15 @@ public AzureAIInferenceChatCompletionService( /// public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) - => this._core.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); + => this._chatService.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); /// public IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) - => this._core.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); + => this._chatService.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); + + #region Private + + private const int MaxInflightAutoInvokes = 128; + + #endregion } diff --git a/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs index 140e16fc97cc..97fedbc73f0b 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs @@ -6,6 +6,7 @@ using System.Net.Http; using System.Text; using System.Threading.Tasks; +using Azure; using Azure.AI.Inference; using Azure.Identity; using Microsoft.Extensions.Configuration; @@ -42,16 +43,7 @@ public async Task InvokeGetChatMessageContentsAsync(string prompt, string expect var config = this._configuration.GetSection("AzureAIInference").Get(); Assert.NotNull(config); - var sut = (config.ApiKey is not null) - ? new AzureAIInferenceChatCompletionService( - endpoint: config.Endpoint, - apiKey: config.ApiKey, - loggerFactory: this._loggerFactory) - : new AzureAIInferenceChatCompletionService( - modelId: null, - endpoint: config.Endpoint, - credential: new AzureCliCredential(), - loggerFactory: this._loggerFactory); + IChatCompletionService sut = this.CreateChatService(config); ChatHistory chatHistory = [ new ChatMessageContent(AuthorRole.User, prompt) @@ -73,16 +65,7 @@ public async Task InvokeGetStreamingChatMessageContentsAsync(string prompt, stri var config = this._configuration.GetSection("AzureAIInference").Get(); Assert.NotNull(config); - var sut = (config.ApiKey is not null) - ? new AzureAIInferenceChatCompletionService( - endpoint: config.Endpoint, - apiKey: config.ApiKey, - loggerFactory: this._loggerFactory) - : new AzureAIInferenceChatCompletionService( - modelId: null, - endpoint: config.Endpoint, - credential: new AzureCliCredential(), - loggerFactory: this._loggerFactory); + IChatCompletionService sut = this.CreateChatService(config); ChatHistory chatHistory = [ new ChatMessageContent(AuthorRole.User, prompt) @@ -150,10 +133,11 @@ public async Task ItHttpRetryPolicyTestAsync() var config = this._configuration.GetSection("AzureAIInference").Get(); Assert.NotNull(config); Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ChatModelId); var kernelBuilder = Kernel.CreateBuilder(); - kernelBuilder.AddAzureAIInferenceChatCompletion(endpoint: config.Endpoint, apiKey: null); + kernelBuilder.AddAzureAIInferenceChatCompletion(modelId: config.ChatModelId, endpoint: config.Endpoint, apiKey: null); kernelBuilder.Services.ConfigureHttpClientDefaults(c => { @@ -176,11 +160,11 @@ public async Task ItHttpRetryPolicyTestAsync() var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; // Act - var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); + var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); // Assert Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); - Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode); + Assert.Equal((int)HttpStatusCode.Unauthorized, ((RequestFailedException)exception).Status); } [Fact] @@ -235,10 +219,12 @@ private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) Assert.NotNull(config); Assert.NotNull(config.ApiKey); Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ChatModelId); var kernelBuilder = base.CreateKernelBuilder(); kernelBuilder.AddAzureAIInferenceChatCompletion( + config.ChatModelId, endpoint: config.Endpoint, apiKey: config.ApiKey, serviceId: config.ServiceId, @@ -247,6 +233,33 @@ private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) return kernelBuilder.Build(); } + private IChatCompletionService CreateChatService(AzureAIInferenceConfiguration config) + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(this._loggerFactory); + + Assert.NotNull(config.ChatModelId); + + if (config.ApiKey is not null) + { + serviceCollection.AddAzureAIInferenceChatCompletion( + modelId: config.ChatModelId, + endpoint: config.Endpoint, + apiKey: config.ApiKey); + } + else + { + serviceCollection.AddAzureAIInferenceChatCompletion( + modelId: config.ChatModelId, + endpoint: config.Endpoint, + credential: new AzureCliCredential()); + } + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + return serviceProvider.GetRequiredService(); + } + public void Dispose() { this._loggerFactory.Dispose(); diff --git a/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletion_FunctionCallingTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletion_FunctionCallingTests.cs new file mode 100644 index 000000000000..f43ba4ec8fa6 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletion_FunctionCallingTests.cs @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +using ChatMessageContent = Microsoft.SemanticKernel.ChatMessageContent; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureAIInference; + +public sealed class AzureAIInferenceChatCompletionFunctionCallingTests : BaseIntegrationTest +{ + // Complex parameters currently don't work (tested against llama3.2 model) + [Fact(Skip = "For manual verification only")] + public async Task CanAutoInvokeKernelFunctionsWithComplexTypeParametersAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + kernel.ImportPluginFromFunctions("HelperFunctions", + [ + kernel.CreateFunctionFromMethod((WeatherParameters parameters) => + { + if (parameters.City.Name == "Dublin" && (parameters.City.Country == "Ireland" || parameters.City.Country == "IE")) + { + return Task.FromResult(42.8); // 42.8 Fahrenheit. + } + + throw new NotSupportedException($"Weather in {parameters.City.Name} ({parameters.City.Country}) is not supported."); + }, "Get_Current_Temperature", "Get current temperature."), + ]); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + var result = await kernel.InvokePromptAsync("What is the current temperature in Dublin, Ireland, in Fahrenheit?", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("42.8", result.GetValue(), StringComparison.InvariantCulture); // The WeatherPlugin always returns 42.8 for Dublin, Ireland. + } + + [Fact(Skip = "For manual verification only")] + public async Task CanAutoInvokeKernelFunctionsWithPrimitiveTypeParametersAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + var result = await kernel.InvokePromptAsync("Convert 50 degrees Fahrenheit to Celsius.", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("10", result.GetValue(), StringComparison.InvariantCulture); + } + + [Fact(Skip = "For manual verification only")] + public async Task CanAutoInvokeKernelFunctionsWithEnumTypeParametersAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + var result = await kernel.InvokePromptAsync("Given the current time of day and weather, what is the likely color of the sky in Boston?", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("rain", result.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task CanAutoInvokeKernelFunctionFromPromptAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var promptFunction = KernelFunctionFactory.CreateFromPrompt( + "Your role is always to return this text - 'A Game-Changer for the Transportation Industry'. Don't ask for more details or context.", + functionName: "FindLatestNews", + description: "Searches for the latest news."); + + kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions( + "NewsProvider", + "Delivers up-to-date news content.", + [promptFunction])); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + var result = await kernel.InvokePromptAsync("Show me the latest news as they are.", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("Transportation", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForManualFunctionCallingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Required() }; + + var sut = kernel.GetRequiredService(); + + // Act + var messageContent = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + var functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + + while (functionCalls.Length != 0) + { + // Adding function call from LLM to chat history + chatHistory.Add(messageContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + var result = await functionCall.InvokeAsync(kernel); + + chatHistory.Add(result.ToChatMessage()); + } + + // Sending the functions invocation results to the LLM to get the final response + messageContent = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + } + + // Assert + Assert.Contains("rain", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanPassFunctionExceptionToConnectorAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("Add the \"Error\" keyword to the response, if you are unable to answer a question or an error has happen."); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Required() }; + + var completionService = kernel.GetRequiredService(); + + // Act + var messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + var functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + + while (functionCalls.Length != 0) + { + // Adding function call from LLM to chat history + chatHistory.Add(messageContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + // Simulating an exception + var exception = new OperationCanceledException("The operation was canceled due to timeout."); + + chatHistory.Add(new FunctionResultContent(functionCall, exception).ToChatMessage()); + } + + // Sending the functions execution results back to the LLM to get the final response + messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + } + + // Assert + Assert.NotNull(messageContent.Content); + + TestHelpers.AssertChatErrorExcuseMessage(messageContent.Content); + } + + [Fact(Skip = "For manual verification only")] + public async Task ConnectorAgnosticFunctionCallingModelClassesSupportSimulatedFunctionCallsAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What is the weather in Boston?"); + + var settings = new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + var completionService = kernel.GetRequiredService(); + + // Act + // Adding a simulated function call to the connector response message + var simulatedFunctionCall = new FunctionCallContent("weather-alert", id: "call_123"); + var messageContent = new ChatMessageContent(AuthorRole.Assistant, [simulatedFunctionCall]); + + // Adding a simulated function result to chat history + var simulatedFunctionResult = "A Tornado Watch has been issued, with potential for severe thunderstorms causing unusual sky colors like green, yellow, or dark gray. Stay informed and follow safety instructions from authorities."; + chatHistory.Add(new FunctionResultContent(simulatedFunctionCall, simulatedFunctionResult).ToChatMessage()); + + // Sending the functions invocation results back to the LLM to get the final response + messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.Contains("tornado", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForAutoFunctionCallingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + var sut = kernel.GetRequiredService(); + + // Act + await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + var userMessage = chatHistory[0]; + Assert.Equal(AuthorRole.User, userMessage.Role); + + // LLM requested the functions to call. + var getParallelFunctionCallRequestMessage = chatHistory[1]; + Assert.Equal(AuthorRole.Assistant, getParallelFunctionCallRequestMessage.Role); + + // Parallel Function Calls in the same request + var functionCalls = getParallelFunctionCallRequestMessage.Items.OfType().ToArray(); + + FunctionCallContent getWeatherForCityFunctionCallRequest; + ChatMessageContent getWeatherForCityFunctionCallResultMessage; + + // Assert + // LLM requested the current time. + getWeatherForCityFunctionCallRequest = functionCalls[0]; + + // Connector invoked the Get_Weather_For_City function and added result to chat history. + getWeatherForCityFunctionCallResultMessage = chatHistory[2]; + + Assert.Equal("HelperFunctions-Get_Weather_For_City", getWeatherForCityFunctionCallRequest.FunctionName); + Assert.NotNull(getWeatherForCityFunctionCallRequest.Id); + + Assert.Equal(AuthorRole.Tool, getWeatherForCityFunctionCallResultMessage.Role); + Assert.Single(getWeatherForCityFunctionCallResultMessage.Items.OfType()); // Current function calling model adds TextContent item representing the result of the function call. + + var getWeatherForCityFunctionCallResult = getWeatherForCityFunctionCallResultMessage.Items.OfType().Single(); + Assert.Equal("HelperFunctions-Get_Weather_For_City", getWeatherForCityFunctionCallResult.FunctionName); + Assert.Equal(getWeatherForCityFunctionCallRequest.Id, getWeatherForCityFunctionCallResult.CallId); + Assert.NotNull(getWeatherForCityFunctionCallResult.Result); + } + + [Fact(Skip = "For manual verification only")] + public async Task SubsetOfFunctionsCanBeUsedForFunctionCallingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var function = kernel.CreateFunctionFromMethod(() => DayOfWeek.Friday.ToString(), "GetDayOfWeek", "Retrieves the current day of the week."); + kernel.ImportPluginFromFunctions("HelperFunctions", [function]); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What day is today?"); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + var sut = kernel.GetRequiredService(); + + // Act + var result = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + Assert.Contains("Friday", result.Content, StringComparison.InvariantCulture); + } + + [Fact(Skip = "For manual verification only")] + public async Task RequiredFunctionShouldBeCalledAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var function = kernel.CreateFunctionFromMethod(() => DayOfWeek.Friday.ToString(), "GetDayOfWeek", "Retrieves the current day of the week."); + kernel.ImportPluginFromFunctions("HelperFunctions", [function]); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("What day is today?"); + + PromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + var sut = kernel.GetRequiredService(); + + // Act + var result = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.NotNull(result); + Assert.Contains("Friday", result.Content, StringComparison.InvariantCulture); + } + + private Kernel CreateAndInitializeKernel(bool importHelperPlugin = false) + { + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ApiKey); + Assert.NotNull(config.ChatModelId); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureAIInferenceChatCompletion(modelId: config.ChatModelId!, endpoint: config.Endpoint, apiKey: config.ApiKey); + + var kernel = kernelBuilder.Build(); + + if (importHelperPlugin) + { + kernel.ImportPluginFromFunctions("HelperFunctions", + [ + kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."), + kernel.CreateFunctionFromMethod((string cityName) => + { + return cityName switch + { + "Boston" => "61 and rainy", + _ => "31 and snowing", + }; + }, "Get_Weather_For_City", "Gets the current weather for the specified city"), + ]); + } + + return kernel; + } + + public record WeatherParameters(City City); + + public class City + { + public string Name { get; set; } = string.Empty; + public string Country { get; set; } = string.Empty; + } + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); +} diff --git a/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs b/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs index 664effc9e3a5..12f177802398 100644 --- a/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs +++ b/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs @@ -7,9 +7,10 @@ namespace SemanticKernel.IntegrationTests.TestSettings; [SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", Justification = "Configuration classes are instantiated through IConfiguration.")] -internal sealed class AzureAIInferenceConfiguration(Uri endpoint, string apiKey, string? serviceId = null) +internal sealed class AzureAIInferenceConfiguration(Uri endpoint, string apiKey, string? serviceId = null, string? chatModelId = null) { public Uri Endpoint { get; set; } = endpoint; public string? ApiKey { get; set; } = apiKey; public string? ServiceId { get; set; } = serviceId; + public string? ChatModelId { get; set; } = chatModelId; } diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index e4bd00c302b6..8153da9efa9d 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -8,7 +8,8 @@ "AzureAIInference": { "ServiceId": "azure-ai-inference", "Endpoint": "", - "ApiKey": "" + "ApiKey": "", + "ChatModelId ": "phi3" }, "AzureOpenAI": { "ServiceId": "azure-gpt", diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs index 9a8cdb974902..c99721c45e69 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs @@ -79,6 +79,7 @@ public class AzureAIInferenceConfig public string ServiceId { get; set; } public string Endpoint { get; set; } public string? ApiKey { get; set; } + public string ChatModelId { get; set; } } public class OnnxConfig