Skip to content

Commit

Permalink
Add samples showing how to use the multiple LLM support
Browse files Browse the repository at this point in the history
  • Loading branch information
markwallace-microsoft committed Oct 25, 2023
1 parent b0d920f commit 5d4d98f
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@ informed:

Developers need to be able to use multiple models e.g., using chat completion together with embeddings.

<!-- This is an optional element. Feel free to remove. -->
## Use Cases

In scope for Semantic Kernel V1.0

* Select Model Request Settings by Service Id.
* A Service Id uniquely identifies a registered AI Service and is typically defined in the scope of an application.
* Select Model Request Settings by Model Id.
* A Model Id uniquely identifies a Large Language Model. Multiple AI service providers can support the same LLM.
* Select AI Service and Model Request Settings By Developer Defined Strategy.
* A Developer Defined Strategy is a code first approach where a developer provides the logic.

Out of scope for V1.0

* Select Model Request Settings by Model Id.
* A Model Id uniquely identifies a Large Language Model. Multiple AI service providers can support the same LLM.
* Select Model Request Settings by Provider Id and Model Id
* A Provider Id uniquely identifies an AI provider e.g. "Azure OpenAI", "OpenAI", "Hugging Face"

Expand Down Expand Up @@ -76,47 +75,22 @@ var func = kernel.CreateSemanticFunction(prompt, config: templateConfig!, "Hello
result = await kernel.RunAsync(func);
```

### Select Model Request Settings by Model Id

_As a developer using the Semantic Kernel I can configure multiple request settings for a semantic function and associate each one with a model id so that the correct request settings are used when different LLM's are used to execute my semantic function._

In this case the developer configures different settings based on the model id that is used to execute the semantic function.
In the example below the semantic function is executed with "OpenAIText" using `max_tokens=60` because "OpenAIText" is set as the default AI service.
If "AzureText" was set as the default AI service, the `max_tokens=60` would also be used if the deployment name matches the model id. [For Azure OpenAI the deployment name is used in code to call the model by using the client libraries and the REST APIs](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal#deploy-a-model).

**Note** For Azure OpenAI When registering an AI Service the `modelId` argument is optional (and a new argument). If none is provided the `deploymentName` is used instead. Because the `deploymentName` name can be set to anything this may not work. Developers can use the Azure API to query for the `modelId` associated with a deployment so that it is correctly set when registering the AI service.
This works by using the `IAIServiceSelector` interface as the strategy for selecting the AI service and request settings to user when invoking a semantic function.
The interface is defined as follows:

```csharp
// Configure a Kernel with multiple LLM's
IKernel kernel = Kernel.Builder
.WithLoggerFactory(ConsoleLogger.LoggerFactory)
.WithAzureTextCompletionService(deploymentName: aoai.DeploymentName, modelId: aoai.ModelId,// text-davinci-003
endpoint: aoai.Endpoint, serviceId: "AzureText", apiKey: aoai.ApiKey)
.WithAzureChatCompletionService(deploymentName: aoai.ChatDeploymentName, modelId: aoai.ModelId, // gpt-35-turbo
endpoint: aoai.Endpoint, serviceId: "AzureChat", apiKey: aoai.ApiKey)
.WithOpenAITextCompletionService(modelId: oai.ModelId, // text-davinci-003
serviceId: "OpenAIText", apiKey: oai.ApiKey, setAsDefault: true)
.WithOpenAIChatCompletionService(modelId: oai.ChatModelId, // gpt-3.5-turbo
serviceId: "OpenAIChat", apiKey: oai.ApiKey)
.Build();

// Configure semantic function with multiple LLM request settings
string configPayload = @"{
""schema"": 1,
""description"": ""Hello AI, what can you do for me?"",
""models"": [
{ ""model_id"": ""text-davinci-003"", ""max_tokens"": 60 },
{ ""model_id"": ""gpt-35-turbo"", ""max_tokens"": 120 },
{ ""model_id"": ""gpt-3.5-turbo"", ""max_tokens"": 180 }
]
}";
var templateConfig = JsonSerializer.Deserialize<PromptTemplateConfig>(configPayload);
var func = kernel.CreateSemanticFunction(prompt, config: templateConfig!, "HelloAI");

