Skip to content

Commit

Permalink
.Net: Sample showing function calling using NexusRaven (#7256)
Browse files Browse the repository at this point in the history
### Motivation and Context

Closes #7190 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Jul 16, 2024
1 parent f2bd420 commit a335d55
Show file tree
Hide file tree
Showing 12 changed files with 436 additions and 32 deletions.
146 changes: 146 additions & 0 deletions dotnet/samples/Concepts/FunctionCalling/NexusRaven_FunctionCalling.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.HuggingFace;
using Microsoft.SemanticKernel.PromptTemplates.Handlebars;
using Microsoft.SemanticKernel.TextGeneration;

namespace FunctionCalling;

/// <summary>
/// The following example shows how to use Semantic Kernel with the HuggingFace <see cref="HuggingFaceTextGenerationService"/>
/// to implement function calling with the Nexus Raven model.
/// </summary>
/// <param name="output">The test output helper.</param>
public class NexusRaven_FunctionCalling(ITestOutputHelper output) : BaseTest(output)
{
/// <summary>
/// Nexus Raven endpoint
/// </summary>
private Uri RavenEndpoint => new("http://nexusraven.nexusflow.ai");

/// <summary>
/// Invokes the Nexus Raven model using Text Generation.
/// </summary>
[Fact]
public async Task InvokeTextGenerationAsync()
{
Kernel kernel = Kernel.CreateBuilder()
.AddHuggingFaceTextGeneration(endpoint: RavenEndpoint)
.Build();

var textGeneration = kernel.GetRequiredService<ITextGenerationService>();
var prompt = "What is deep learning?";

var result = await textGeneration.GetTextContentsAsync(prompt);

Console.WriteLine(result[0].ToString());
}

/// <summary>
/// Invokes the Nexus Raven model with Function Calling.
/// </summary>
[Fact]
public async Task InvokeTextGenerationWithFunctionCallingAsync()
{
using var handler = new LoggingHandler(new HttpClientHandler(), this.Output);
using var httpClient = new HttpClient(handler);

Kernel kernel = Kernel.CreateBuilder()
.AddHuggingFaceTextGeneration(
endpoint: RavenEndpoint,
httpClient: httpClient)
.Build();
var plugin = ImportFunctions(kernel);
var textGeneration = kernel.GetRequiredService<ITextGenerationService>();

// This Handlebars template is used to format the available KernelFunctions so
// they can be understood by the NexusRaven model. The function name, signature and
// description must be provided. NexusRaven can reason over the list of functions and
// determine which ones need to be called for the current query.
var template =
""""
{{#each (functions)}}
Function:
{{Name}}{{Signature}}
"""
{{Description}}
"""
{{/each}}
User Query:{{prompt}}<human_end>
"""";

var prompt = "What is the weather like in Dublin?";
var functions = plugin.Select(f => new FunctionDefinition { Name = f.Name, Description = f.Description, Signature = CreateSignature(f) }).ToList();
var executionSettings = new HuggingFacePromptExecutionSettings { Temperature = 0.001F, MaxNewTokens = 1024, ReturnFullText = false, DoSample = false }; // , Stop = ["<bot_end>"]
KernelArguments arguments = new(executionSettings) { { "prompt", prompt }, { "functions", functions } };

var factory = new HandlebarsPromptTemplateFactory();
var promptTemplate = factory.Create(new PromptTemplateConfig(template) { TemplateFormat = "handlebars" });
var rendered = await promptTemplate.RenderAsync(kernel, arguments);

Console.WriteLine(" Prompt:\n====================");
Console.WriteLine(rendered);

var function = kernel.CreateFunctionFromPrompt(template, templateFormat: "handlebars", promptTemplateFactory: new HandlebarsPromptTemplateFactory());

var result = await kernel.InvokeAsync(function, arguments);

Console.WriteLine("\n Response:\n====================");
Console.WriteLine(result.ToString());
}

// The signature must be Python compliant and currently only supports primitive values
private static string CreateSignature(KernelFunction function)
{
var signature = new StringBuilder();
var parameters = function.Metadata.Parameters;
signature.Append('(');
foreach (var parameter in parameters)
{
signature.Append(parameter.Name).Append(':').Append(GetType(parameter));
}
signature.Append(')');
return signature.ToString();
}

private static string GetType(KernelParameterMetadata parameter)
{
if (parameter.Schema is not null)
{
var rootElement = parameter.Schema.RootElement;
if (rootElement.TryGetProperty("type", out var type))
{
return type.GetString() ?? string.Empty;
}
}
return string.Empty;
}

private static KernelPlugin ImportFunctions(Kernel kernel)
{
return kernel.ImportPluginFromFunctions("WeatherPlugin",
[
kernel.CreateFunctionFromMethod(
(string cityName) => "12°C\nWind: 11 KMPH\nHumidity: 48%\nMostly cloudy",
"GetWeatherForCity",
"Gets the current weather for the specified city",
new List<KernelParameterMetadata>
{
new("cityName") { Description = "The city name", ParameterType = string.Empty.GetType() }
}),
]);
}

/// <summary>
/// Function definition for use with Nexus Raven.
/// </summary>
private sealed class FunctionDefinition
{
public string Name { get; init; }
public string Signature { get; init; }
public string Description { get; init; }
}
}
1 change: 1 addition & 0 deletions dotnet/samples/Concepts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Down below you can find the code snippets that demonstrate the usage of many Sem

- [Gemini_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/FunctionCalling/Gemini_FunctionCalling.cs)
- [OpenAI_FunctionCalling](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/FunctionCalling/OpenAI_FunctionCalling.cs)
- [NexusRaven_HuggingFaceTextGeneration](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/FunctionCalling/NexusRaven_FunctionCalling.cs)

## Caching - Examples of caching implementations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,26 @@ internal sealed class HuggingFaceClient
private readonly HttpClient _httpClient;

internal string ModelProvider => "huggingface";
internal string ModelId { get; }
internal string? ModelId { get; }
internal string? ApiKey { get; }
internal Uri Endpoint { get; }
internal string Separator { get; }
internal ILogger Logger { get; }

internal HuggingFaceClient(
string modelId,
HttpClient httpClient,
string? modelId = null,
Uri? endpoint = null,
string? apiKey = null,
ILogger? logger = null)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNull(httpClient);

if (string.IsNullOrWhiteSpace(modelId) && endpoint is null)
{
throw new InvalidOperationException("A valid model id or endpoint must be provided.");
}

endpoint ??= new Uri("https://api-inference.huggingface.co");
this.Separator = endpoint.AbsolutePath.EndsWith("/", StringComparison.InvariantCulture) ? string.Empty : "/";
this.Endpoint = endpoint;
Expand Down Expand Up @@ -130,13 +134,13 @@ public async Task<IReadOnlyList<TextContent>> GenerateTextAsync(
PromptExecutionSettings? executionSettings,
CancellationToken cancellationToken)
{
string modelId = executionSettings?.ModelId ?? this.ModelId;
string? modelId = executionSettings?.ModelId ?? this.ModelId;
var endpoint = this.GetTextGenerationEndpoint(modelId);

var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
var request = this.CreateTextRequest(prompt, huggingFaceExecutionSettings);

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, huggingFaceExecutionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId ?? string.Empty, this.ModelProvider, prompt, huggingFaceExecutionSettings);
using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey);

