Skip to content

Commit

Permalink
.Net: Fix MistralAI function calling and add image content support (#…
Browse files Browse the repository at this point in the history
…9844)

### Motivation and Context

Tool messages were being rejected as bad requests

Close #9806 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Dec 5, 2024
1 parent 5e7049b commit 6dc7559
Show file tree
Hide file tree
Showing 13 changed files with 569 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ public void ValidateRequiredArguments()
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
}

[Fact]
public void ValidateDeserializeChatCompletionMistralChatMessage()
{
var json = "{\"role\":\"assistant\",\"content\":\"Some response.\",\"tool_calls\":null}";

MistralChatMessage? deserializedResponse = JsonSerializer.Deserialize<MistralChatMessage>(json);
Assert.NotNull(deserializedResponse);
}

[Fact]
public void ValidateDeserializeChatCompletionResponse()
{
var json = "{\"id\":\"aee5e73a5ef241be89cd7d3e9c45089a\",\"object\":\"chat.completion\",\"created\":1732882368,\"model\":\"mistral-large-latest\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Some response.\",\"tool_calls\":null},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":17,\"total_tokens\":124,\"completion_tokens\":107}}";

ChatCompletionResponse? deserializedResponse = JsonSerializer.Deserialize<ChatCompletionResponse>(json);
Assert.NotNull(deserializedResponse);
}

[Fact]
public async Task ValidateChatMessageRequestAsync()
{
Expand All @@ -62,7 +80,7 @@ public async Task ValidateChatMessageRequestAsync()
Assert.Equal(0.9, chatRequest.Temperature);
Assert.Single(chatRequest.Messages);
Assert.Equal("user", chatRequest.Messages[0].Role);
Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content);
Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content?.ToString());
}

[Fact]
Expand Down Expand Up @@ -504,6 +522,31 @@ public void ValidateToMistralChatMessages(string roleLabel, string content)
Assert.Single(messages);
}

[Fact]
public void ValidateToMistralChatMessagesWithMultipleContents()
{
// Arrange
using var httpClient = new HttpClient();
var client = new MistralClient("mistral-large-latest", httpClient, "key");
var chatMessage = new ChatMessageContent()
{
Role = AuthorRole.User,
Items =
[
new TextContent("What is the weather like in Paris?"),
new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"))
],
};

// Act
var messages = client.ToMistralChatMessages(chatMessage, default);

// Assert
Assert.NotNull(messages);
Assert.Single(messages);
Assert.IsType<List<ContentChunk>>(messages[0].Content);
}

[Fact]
public void ValidateToMistralChatMessagesWithFunctionCallContent()
{
Expand Down Expand Up @@ -544,6 +587,41 @@ public void ValidateToMistralChatMessagesWithFunctionResultContent()
Assert.Equal(2, messages.Count);
}

[Fact]
public void ValidateCloneMistralAIPromptExecutionSettings()
{
// Arrange
var settings = new MistralAIPromptExecutionSettings
{
MaxTokens = 1024,
Temperature = 0.9,
TopP = 0.9,
FrequencyPenalty = 0.9,
PresencePenalty = 0.9,
Stop = ["stop"],
SafePrompt = true,
RandomSeed = 123,
ResponseFormat = new { format = "json" },
};

// Act
var clonedSettings = settings.Clone();

// Assert
Assert.NotNull(clonedSettings);
Assert.IsType<MistralAIPromptExecutionSettings>(clonedSettings);
var clonedMistralAISettings = clonedSettings as MistralAIPromptExecutionSettings;
Assert.Equal(settings.MaxTokens, clonedMistralAISettings!.MaxTokens);
Assert.Equal(settings.Temperature, clonedMistralAISettings.Temperature);
Assert.Equal(settings.TopP, clonedMistralAISettings.TopP);
Assert.Equal(settings.FrequencyPenalty, clonedMistralAISettings.FrequencyPenalty);
Assert.Equal(settings.PresencePenalty, clonedMistralAISettings.PresencePenalty);
Assert.Equal(settings.Stop, clonedMistralAISettings.Stop);
Assert.Equal(settings.SafePrompt, clonedMistralAISettings.SafePrompt);
Assert.Equal(settings.RandomSeed, clonedMistralAISettings.RandomSeed);
Assert.Equal(settings.ResponseFormat, clonedMistralAISettings.ResponseFormat);
}

public sealed class WeatherPlugin
{
[KernelFunction]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.MistralAI;
Expand Down Expand Up @@ -68,4 +69,44 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia
Assert.True(MistralExecutionSettings.SafePrompt);
Assert.Equal(123, MistralExecutionSettings.RandomSeed);
}

[Fact]
public void FreezeShouldPreventPropertyModification()
{
// Arrange
var settings = new MistralAIPromptExecutionSettings
{
Temperature = 0.7,
TopP = 1,
MaxTokens = 100,
SafePrompt = false,
Stop = ["foo", "bar"]
};

// Act
settings.Freeze();

// Assert
// Try to modify a property after freezing
Assert.Throws<InvalidOperationException>(() => settings.Temperature = 0.8);
Assert.Throws<InvalidOperationException>(() => settings.TopP = 0.9);
Assert.Throws<InvalidOperationException>(() => settings.MaxTokens = 50);
Assert.Throws<InvalidOperationException>(() => settings.SafePrompt = true);
Assert.Throws<NotSupportedException>(() => settings.Stop.Add("baz"));
}

