From 6dc75590fa39c3e03b04829e43b5530edc11b46e Mon Sep 17 00:00:00 2001 From: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:35:09 +0000 Subject: [PATCH] .Net: Fix MistralAI function calling and add image content support (#9844) ### Motivation and Context Tool messages were being rejected as bad requests Close #9806 ### Description ### Contribution Checklist - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone :smile: --- .../Client/MistralClientTests.cs | 80 ++++++- .../MistralAIPromptExecutionSettingsTests.cs | 41 ++++ .../Client/ChatCompletionRequest.cs | 16 ++ .../Client/ContentChunk.cs | 13 ++ .../Client/ContentChunkType.cs | 62 +++++ .../Client/ImageUrlChunk.cs | 11 + .../Client/MistralChatCompletionChunk.cs | 2 +- .../Client/MistralChatMessage.cs | 13 +- .../Client/MistralClient.cs | 59 +++-- .../Connectors.MistralAI/Client/TextChunk.cs | 10 + .../MistralAIPromptExecutionSettings.cs | 98 ++++++++ .../MistralAIChatCompletionTests.cs | 212 +++++++++++++++--- dotnet/src/IntegrationTests/testsettings.json | 6 + 13 files changed, 569 insertions(+), 54 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunk.cs create mode 100644 dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunkType.cs create mode 100644 dotnet/src/Connectors/Connectors.MistralAI/Client/ImageUrlChunk.cs create mode 100644 dotnet/src/Connectors/Connectors.MistralAI/Client/TextChunk.cs diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs index 37e00ec56154..4fd7cabd2987 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs @@ -37,6 +37,24 @@ public void ValidateRequiredArguments() #pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type. } + [Fact] + public void ValidateDeserializeChatCompletionMistralChatMessage() + { + var json = "{\"role\":\"assistant\",\"content\":\"Some response.\",\"tool_calls\":null}"; + + MistralChatMessage? deserializedResponse = JsonSerializer.Deserialize(json); + Assert.NotNull(deserializedResponse); + } + + [Fact] + public void ValidateDeserializeChatCompletionResponse() + { + var json = "{\"id\":\"aee5e73a5ef241be89cd7d3e9c45089a\",\"object\":\"chat.completion\",\"created\":1732882368,\"model\":\"mistral-large-latest\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Some response.\",\"tool_calls\":null},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":17,\"total_tokens\":124,\"completion_tokens\":107}}"; + + ChatCompletionResponse? deserializedResponse = JsonSerializer.Deserialize(json); + Assert.NotNull(deserializedResponse); + } + [Fact] public async Task ValidateChatMessageRequestAsync() { @@ -62,7 +80,7 @@ public async Task ValidateChatMessageRequestAsync() Assert.Equal(0.9, chatRequest.Temperature); Assert.Single(chatRequest.Messages); Assert.Equal("user", chatRequest.Messages[0].Role); - Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content); + Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content?.ToString()); } [Fact] @@ -504,6 +522,31 @@ public void ValidateToMistralChatMessages(string roleLabel, string content) Assert.Single(messages); } + [Fact] + public void ValidateToMistralChatMessagesWithMultipleContents() + { + // Arrange + using var httpClient = new HttpClient(); + var client = new MistralClient("mistral-large-latest", httpClient, "key"); + var chatMessage = new ChatMessageContent() + { + Role = AuthorRole.User, + Items = + [ + new TextContent("What is the weather like in Paris?"), + new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg")) + ], + }; + + // Act + var messages = client.ToMistralChatMessages(chatMessage, default); + + // Assert + Assert.NotNull(messages); + Assert.Single(messages); + Assert.IsType>(messages[0].Content); + } + [Fact] public void ValidateToMistralChatMessagesWithFunctionCallContent() { @@ -544,6 +587,41 @@ public void ValidateToMistralChatMessagesWithFunctionResultContent() Assert.Equal(2, messages.Count); } + [Fact] + public void ValidateCloneMistralAIPromptExecutionSettings() + { + // Arrange + var settings = new MistralAIPromptExecutionSettings + { + MaxTokens = 1024, + Temperature = 0.9, + TopP = 0.9, + FrequencyPenalty = 0.9, + PresencePenalty = 0.9, + Stop = ["stop"], + SafePrompt = true, + RandomSeed = 123, + ResponseFormat = new { format = "json" }, + }; + + // Act + var clonedSettings = settings.Clone(); + + // Assert + Assert.NotNull(clonedSettings); + Assert.IsType(clonedSettings); + var clonedMistralAISettings = clonedSettings as MistralAIPromptExecutionSettings; + Assert.Equal(settings.MaxTokens, clonedMistralAISettings!.MaxTokens); + Assert.Equal(settings.Temperature, clonedMistralAISettings.Temperature); + Assert.Equal(settings.TopP, clonedMistralAISettings.TopP); + Assert.Equal(settings.FrequencyPenalty, clonedMistralAISettings.FrequencyPenalty); + Assert.Equal(settings.PresencePenalty, clonedMistralAISettings.PresencePenalty); + Assert.Equal(settings.Stop, clonedMistralAISettings.Stop); + Assert.Equal(settings.SafePrompt, clonedMistralAISettings.SafePrompt); + Assert.Equal(settings.RandomSeed, clonedMistralAISettings.RandomSeed); + Assert.Equal(settings.ResponseFormat, clonedMistralAISettings.ResponseFormat); + } + public sealed class WeatherPlugin { [KernelFunction] diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIPromptExecutionSettingsTests.cs index 4422740da6c8..8a4b3c31594d 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIPromptExecutionSettingsTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/MistralAIPromptExecutionSettingsTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Text.Json; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.MistralAI; @@ -68,4 +69,44 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia Assert.True(MistralExecutionSettings.SafePrompt); Assert.Equal(123, MistralExecutionSettings.RandomSeed); } + + [Fact] + public void FreezeShouldPreventPropertyModification() + { + // Arrange + var settings = new MistralAIPromptExecutionSettings + { + Temperature = 0.7, + TopP = 1, + MaxTokens = 100, + SafePrompt = false, + Stop = ["foo", "bar"] + }; + + // Act + settings.Freeze(); + + // Assert + // Try to modify a property after freezing + Assert.Throws(() => settings.Temperature = 0.8); + Assert.Throws(() => settings.TopP = 0.9); + Assert.Throws(() => settings.MaxTokens = 50); + Assert.Throws(() => settings.SafePrompt = true); + Assert.Throws(() => settings.Stop.Add("baz")); + } + + [Fact] + public void FreezeShouldNotAllowMultipleFreezes() + { + // Arrange + var settings = new MistralAIPromptExecutionSettings(); + settings.Freeze(); // First freeze + + // Act + settings.Freeze(); // Second freeze (should not throw) + + // Assert + // No exception should be thrown + Assert.True(settings.IsFrozen); // Assuming IsFrozen is a property indicating the freeze state + } } diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs index e1fc8dbfe996..cf5a3258ea9a 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs @@ -44,6 +44,22 @@ internal sealed class ChatCompletionRequest [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public int? RandomSeed { get; set; } + [JsonPropertyName("response_format")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? ResponseFormat { get; set; } + + [JsonPropertyName("frequency_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? FrequencyPenalty { get; set; } + + [JsonPropertyName("presence_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? PresencePenalty { get; set; } + + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IList? Stop { get; set; } + /// /// Construct an instance of . /// diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunk.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunk.cs new file mode 100644 index 000000000000..9701570bc678 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunk.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client; + +[JsonDerivedType(typeof(TextChunk))] +[JsonDerivedType(typeof(ImageUrlChunk))] +internal abstract class ContentChunk(ContentChunkType type) +{ + [JsonPropertyName("type")] + public string Type { get; set; } = type.ToString(); +} diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunkType.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunkType.cs new file mode 100644 index 000000000000..95125f23f191 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunkType.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client; + +internal readonly struct ContentChunkType : IEquatable +{ + public static ContentChunkType Text { get; } = new("text"); + + public static ContentChunkType ImageUrl { get; } = new("image_url"); + + public string Type { get; } + + /// + /// Creates a new instance with the provided type. + /// + /// The label to associate with this . + [JsonConstructor] + public ContentChunkType(string type) + { + Verify.NotNullOrWhiteSpace(type, nameof(type)); + this.Type = type!; + } + + /// + /// Returns a value indicating whether two instances are equivalent, as determined by a + /// case-insensitive comparison of their labels. + /// + /// the first instance to compare + /// the second instance to compare + /// true if left and right are both null or have equivalent labels; false otherwise + public static bool operator ==(ContentChunkType left, ContentChunkType right) + => left.Equals(right); + + /// + /// Returns a value indicating whether two instances are not equivalent, as determined by a + /// case-insensitive comparison of their labels. + /// + /// the first instance to compare + /// the second instance to compare + /// false if left and right are both null or have equivalent labels; true otherwise + public static bool operator !=(ContentChunkType left, ContentChunkType right) + => !left.Equals(right); + + /// + public override bool Equals([NotNullWhen(true)] object? obj) + => obj is ContentChunkType otherRole && this == otherRole; + + /// + public bool Equals(ContentChunkType other) + => string.Equals(this.Type, other.Type, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() + => StringComparer.OrdinalIgnoreCase.GetHashCode(this.Type); + + /// + public override string ToString() => this.Type; +} diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/ImageUrlChunk.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/ImageUrlChunk.cs new file mode 100644 index 000000000000..01d44e363403 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/ImageUrlChunk.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client; +internal class ImageUrlChunk(Uri imageUrl) : ContentChunk(ContentChunkType.ImageUrl) +{ + [JsonPropertyName("image_url")] + public string ImageUrl { get; set; } = imageUrl.ToString(); +} diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatCompletionChunk.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatCompletionChunk.cs index 6ae497ca0180..b163b8401605 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatCompletionChunk.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatCompletionChunk.cs @@ -43,7 +43,7 @@ internal sealed class MistralChatCompletionChunk internal string? GetRole(int index) => this.Choices?[index]?.Delta?.Role; - internal string? GetContent(int index) => this.Choices?[index]?.Delta?.Content; + internal string? GetContent(int index) => this.Choices?[index]?.Delta?.Content?.ToString(); internal int GetChoiceIndex(int index) => this.Choices?[index]?.Index ?? -1; diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs index 6efdb6e0ac5c..e587ac8f5c95 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs @@ -15,7 +15,16 @@ internal sealed class MistralChatMessage public string? Role { get; set; } [JsonPropertyName("content")] - public string? Content { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? Content { get; set; } + + [JsonPropertyName("name")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Name { get; set; } + + [JsonPropertyName("tool_call_id")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ToolCallId { get; set; } [JsonPropertyName("tool_calls")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] @@ -27,7 +36,7 @@ internal sealed class MistralChatMessage /// If provided must be one of: system, user, assistant /// Content of the chat message [JsonConstructor] - internal MistralChatMessage(string? role, string? content) + internal MistralChatMessage(string? role, object? content) { if (role is not null and not "system" and not "user" and not "assistant" and not "tool") { diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index 783f4565c9a1..daff20926bdf 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -505,7 +505,7 @@ private async IAsyncEnumerable StreamChatMessageCon var endpoint = this.GetEndpoint(executionSettings, path: "chat/completions"); using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: true); using var response = await this.SendStreamingRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); - using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false); + var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false); await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) { yield return streamingChatContent; @@ -693,7 +693,11 @@ private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool s TopP = executionSettings.TopP, MaxTokens = executionSettings.MaxTokens, SafePrompt = executionSettings.SafePrompt, - RandomSeed = executionSettings.RandomSeed + RandomSeed = executionSettings.RandomSeed, + ResponseFormat = executionSettings.ResponseFormat, + FrequencyPenalty = executionSettings.FrequencyPenalty, + PresencePenalty = executionSettings.PresencePenalty, + Stop = executionSettings.Stop, }; executionSettings.ToolCallBehavior?.ConfigureRequest(kernel, request); @@ -701,14 +705,14 @@ private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool s return request; } - internal List ToMistralChatMessages(ChatMessageContent content, MistralAIToolCallBehavior? toolCallBehavior) + internal List ToMistralChatMessages(ChatMessageContent chatMessage, MistralAIToolCallBehavior? toolCallBehavior) { - if (content.Role == AuthorRole.Assistant) + if (chatMessage.Role == AuthorRole.Assistant) { // Handling function calls supplied via ChatMessageContent.Items collection elements of the FunctionCallContent type. - var message = new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty); + var message = new MistralChatMessage(chatMessage.Role.ToString(), chatMessage.Content ?? string.Empty); Dictionary toolCalls = []; - foreach (var item in content.Items) + foreach (var item in chatMessage.Items) { if (item is not FunctionCallContent callRequest) { @@ -740,10 +744,10 @@ internal List ToMistralChatMessages(ChatMessageContent conte return [message]; } - if (content.Role == AuthorRole.Tool) + if (chatMessage.Role == AuthorRole.Tool) { List? messages = null; - foreach (var item in content.Items) + foreach (var item in chatMessage.Items) { if (item is not FunctionResultContent resultContent) { @@ -753,7 +757,12 @@ internal List ToMistralChatMessages(ChatMessageContent conte messages ??= []; var stringResult = ProcessFunctionResult(resultContent.Result ?? string.Empty, toolCallBehavior); - messages.Add(new MistralChatMessage(content.Role.ToString(), stringResult)); + var name = $"{resultContent.PluginName}-{resultContent.FunctionName}"; + messages.Add(new MistralChatMessage(chatMessage.Role.ToString(), stringResult) + { + Name = name, + ToolCallId = resultContent.CallId + }); } if (messages is not null) { @@ -763,7 +772,29 @@ internal List ToMistralChatMessages(ChatMessageContent conte throw new NotSupportedException("No function result provided in the tool message."); } - return [new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty)]; + if (chatMessage.Items.Count == 1 && chatMessage.Items[0] is TextContent text) + { + return [new MistralChatMessage(chatMessage.Role.ToString(), text.Text)]; + } + + List content = []; + foreach (var item in chatMessage.Items) + { + if (item is TextContent textContent && !string.IsNullOrEmpty(textContent.Text)) + { + content.Add(new TextChunk(textContent.Text!)); + } + else if (item is ImageContent imageContent && imageContent.Uri is not null) + { + content.Add(new ImageUrlChunk(imageContent.Uri)); + } + else + { + throw new NotSupportedException("Invalid message content, only text and image url are supported."); + } + } + + return [new MistralChatMessage(chatMessage.Role.ToString(), content)]; } private HttpRequestMessage CreatePost(object requestData, Uri endpoint, string apiKey, bool stream) @@ -842,7 +873,7 @@ private List ToChatMessageContent(string modelId, ChatComple private ChatMessageContent ToChatMessageContent(string modelId, ChatCompletionResponse response, MistralChatChoice chatChoice) { - var message = new ChatMessageContent(new AuthorRole(chatChoice.Message!.Role!), chatChoice.Message!.Content, modelId, chatChoice, Encoding.UTF8, GetChatChoiceMetadata(response, chatChoice)); + var message = new ChatMessageContent(new AuthorRole(chatChoice.Message!.Role!), chatChoice.Message!.Content?.ToString(), modelId, chatChoice, Encoding.UTF8, GetChatChoiceMetadata(response, chatChoice)); if (chatChoice.IsToolCall) { @@ -857,7 +888,7 @@ private ChatMessageContent ToChatMessageContent(string modelId, ChatCompletionRe private ChatMessageContent ToChatMessageContent(string modelId, string streamedRole, MistralChatCompletionChunk chunk, MistralChatCompletionChoice chatChoice) { - var message = new ChatMessageContent(new AuthorRole(streamedRole), chatChoice.Delta!.Content, modelId, chatChoice, Encoding.UTF8, GetChatChoiceMetadata(chunk, chatChoice)); + var message = new ChatMessageContent(new AuthorRole(streamedRole), chatChoice.Delta!.Content?.ToString(), modelId, chatChoice, Encoding.UTF8, GetChatChoiceMetadata(chunk, chatChoice)); if (chatChoice.IsToolCall) { @@ -982,7 +1013,7 @@ private void AddResponseMessage(ChatHistory chat, MistralToolCall toolCall, stri /// The result of the function call. /// The ToolCallBehavior object containing optional settings like JsonSerializerOptions.TypeInfoResolver. /// A string representation of the function result. - private static string? ProcessFunctionResult(object functionResult, MistralAIToolCallBehavior? toolCallBehavior) + private static string ProcessFunctionResult(object functionResult, MistralAIToolCallBehavior? toolCallBehavior) { if (functionResult is string stringResult) { @@ -1000,7 +1031,7 @@ private void AddResponseMessage(ChatHistory chat, MistralToolCall toolCall, stri // a corresponding JsonTypeInfoResolver should be provided via the JsonSerializerOptions.TypeInfoResolver property. // For more details about the polymorphic serialization, see the article at: // https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0 - return JsonSerializer.Serialize(functionResult, toolCallBehavior?.ToolCallResultSerializerOptions); + return JsonSerializer.Serialize(functionResult, toolCallBehavior?.ToolCallResultSerializerOptions) ?? string.Empty; } /// diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/TextChunk.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/TextChunk.cs new file mode 100644 index 000000000000..9c8c7a2d6cd9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/TextChunk.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client; +internal class TextChunk(string text) : ContentChunk(ContentChunkType.Text) +{ + [JsonPropertyName("text")] + public string Text { get; set; } = text; +} diff --git a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs index 9e136d0e089f..5ea83a482756 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.SemanticKernel.ChatCompletion; @@ -155,6 +156,90 @@ public MistralAIToolCallBehavior? ToolCallBehavior } } + /// + /// Gets or sets the response format to use for the completion. + /// + /// + /// An object specifying the format that the model must output. + /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. + /// When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message. + /// + [JsonPropertyName("response_format")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? ResponseFormat + { + get => this._responseFormat; + + set + { + this.ThrowIfFrozen(); + this._responseFormat = value; + } + } + + /// + /// Gets or sets the stop sequences to use for the completion. + /// + /// + /// Stop generation if this token is detected. Or if one of these tokens is detected when providing an array + /// + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IList? Stop + { + get => this._stop; + + set + { + this.ThrowIfFrozen(); + this._stop = value; + } + } + + /// + /// Number between -2.0 and 2.0. Positive values penalize new tokens + /// based on whether they appear in the text so far, increasing the + /// model's likelihood to talk about new topics. + /// + /// + /// presence_penalty determines how much the model penalizes the repetition of words or phrases. + /// A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. + /// + [JsonPropertyName("presence_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? PresencePenalty + { + get => this._presencePenalty; + + set + { + this.ThrowIfFrozen(); + this._presencePenalty = value; + } + } + + /// + /// Number between -2.0 and 2.0. Positive values penalize new tokens + /// based on their existing frequency in the text so far, decreasing + /// the model's likelihood to repeat the same line verbatim. + /// + /// + /// frequency_penalty penalizes the repetition of words based on their frequency in the generated text. + /// A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. + /// + [JsonPropertyName("frequency_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? FrequencyPenalty + { + get => this._frequencyPenalty; + + set + { + this.ThrowIfFrozen(); + this._frequencyPenalty = value; + } + } + /// public override void Freeze() { @@ -163,6 +248,11 @@ public override void Freeze() return; } + if (this._stop is not null) + { + this._stop = new ReadOnlyCollection(this._stop); + } + base.Freeze(); } @@ -180,6 +270,10 @@ public override PromptExecutionSettings Clone() RandomSeed = this.RandomSeed, ApiVersion = this.ApiVersion, ToolCallBehavior = this.ToolCallBehavior, + ResponseFormat = this.ResponseFormat, + FrequencyPenalty = this.FrequencyPenalty, + PresencePenalty = this.PresencePenalty, + Stop = this.Stop is not null ? new List(this.Stop) : null, }; } @@ -215,6 +309,10 @@ public static MistralAIPromptExecutionSettings FromExecutionSettings(PromptExecu private int? _randomSeed; private string _apiVersion = "v1"; private MistralAIToolCallBehavior? _toolCallBehavior; + private object? _responseFormat; + private double? _presencePenalty; + private double? _frequencyPenalty; + private IList? _stop; #endregion } diff --git a/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs index def078a53799..da3315fa3816 100644 --- a/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs @@ -3,8 +3,11 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Net.Http; using System.Text; +using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel; @@ -12,19 +15,27 @@ using Microsoft.SemanticKernel.Connectors.MistralAI; using Microsoft.SemanticKernel.Connectors.MistralAI.Client; using Xunit; +using Xunit.Abstractions; namespace SemanticKernel.IntegrationTests.Connectors.MistralAI; /// /// Integration tests for . /// -public sealed class MistralAIChatCompletionTests +public sealed class MistralAIChatCompletionTests : IDisposable { + private readonly ITestOutputHelper _output; private readonly IConfigurationRoot _configuration; private readonly MistralAIPromptExecutionSettings _executionSettings; + private readonly HttpClientHandler _httpClientHandler; + private readonly HttpMessageHandler _httpMessageHandler; + private readonly HttpClient _httpClient; + private bool _disposedValue; - public MistralAIChatCompletionTests() + public MistralAIChatCompletionTests(ITestOutputHelper output) { + this._output = output; + // Load configuration this._configuration = new ConfigurationBuilder() .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) @@ -37,15 +48,38 @@ public MistralAIChatCompletionTests() { MaxTokens = 500, }; + + this._httpClientHandler = new HttpClientHandler(); + this._httpMessageHandler = new LoggingHandler(this._httpClientHandler, this._output); + this._httpClient = new HttpClient(this._httpMessageHandler); + } + private void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._httpClientHandler.Dispose(); + this._httpMessageHandler.Dispose(); + this._httpClient.Dispose(); + } + this._disposedValue = true; + } + } + + public void Dispose() + { + this.Dispose(disposing: true); + GC.SuppressFinalize(this); } [Fact(Skip = "This test is for manual verification.")] public async Task ValidateGetChatMessageContentsAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); // Act var chatHistory = new ChatHistory @@ -65,9 +99,9 @@ public async Task ValidateGetChatMessageContentsAsync() public async Task ValidateGetChatMessageContentsWithUsageAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); // Act var chatHistory = new ChatHistory @@ -89,11 +123,69 @@ public async Task ValidateGetChatMessageContentsWithUsageAsync() Assert.True(usage?.TotalTokens > 0); } + [Fact(Skip = "This test is for manual verification.")] + public async Task ValidateGetChatMessageContentsWithImageAsync() + { + // Arrange + var model = this._configuration["MistralAI:ImageModelId"]; + var apiKey = this._configuration["MistralAI:ApiKey"]; + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); + + // Act + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "Describe the image"), + new ChatMessageContent(AuthorRole.User, [new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"))]) + }; + var response = await service.GetChatMessageContentsAsync(chatHistory, this._executionSettings); + + // Assert + Assert.NotNull(response); + Assert.Single(response); + Assert.Contains("Paris", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Eiffel Tower", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Snow", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + } + + [Fact(Skip = "This test is for manual verification.")] + public async Task ValidateGetChatMessageContentsWithImageAndJsonFormatAsync() + { + // Arrange + var model = this._configuration["MistralAI:ImageModelId"]; + var apiKey = this._configuration["MistralAI:ApiKey"]; + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); + + // Act + var systemMessage = "Return the answer in a JSON object with the next structure: " + + "{\"elements\": [{\"element\": \"some name of element1\", " + + "\"description\": \"some description of element 1\"}, " + + "{\"element\": \"some name of element2\", \"description\": " + + "\"some description of element 2\"}]}"; + var chatHistory = new ChatHistory(systemMessage) + { + new ChatMessageContent(AuthorRole.User, "Describe the image"), + new ChatMessageContent(AuthorRole.User, [new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"))]) + }; + var executionSettings = new MistralAIPromptExecutionSettings + { + MaxTokens = 500, + ResponseFormat = new { type = "json_object" }, + }; + var response = await service.GetChatMessageContentsAsync(chatHistory, executionSettings); + + // Assert + Assert.NotNull(response); + Assert.Single(response); + Assert.Contains("Paris", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Eiffel Tower", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Snow", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + } + [Fact(Skip = "This test is for manual verification.")] public async Task ValidateInvokeChatPromptAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; var kernel = Kernel.CreateBuilder() .AddMistralChatCompletion(model!, apiKey!) @@ -117,9 +209,9 @@ public async Task ValidateInvokeChatPromptAsync() public async Task ValidateGetStreamingChatMessageContentsAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); // Act var chatHistory = new ChatHistory @@ -146,9 +238,9 @@ public async Task ValidateGetStreamingChatMessageContentsAsync() public async Task ValidateGetChatMessageContentsHasToolCallsResponseAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var kernel = new Kernel(); kernel.Plugins.AddFromType(); @@ -170,9 +262,9 @@ public async Task ValidateGetChatMessageContentsHasToolCallsResponseAsync() public async Task ValidateGetChatMessageContentsHasRequiredToolCallResponseAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var kernel = new Kernel(); var plugin = kernel.Plugins.AddFromType(); @@ -197,9 +289,9 @@ public async Task ValidateGetChatMessageContentsHasRequiredToolCallResponseAsync public async Task ValidateGetChatMessageContentsWithAutoInvokeAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions }; var kernel = new Kernel(); kernel.Plugins.AddFromType(); @@ -214,16 +306,17 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAsync() // Assert Assert.NotNull(response); Assert.Single(response); - Assert.Contains("sunny", response[0].Content, System.StringComparison.Ordinal); + Assert.Contains("Paris", response[0].Content, System.StringComparison.Ordinal); + Assert.Contains("12°C", response[0].Content, System.StringComparison.Ordinal); } [Fact(Skip = "This test is for manual verification.")] public async Task ValidateGetChatMessageContentsWithNoFunctionsAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.NoKernelFunctions }; var kernel = new Kernel(); kernel.Plugins.AddFromType(); @@ -238,16 +331,16 @@ public async Task ValidateGetChatMessageContentsWithNoFunctionsAsync() // Assert Assert.NotNull(response); Assert.Single(response); - Assert.Contains("GetWeather", response[0].Content, System.StringComparison.Ordinal); + Assert.Contains("weather", response[0].Content, System.StringComparison.Ordinal); } [Fact(Skip = "This test is for manual verification.")] public async Task ValidateGetChatMessageContentsWithAutoInvokeReturnsFunctionCallContentAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions }; var kernel = new Kernel(); kernel.Plugins.AddFromType(); @@ -272,9 +365,9 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeReturnsFunctionCal public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionFilterAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var kernel = new Kernel(); kernel.Plugins.AddFromType(); var invokedFunctions = new List(); @@ -296,7 +389,6 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionFilterA // Assert Assert.NotNull(response); Assert.Single(response); - Assert.Contains("sunny", response[0].Content, System.StringComparison.Ordinal); Assert.Contains("GetWeather", invokedFunctions); } @@ -304,9 +396,9 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionFilterA public async Task ValidateGetStreamingChatMessageContentsWithAutoInvokeAndFunctionFilterAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var kernel = new Kernel(); kernel.Plugins.AddFromType(); @@ -338,7 +430,6 @@ public async Task ValidateGetStreamingChatMessageContentsWithAutoInvokeAndFuncti // Assert Assert.NotNull(content); - Assert.Contains("sunny", content.ToString()); Assert.Contains("GetWeather", invokedFunctions); } @@ -346,9 +437,9 @@ public async Task ValidateGetStreamingChatMessageContentsWithAutoInvokeAndFuncti public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionInvocationFilterAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var kernel = new Kernel(); kernel.Plugins.AddFromType(); var invokedFunctions = new List(); @@ -371,8 +462,8 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionInvocat // Assert Assert.NotNull(response); Assert.Single(response); - Assert.StartsWith("Weather in Paris", response[0].Content); - Assert.EndsWith("is sunny and 18 Celsius", response[0].Content); + Assert.Contains("Paris", response[0].Content, System.StringComparison.Ordinal); + Assert.Contains("12°C", response[0].Content, System.StringComparison.Ordinal); Assert.Contains("GetWeather", invokedFunctions); } @@ -380,9 +471,9 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionInvocat public async Task ValidateGetChatMessageContentsWithAutoInvokeAndMultipleCallsAsync() { // Arrange - var model = this._configuration["MistralAI:ChatModel"]; + var model = this._configuration["MistralAI:ChatModelId"]; var apiKey = this._configuration["MistralAI:ApiKey"]; - var service = new MistralAIChatCompletionService(model!, apiKey!); + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); var kernel = new Kernel(); kernel.Plugins.AddFromType(); @@ -400,8 +491,8 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndMultipleCallsAs // Assert Assert.NotNull(result2); Assert.Single(result2); - Assert.Contains("Marseille", result2[0].Content); - Assert.Contains("sunny", result2[0].Content); + Assert.Contains("Marseille", result2[0].Content, System.StringComparison.Ordinal); + Assert.Contains("12°C", result2[0].Content, System.StringComparison.Ordinal); } public sealed class WeatherPlugin @@ -410,7 +501,7 @@ public sealed class WeatherPlugin [Description("Get the current weather in a given location.")] public string GetWeather( [Description("The city and department, e.g. Marseille, 13")] string location - ) => $"Weather in {location} is sunny and 18 Celsius"; + ) => $"12°C\nWind: 11 KMPH\nHumidity: 48%\nMostly cloudy\nLocation: {location}"; } public sealed class AnonymousPlugin @@ -439,4 +530,53 @@ private sealed class FakeAutoFunctionFilter( public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) => this._onAutoFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask; } + + private sealed class LoggingHandler(HttpMessageHandler innerHandler, ITestOutputHelper output) : DelegatingHandler(innerHandler) + { + private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { WriteIndented = true }; + + private readonly ITestOutputHelper _output = output; + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + // Log the request details + if (request.Content is not null) + { + var content = await request.Content.ReadAsStringAsync(cancellationToken); + this._output.WriteLine("=== REQUEST ==="); + try + { + string formattedContent = JsonSerializer.Serialize(JsonSerializer.Deserialize(content), s_jsonSerializerOptions); + this._output.WriteLine(formattedContent); + } + catch (JsonException) + { + this._output.WriteLine(content); + } + this._output.WriteLine(string.Empty); + } + + // Call the next handler in the pipeline + var response = await base.SendAsync(request, cancellationToken); + + if (response.Content is not null) + { + // Log the response details + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken); + this._output.WriteLine("=== RESPONSE ==="); + try + { + string formattedContent = JsonSerializer.Serialize(JsonSerializer.Deserialize(responseContent), s_jsonSerializerOptions); + this._output.WriteLine(formattedContent); + } + catch (JsonException) + { + this._output.WriteLine(responseContent); + } + this._output.WriteLine(string.Empty); + } + + return response; + } + } } diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index 8153da9efa9d..c2396c7c0419 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -78,6 +78,12 @@ "VisionModelId": "gemini-1.5-flash" } }, + "MistralAI": { + "EmbeddingModelId": "mistral-embed", + "ChatModelId": "mistral-large-latest", + "ImageModelId": "pixtral-12b-2409", + "ApiKey": "" + }, "Bing": { "ApiKey": "" },