TextGenerationResponse response;
Expand Down Expand Up @@ -166,14 +170,14 @@ public async IAsyncEnumerable<StreamingTextContent> StreamGenerateTextAsync(
PromptExecutionSettings? executionSettings,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
string modelId = executionSettings?.ModelId ?? this.ModelId;
string? modelId = executionSettings?.ModelId ?? this.ModelId;
var endpoint = this.GetTextGenerationEndpoint(modelId);

var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
var request = this.CreateTextRequest(prompt, huggingFaceExecutionSettings);
request.Stream = true;

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, huggingFaceExecutionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId ?? string.Empty, this.ModelProvider, prompt, huggingFaceExecutionSettings);
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
try
Expand Down Expand Up @@ -223,7 +227,7 @@ public async IAsyncEnumerable<StreamingTextContent> StreamGenerateTextAsync(
}
}

private async IAsyncEnumerable<StreamingTextContent> ProcessTextResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
private async IAsyncEnumerable<StreamingTextContent> ProcessTextResponseStreamAsync(Stream stream, string? modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var content in this.ParseTextResponseStreamAsync(stream, cancellationToken).ConfigureAwait(false))
{
Expand All @@ -234,7 +238,7 @@ private async IAsyncEnumerable<StreamingTextContent> ProcessTextResponseStreamAs
private IAsyncEnumerable<TextGenerationStreamResponse> ParseTextResponseStreamAsync(Stream responseStream, CancellationToken cancellationToken)
=> SseJsonParser.ParseAsync<TextGenerationStreamResponse>(responseStream, cancellationToken);

private static StreamingTextContent GetStreamingTextContentFromStreamResponse(TextGenerationStreamResponse response, string modelId)
private static StreamingTextContent GetStreamingTextContentFromStreamResponse(TextGenerationStreamResponse response, string? modelId)
=> new(
text: response.Token?.Text,
modelId: modelId,
Expand All @@ -250,10 +254,10 @@ private TextGenerationRequest CreateTextRequest(
return request;
}

private static List<TextContent> GetTextContentsFromResponse(TextGenerationResponse response, string modelId)
private static List<TextContent> GetTextContentsFromResponse(TextGenerationResponse response, string? modelId)
=> response.Select(r => new TextContent(r.GeneratedText, modelId, r, Encoding.UTF8, new HuggingFaceTextGenerationMetadata(response))).ToList();

private static List<TextContent> GetTextContentsFromResponse(ImageToTextGenerationResponse response, string modelId)
private static List<TextContent> GetTextContentsFromResponse(ImageToTextGenerationResponse response, string? modelId)
=> response.Select(r => new TextContent(r.GeneratedText, modelId, r, Encoding.UTF8)).ToList();

private void LogTextGenerationUsage(HuggingFacePromptExecutionSettings executionSettings)
Expand All @@ -265,8 +269,8 @@ private void LogTextGenerationUsage(HuggingFacePromptExecutionSettings execution
executionSettings.ModelId ?? this.ModelId);
}
}
private Uri GetTextGenerationEndpoint(string modelId)
=> new($"{this.Endpoint}{this.Separator}models/{modelId}");
private Uri GetTextGenerationEndpoint(string? modelId)
=> string.IsNullOrWhiteSpace(modelId) ? this.Endpoint : new($"{this.Endpoint}{this.Separator}models/{modelId}");