[Fact]
public void FreezeShouldNotAllowMultipleFreezes()
{
// Arrange
var settings = new MistralAIPromptExecutionSettings();
settings.Freeze(); // First freeze

// Act
settings.Freeze(); // Second freeze (should not throw)

// Assert
// No exception should be thrown
Assert.True(settings.IsFrozen); // Assuming IsFrozen is a property indicating the freeze state
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ internal sealed class ChatCompletionRequest
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? RandomSeed { get; set; }

[JsonPropertyName("response_format")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? ResponseFormat { get; set; }

[JsonPropertyName("frequency_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? FrequencyPenalty { get; set; }

[JsonPropertyName("presence_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? PresencePenalty { get; set; }

[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<string>? Stop { get; set; }

/// <summary>
/// Construct an instance of <see cref="ChatCompletionRequest"/>.
/// </summary>
Expand Down
13 changes: 13 additions & 0 deletions dotnet/src/Connectors/Connectors.MistralAI/Client/ContentChunk.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;

[JsonDerivedType(typeof(TextChunk))]
[JsonDerivedType(typeof(ImageUrlChunk))]
internal abstract class ContentChunk(ContentChunkType type)
{
[JsonPropertyName("type")]
public string Type { get; set; } = type.ToString();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;

internal readonly struct ContentChunkType : IEquatable<ContentChunkType>
{
public static ContentChunkType Text { get; } = new("text");

public static ContentChunkType ImageUrl { get; } = new("image_url");

public string Type { get; }

/// <summary>
/// Creates a new <see cref="ContentChunkType"/> instance with the provided type.
/// </summary>
/// <param name="type">The label to associate with this <see cref="ContentChunkType"/>.</param>
[JsonConstructor]
public ContentChunkType(string type)
{
Verify.NotNullOrWhiteSpace(type, nameof(type));
this.Type = type!;
}

/// <summary>
/// Returns a value indicating whether two <see cref="ContentChunkType"/> instances are equivalent, as determined by a
/// case-insensitive comparison of their labels.
/// </summary>
/// <param name="left"> the first <see cref="ContentChunkType"/> instance to compare </param>
/// <param name="right"> the second <see cref="ContentChunkType"/> instance to compare </param>
/// <returns> true if left and right are both null or have equivalent labels; false otherwise </returns>
public static bool operator ==(ContentChunkType left, ContentChunkType right)
=> left.Equals(right);

/// <summary>
/// Returns a value indicating whether two <see cref="ContentChunkType"/> instances are not equivalent, as determined by a
/// case-insensitive comparison of their labels.
/// </summary>
/// <param name="left"> the first <see cref="ContentChunkType"/> instance to compare </param>
/// <param name="right"> the second <see cref="ContentChunkType"/> instance to compare </param>
/// <returns> false if left and right are both null or have equivalent labels; true otherwise </returns>
public static bool operator !=(ContentChunkType left, ContentChunkType right)
=> !left.Equals(right);

/// <inheritdoc/>
public override bool Equals([NotNullWhen(true)] object? obj)
=> obj is ContentChunkType otherRole && this == otherRole;

/// <inheritdoc/>
public bool Equals(ContentChunkType other)
=> string.Equals(this.Type, other.Type, StringComparison.OrdinalIgnoreCase);

/// <inheritdoc/>
public override int GetHashCode()
=> StringComparer.OrdinalIgnoreCase.GetHashCode(this.Type);

/// <inheritdoc/>
public override string ToString() => this.Type;
}
11 changes: 11 additions & 0 deletions dotnet/src/Connectors/Connectors.MistralAI/Client/ImageUrlChunk.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;
internal class ImageUrlChunk(Uri imageUrl) : ContentChunk(ContentChunkType.ImageUrl)
{
[JsonPropertyName("image_url")]
public string ImageUrl { get; set; } = imageUrl.ToString();
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal sealed class MistralChatCompletionChunk

internal string? GetRole(int index) => this.Choices?[index]?.Delta?.Role;

internal string? GetContent(int index) => this.Choices?[index]?.Delta?.Content;
internal string? GetContent(int index) => this.Choices?[index]?.Delta?.Content?.ToString();

internal int GetChoiceIndex(int index) => this.Choices?[index]?.Index ?? -1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@ internal sealed class MistralChatMessage
public string? Role { get; set; }

[JsonPropertyName("content")]
public string? Content { get; set; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? Content { get; set; }

[JsonPropertyName("name")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Name { get; set; }

[JsonPropertyName("tool_call_id")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ToolCallId { get; set; }

[JsonPropertyName("tool_calls")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
Expand All @@ -27,7 +36,7 @@ internal sealed class MistralChatMessage
/// <param name="role">If provided must be one of: system, user, assistant</param>
/// <param name="content">Content of the chat message</param>
[JsonConstructor]
internal MistralChatMessage(string? role, string? content)
internal MistralChatMessage(string? role, object? content)
{
if (role is not null and not "system" and not "user" and not "assistant" and not "tool")
{
Expand Down
Loading

0 comments on commit 6dc7559

Please sign in to comment.