// Semantic function is executed with OpenAIText using max_tokens=60
result = await kernel.RunAsync(func);
public interface IAIServiceSelector
{
(T?, AIRequestSettings?) SelectAIService<T>(IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService;
}
```

A default `OrderedIAIServiceSelector` implementation is used which selects the AI service based on the order of the model request settings defined for the semantic function.
The implementation checks if a service exists which the corresponding service id and if it does it and the associated model request settings will be used.
In no model request settings are defined then the default text completion service is used.
A default set of request settings can be specified by leaving the service id undefined or empty, the first such default will be used.
If no default if specified and none of the specified services are available the operation will fail.

### Select AI Service and Model Request Settings By Developer Defined Strategy

_As a developer using the Semantic Kernel I can provide factories which select the AI service and request settings used to execute my function so that I can dynamically control which AI service is used to execute my semantic function._
Expand All @@ -143,9 +117,9 @@ string configPayload = @"{
""schema"": 1,
""description"": ""Hello AI, what can you do for me?"",
""models"": [
{ ""model_id"": ""text-davinci-003"", ""max_tokens"": 60 },
{ ""model_id"": ""gpt-35-turbo"", ""max_tokens"": 120 },
{ ""model_id"": ""gpt-3.5-turbo"", ""max_tokens"": 180 }
{ ""max_tokens"": 60 },
{ ""max_tokens"": 120 },
{ ""max_tokens"": 180 }
]
}";
var templateConfig = JsonSerializer.Deserialize<PromptTemplateConfig>(configPayload);
Expand Down
69 changes: 69 additions & 0 deletions dotnet/samples/KernelSyntaxExamples/Example61_MultipleLLMs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using RepoUtils;

// ReSharper disable once InconsistentNaming
public static class Example61_MultipleLLMs
{
/// <summary>
/// Show how to run a semantic function and specify a specific service to use.
/// </summary>
public static async Task RunAsync()
{
Console.WriteLine("======== Example61_MultipleLLMs ========");

string apiKey = TestConfiguration.AzureOpenAI.ApiKey;
string chatDeploymentName = TestConfiguration.AzureOpenAI.ChatDeploymentName;
string endpoint = TestConfiguration.AzureOpenAI.Endpoint;

if (apiKey == null || chatDeploymentName == null || endpoint == null)
{
Console.WriteLine("Azure endpoint, apiKey, or deploymentName not found. Skipping example.");
return;
}

string openAIModelId = TestConfiguration.OpenAI.ChatModelId;
string openAIApiKey = TestConfiguration.OpenAI.ApiKey;

if (openAIModelId == null || openAIApiKey == null)
{
Console.WriteLine("OpenAI credentials not found. Skipping example.");
return;
}

IKernel kernel = new KernelBuilder()
.WithLoggerFactory(ConsoleLogger.LoggerFactory)
.WithAzureChatCompletionService(
deploymentName: chatDeploymentName,
endpoint: endpoint,
serviceId: "AzureOpenAIChat",
apiKey: apiKey)
.WithOpenAIChatCompletionService(
modelId: openAIModelId,
serviceId: "OpenAIChat",
apiKey: openAIApiKey)
.Build();

await RunSemanticFunctionAsync(kernel, "AzureOpenAIChat");
await RunSemanticFunctionAsync(kernel, "OpenAIChat");
}

public static async Task RunSemanticFunctionAsync(IKernel kernel, string serviceId)
{
Console.WriteLine($"======== {serviceId} ========");

var prompt = "Hello AI, what can you do for me?";

var result = await kernel.InvokeSemanticFunctionAsync(
prompt,
requestSettings: new AIRequestSettings()
{
ServiceId = serviceId
});
Console.WriteLine(result.GetValue<string>());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.ML.Tokenizers;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TemplateEngine;
using RepoUtils;

// ReSharper disable once InconsistentNaming
public static class Example62_CustomAIServiceSelector
{
/// <summary>
/// Show how to configure model request settings
/// </summary>
public static async Task RunAsync()
{
Console.WriteLine("======== Example61_CustomAIServiceSelector ========");

string apiKey = TestConfiguration.AzureOpenAI.ApiKey;
string chatDeploymentName = TestConfiguration.AzureOpenAI.ChatDeploymentName;
string endpoint = TestConfiguration.AzureOpenAI.Endpoint;

if (apiKey == null || chatDeploymentName == null || endpoint == null)
{
Console.WriteLine("Azure endpoint, apiKey, or deploymentName not found. Skipping example.");
return;
}

string openAIModelId = TestConfiguration.OpenAI.ChatModelId;
string openAIApiKey = TestConfiguration.OpenAI.ApiKey;

if (openAIModelId == null || openAIApiKey == null)
{
Console.WriteLine("OpenAI credentials not found. Skipping example.");
return;
}

IKernel kernel = new KernelBuilder()
.WithLoggerFactory(ConsoleLogger.LoggerFactory)
.WithAzureChatCompletionService(
deploymentName: chatDeploymentName,
endpoint: endpoint,
serviceId: "AzureOpenAIChat",
apiKey: apiKey)
.WithOpenAIChatCompletionService(
modelId: openAIModelId,
serviceId: "OpenAIChat",
apiKey: openAIApiKey)
.Build();

var modelSettings = new List<AIRequestSettings>();
modelSettings.Add(new OpenAIRequestSettings() { ServiceId = "AzureOpenAIChat", MaxTokens = 400 });
modelSettings.Add(new OpenAIRequestSettings() { ServiceId = "OpenAIChat", MaxTokens = 200 });

await RunSemanticFunctionAsync(kernel, "Hello AI, what can you do for me?", modelSettings);
await RunSemanticFunctionAsync(kernel, "Hello AI, provide an indepth description of what can you do for me as a bulleted list?", modelSettings);
}

public static async Task RunSemanticFunctionAsync(IKernel kernel, string prompt, List<AIRequestSettings> modelSettings)
{
Console.WriteLine($"======== {prompt} ========");

var promptTemplateConfig = new PromptTemplateConfig() { ModelSettings = modelSettings };
var promptTemplate = new PromptTemplate(prompt, promptTemplateConfig, kernel);
var serviceSelector = new MyAIServiceSelector();

var skfunction = kernel.RegisterSemanticFunction(
"MyFunction",
promptTemplateConfig,
promptTemplate,
serviceSelector);

var result = await kernel.RunAsync(skfunction);
Console.WriteLine(result.GetValue<string>());
}
}

public class MyAIServiceSelector : IAIServiceSelector
{
private readonly int _defaultMaxTokens = 300;
private readonly int _minResponseTokens = 150;

public (T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
{
if (modelSettings is null || modelSettings.Count == 0)
{
var service = serviceProvider.GetService<T>(null);
if (service is not null)
{
return (service, null);
}
}
else
{
var tokens = this.CountTokens(renderedPrompt);

string? serviceId = null;
int fewestTokens = 0;
AIRequestSettings? requestSettings = null;
AIRequestSettings? defaultRequestSettings = null;
foreach (var model in modelSettings)
{
if (!string.IsNullOrEmpty(model.ServiceId))
{
if (model is OpenAIRequestSettings openAIModel)
{
var responseTokens = (openAIModel.MaxTokens ?? this._defaultMaxTokens) - tokens;
if (serviceId is null || (responseTokens > this._minResponseTokens && responseTokens < fewestTokens))
{
fewestTokens = responseTokens;
serviceId = model.ServiceId;
requestSettings = model;
}
}
}
else
{
// First request settings with empty or null service id is the default
defaultRequestSettings ??= model;
}
}

if (serviceId is not null)
{
Console.WriteLine($"Selected service: {serviceId}");
var service = serviceProvider.GetService<T>(serviceId);
if (service is not null)
{
return (service, requestSettings);
}
}

if (defaultRequestSettings is not null)
{
var service = serviceProvider.GetService<T>(null);
if (service is not null)
{
return (service, defaultRequestSettings);
}
}
}

throw new SKException("Unable to find AI service to handled request.");
}

/// <summary>
/// MicrosoftML token counter implementation.
/// </summary>
private int CountTokens(string input)
{
Tokenizer tokenizer = new(new Bpe());
var tokens = tokenizer.Encode(input).Tokens;

return tokens.Count;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public interface IAIServiceSelector
/// The returned value is a tuple containing instances of <see cref="IAIService"/> and <see cref="AIRequestSettings"/>
/// </summary>
/// <typeparam name="T">Type of AI service to return</typeparam>
/// <param name="renderedPrompt">Rendered prompt</param>
/// <param name="serviceProvider">AI service provider</param>
/// <param name="modelSettings">Collection of model settings</param>
/// <returns></returns>
(T?, AIRequestSettings?) SelectAIService<T>(IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService;
(T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService;
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static ISKFunction RegisterSemanticFunction(
IPromptTemplate promptTemplate,
IAIServiceSelector? serviceSelector = null)
{
return kernel.RegisterSemanticFunction(FunctionCollection.GlobalFunctionsPluginName, functionName, promptTemplateConfig, promptTemplate);
return kernel.RegisterSemanticFunction(FunctionCollection.GlobalFunctionsPluginName, functionName, promptTemplateConfig, promptTemplate, serviceSelector);
}

/// <summary>
Expand All @@ -62,7 +62,7 @@ public static ISKFunction RegisterSemanticFunction(
// Future-proofing the name not to contain special chars
Verify.ValidFunctionName(functionName);

ISKFunction function = kernel.CreateSemanticFunction(pluginName, functionName, promptTemplateConfig, promptTemplate);
ISKFunction function = kernel.CreateSemanticFunction(pluginName, functionName, promptTemplateConfig, promptTemplate, serviceSelector);
return kernel.RegisterCustomFunction(function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal class DelegatingAIServiceSelector : IAIServiceSelector
internal AIRequestSettings? RequestSettings { get; set; }

/// <inheritdoc/>
public (T?, AIRequestSettings?) SelectAIService<T>(IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
public (T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
{
return ((T?)this.ServiceFactory?.Invoke() ?? serviceProvider.GetService<T>(null), this.RequestSettings ?? modelSettings?[0]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.SemanticKernel.Functions;
internal class OrderedIAIServiceSelector : IAIServiceSelector
{
/// <inheritdoc/>
public (T?, AIRequestSettings?) SelectAIService<T>(IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
public (T?, AIRequestSettings?) SelectAIService<T>(string renderedPrompt, IAIServiceProvider serviceProvider, IReadOnlyList<AIRequestSettings>? modelSettings) where T : IAIService
{
if (modelSettings is null || modelSettings.Count == 0)
{
Expand Down
Loading

0 comments on commit 5d4d98f

Please sign in to comment.