#endregion

Expand Down Expand Up @@ -300,8 +304,8 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
return response.ToList()!;
}

private Uri GetEmbeddingGenerationEndpoint(string modelId)
=> new($"{this.Endpoint}{this.Separator}pipeline/feature-extraction/{modelId}");
private Uri GetEmbeddingGenerationEndpoint(string? modelId)
=> string.IsNullOrWhiteSpace(modelId) ? this.Endpoint : new($"{this.Endpoint}{this.Separator}pipeline/feature-extraction/{modelId}");

#endregion

Expand Down Expand Up @@ -337,8 +341,8 @@ private HttpRequestMessage CreateImageToTextRequest(ImageContent content, Prompt
return request;
}

private Uri GetImageToTextGenerationEndpoint(string modelId)
=> new($"{this.Endpoint}{this.Separator}models/{modelId}");
private Uri GetImageToTextGenerationEndpoint(string? modelId)
=> string.IsNullOrWhiteSpace(modelId) ? this.Endpoint : new($"{this.Endpoint}{this.Separator}models/{modelId}");

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ internal sealed class HuggingFaceMessageApiClient
description: "Number of total tokens used");

internal HuggingFaceMessageApiClient(
string modelId,
HttpClient httpClient,
string? modelId = null,
Uri? endpoint = null,
string? apiKey = null,
ILogger? logger = null)
{
this._clientCore = new(
modelId,
httpClient,
modelId,
endpoint,
apiKey,
logger);
Expand All @@ -81,15 +81,15 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> StreamCompleteChatM
PromptExecutionSettings? executionSettings,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
string modelId = executionSettings?.ModelId ?? this._clientCore.ModelId;
string? modelId = executionSettings?.ModelId ?? this._clientCore.ModelId;
var endpoint = this.GetChatGenerationEndpoint();

var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);

var request = this.CreateChatRequest(chatHistory, huggingFaceExecutionSettings, modelId);
request.Stream = true;

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, huggingFaceExecutionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId ?? string.Empty, this._clientCore.ModelProvider, chatHistory, huggingFaceExecutionSettings);
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
try
Expand Down Expand Up @@ -144,13 +144,13 @@ internal async Task<IReadOnlyList<ChatMessageContent>> CompleteChatMessageAsync(
PromptExecutionSettings? executionSettings,
CancellationToken cancellationToken)
{
string modelId = executionSettings?.ModelId ?? this._clientCore.ModelId;
string? modelId = executionSettings?.ModelId ?? this._clientCore.ModelId;
var endpoint = this.GetChatGenerationEndpoint();

var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
var request = this.CreateChatRequest(chatHistory, huggingFaceExecutionSettings, modelId);

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, huggingFaceExecutionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId ?? string.Empty, this._clientCore.ModelProvider, chatHistory, huggingFaceExecutionSettings);
using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey);

ChatCompletionResponse response;
Expand Down Expand Up @@ -198,7 +198,7 @@ private void LogChatCompletionUsage(HuggingFacePromptExecutionSettings execution
s_totalTokensCounter.Add(chatCompletionResponse.Usage.TotalTokens);
}

private static List<ChatMessageContent> GetChatMessageContentsFromResponse(ChatCompletionResponse response, string modelId)
private static List<ChatMessageContent> GetChatMessageContentsFromResponse(ChatCompletionResponse response, string? modelId)
{
var chatMessageContents = new List<ChatMessageContent>();

Expand Down Expand Up @@ -230,7 +230,7 @@ private static List<ChatMessageContent> GetChatMessageContentsFromResponse(ChatC
return chatMessageContents;
}

private static StreamingChatMessageContent GetStreamingChatMessageContentFromStreamResponse(ChatCompletionStreamResponse response, string modelId)
private static StreamingChatMessageContent GetStreamingChatMessageContentFromStreamResponse(ChatCompletionStreamResponse response, string? modelId)
{
var choice = response.Choices?.FirstOrDefault();
if (choice is not null)
Expand Down Expand Up @@ -264,7 +264,7 @@ private static StreamingChatMessageContent GetStreamingChatMessageContentFromStr
};
}

private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseStreamAsync(Stream stream, string? modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var content in this.ParseChatResponseStreamAsync(stream, cancellationToken).ConfigureAwait(false))
{
Expand All @@ -275,7 +275,7 @@ private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseS
private ChatCompletionRequest CreateChatRequest(
ChatHistory chatHistory,
HuggingFacePromptExecutionSettings huggingFaceExecutionSettings,
string modelId)
string? modelId)
{
HuggingFaceClient.ValidateMaxTokens(huggingFaceExecutionSettings.MaxTokens);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ internal sealed class ChatCompletionRequest
/// <param name="executionSettings">Execution settings to be used for the request.</param>
/// <param name="modelId">Model id to use if value in prompt execution settings is not set.</param>
/// <returns>TexGenerationRequest object.</returns>
internal static ChatCompletionRequest FromChatHistoryAndExecutionSettings(ChatHistory chatHistory, HuggingFacePromptExecutionSettings executionSettings, string modelId)
internal static ChatCompletionRequest FromChatHistoryAndExecutionSettings(ChatHistory chatHistory, HuggingFacePromptExecutionSettings executionSettings, string? modelId)
{
return new ChatCompletionRequest
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ internal static TextGenerationRequest FromPromptAndExecutionSettings(string prom
RepetitionPenalty = executionSettings.RepetitionPenalty,
MaxTime = executionSettings.MaxTime,
NumReturnSequences = executionSettings.ResultsPerPrompt,
Details = executionSettings.Details
Details = executionSettings.Details,
ReturnFullText = executionSettings.ReturnFullText,
DoSample = executionSettings.DoSample,
},
Options = new()
{
Expand Down Expand Up @@ -124,7 +126,8 @@ internal sealed class HuggingFaceTextParameters
/// (Default: True). If set to False, the return results will not contain the original query making it easier for prompting.
/// </summary>
[JsonPropertyName("return_full_text")]
public bool ReturnFullText { get; set; } = true;
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? ReturnFullText { get; set; } = true;

/// <summary>
/// (Default: 1). The number of proposition you want to be returned.
Expand Down
Loading

0 comments on commit a335d55

Please sign in to comment.