Skip to content

Commit

Permalink
.Net: Introduce IAIServiceSelector interface to allow AI service and …
Browse files Browse the repository at this point in the history
…request settings to be selected during semantic function execution (#3227)

### Motivation and Context

Fix following issues for Semantic functions:

1. Allow semantic functions to dynamically retrieve the AI service and
request settings to be used when executing the function

Previously semantic functions execution has the following limitations:

1. The AI service was set on the semantic function when it was created.
This meant that an instance of a semantic function could not be used
with multiple different Kernel instances.
1. The AI request settings were set on the semantic function when it was
created. This meant that different request settings could not be
selected when the function was executed.

### Description

This PR add's a new abstraction called `IAIServiceSelector`. An instance
of this is added to a semantic function when it is created. The default
implementation works as follows:

1. If the semantic function only has a single associated model request
setting instance with no service id then the default AI service and the
model request settings are used. This is consistent with the previous
behaviour.
1. If the semantic function only has a single associated model request
setting instance with a service id then the named AI service and the
model request settings are used. If the named AI service is not
available an exception will be thrown. This is consistent with the
previous behaviour.
1. If the semantic function only has multiple associated model request
setting instances then:
    1. The model request setting instances are considered in order
1. The first model request setting instance with no service id is
considered the default
1. For each model request setting instance that has a service id, if the
service exists then it will be used with the associated model request
settings
1. If no matching service can be found and default request settings are
provided the default service will be used
1. If no matching service can be found and no default request settings
are provided an exception will be thrown

### Contribution Checklist

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Oct 23, 2023
1 parent aef2bdd commit 3601022
Show file tree
Hide file tree
Showing 30 changed files with 814 additions and 219 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TemplateEngine.Basic.Blocks;
using Moq;
using Xunit;
Expand All @@ -21,6 +22,7 @@ public class CodeBlockTests
private readonly Mock<IReadOnlyFunctionCollection> _functions;
private readonly ILoggerFactory _logger = NullLoggerFactory.Instance;
private readonly Mock<IFunctionRunner> _functionRunner = new();
private readonly Mock<IAIServiceProvider> _serviceProvider = new();

public CodeBlockTests()
{
Expand All @@ -32,7 +34,7 @@ public async Task ItThrowsIfAFunctionDoesntExistAsync()
{
// Arrange
var functionRunner = new Mock<IFunctionRunner>();
var context = new SKContext(this._functionRunner.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object);
var target = new CodeBlock("functionName", this._logger);

this._functionRunner.Setup(r => r.RunAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>()))
Expand All @@ -49,7 +51,7 @@ public async Task ItThrowsIfAFunctionDoesntExistAsync()
public async Task ItThrowsIfAFunctionCallThrowsAsync()
{
// Arrange
var context = new SKContext(this._functionRunner.Object, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, functions: this._functions.Object);
var function = new Mock<ISKFunction>();
function
.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), It.IsAny<AIRequestSettings?>(), It.IsAny<CancellationToken>()))
Expand Down Expand Up @@ -145,7 +147,7 @@ public async Task ItRendersCodeBlockConsistingOfJustAVarBlock1Async()
{
// Arrange
var variables = new ContextVariables { ["varName"] = "foo" };
var context = new SKContext(this._functionRunner.Object, variables, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables, functions: this._functions.Object);

// Act
var codeBlock = new CodeBlock("$varName", NullLoggerFactory.Instance);
Expand All @@ -160,7 +162,7 @@ public async Task ItRendersCodeBlockConsistingOfJustAVarBlock2Async()
{
// Arrange
var variables = new ContextVariables { ["varName"] = "bar" };
var context = new SKContext(this._functionRunner.Object, variables, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables, functions: this._functions.Object);
var varBlock = new VarBlock("$varName");

// Act
Expand All @@ -175,7 +177,7 @@ public async Task ItRendersCodeBlockConsistingOfJustAVarBlock2Async()
public async Task ItRendersCodeBlockConsistingOfJustAValBlock1Async()
{
// Arrange
var context = new SKContext(this._functionRunner.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object);

// Act
var codeBlock = new CodeBlock("'ciao'", NullLoggerFactory.Instance);
Expand All @@ -190,7 +192,7 @@ public async Task ItRendersCodeBlockConsistingOfJustAValBlock2Async()
{
// Arrange
var kernel = new Mock<IKernel>();
var context = new SKContext(this._functionRunner.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object);
var valBlock = new ValBlock("'arrivederci'");

// Act
Expand All @@ -209,7 +211,7 @@ public async Task ItInvokesFunctionCloningAllVariablesAsync()
const string Plugin = "pluginName";

var variables = new ContextVariables { ["input"] = "zero", ["var1"] = "uno", ["var2"] = "due" };
var context = new SKContext(this._functionRunner.Object, variables, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables, functions: this._functions.Object);
var funcId = new FunctionIdBlock(Func);

var canary0 = string.Empty;
Expand Down Expand Up @@ -257,7 +259,7 @@ public async Task ItInvokesFunctionWithCustomVariableAsync()
const string VarValue = "varValue";

var variables = new ContextVariables { [Var] = VarValue };
var context = new SKContext(this._functionRunner.Object, variables, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables, functions: this._functions.Object);
var funcId = new FunctionIdBlock(Func);
var varBlock = new VarBlock($"${Var}");

Expand Down Expand Up @@ -290,7 +292,7 @@ public async Task ItInvokesFunctionWithCustomValueAsync()
const string Plugin = "pluginName";
const string Value = "value";

var context = new SKContext(this._functionRunner.Object, variables: null, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables: null, functions: this._functions.Object);
var funcId = new FunctionIdBlock(Func);
var valBlock = new ValBlock($"'{Value}'");

Expand Down Expand Up @@ -328,7 +330,7 @@ public async Task ItInvokesFunctionWithNamedArgsAsync()
var variables = new ContextVariables();
variables.Set("bob", BobValue);
variables.Set("input", Value);
var context = new SKContext(this._functionRunner.Object, variables: variables, functions: this._functions.Object);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables: variables, functions: this._functions.Object);
var funcId = new FunctionIdBlock(Func);
var namedArgBlock1 = new NamedArgBlock($"foo='{FooValue}'");
var namedArgBlock2 = new NamedArgBlock("baz=$bob");
Expand Down Expand Up @@ -362,7 +364,7 @@ private void MockFunctionRunner(ISKFunction function)
this._functionRunner.Setup(r => r.RunAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>()))
.Returns<string, string, ContextVariables, CancellationToken>((pluginName, functionName, variables, cancellationToken) =>
{
var context = new SKContext(this._functionRunner.Object, variables);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables);
return function.InvokeAsync(context, null, cancellationToken);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TemplateEngine.Basic;
using Microsoft.SemanticKernel.TemplateEngine.Basic.Blocks;
using Moq;
Expand All @@ -29,6 +30,7 @@ public sealed class PromptTemplateEngineTests
private readonly ITestOutputHelper _logger;
private readonly Mock<IKernel> _kernel;
private readonly Mock<IFunctionRunner> _functionRunner;
private readonly Mock<IAIServiceProvider> _serviceProvider;

public PromptTemplateEngineTests(ITestOutputHelper testOutputHelper)
{
Expand All @@ -38,6 +40,7 @@ public PromptTemplateEngineTests(ITestOutputHelper testOutputHelper)
this._functions = new Mock<IReadOnlyFunctionCollection>();
this._kernel = new Mock<IKernel>();
this._functionRunner = new Mock<IFunctionRunner>();
this._serviceProvider = new Mock<IAIServiceProvider>();
}

[Fact]
Expand Down Expand Up @@ -375,7 +378,7 @@ private void MockFunctionRunner(ISKFunction function)
this._functionRunner.Setup(r => r.RunAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>()))
.Returns<string, string, ContextVariables, CancellationToken>((pluginName, functionName, variables, cancellationToken) =>
{
var context = new SKContext(this._functionRunner.Object, variables);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables);
return function.InvokeAsync(context, null, cancellationToken);
});
}
Expand All @@ -385,7 +388,7 @@ private void MockFunctionRunner(List<ISKFunction> functions)
this._functionRunner.Setup(r => r.RunAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>()))
.Returns<string, string, ContextVariables, CancellationToken>((pluginName, functionName, variables, cancellationToken) =>
{
var context = new SKContext(this._functionRunner.Object, variables);
var context = new SKContext(this._functionRunner.Object, this._serviceProvider.Object, variables);
var function = functions.First(f => f.PluginName == functionName);

return function.InvokeAsync(context, null, cancellationToken);
Expand All @@ -396,6 +399,7 @@ private SKContext MockContext()
{
return new SKContext(
this._functionRunner.Object,
this._serviceProvider.Object,
this._variables,
this._functions.Object);
}
Expand Down
4 changes: 0 additions & 4 deletions dotnet/src/IntegrationTests/Plugins/SamplePluginsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ public void CanLoadSamplePluginsRequestSettings()
{
var function = kernel.Functions.GetFunction(view.PluginName, view.Name);
Assert.NotNull(function);
Assert.NotNull(function.RequestSettings);
Assert.True(function.RequestSettings.ExtensionData.ContainsKey("max_tokens"));
});
}

Expand All @@ -49,8 +47,6 @@ public void CanLoadSampleSkillsCompletions()
{
var function = kernel.Functions.GetFunction(view.PluginName, view.Name);
Assert.NotNull(function);
Assert.NotNull(function.RequestSettings);
Assert.True(function.RequestSettings.ExtensionData.ContainsKey("max_tokens"));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Moq;
using Xunit;

Expand Down Expand Up @@ -163,11 +164,12 @@ private Mock<IKernel> CreateMockKernelAndFunctionFlowWithTestString(string testP
functions.Setup(x => x.GetFunctionViews()).Returns(new List<FunctionView>());
}
var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();
var kernel = new Mock<IKernel>();

var returnContext = new SKContext(functionRunner.Object, new ContextVariables(testPlanString), functions.Object);
var returnContext = new SKContext(functionRunner.Object, serviceProvider.Object, new ContextVariables(testPlanString), functions.Object);

var context = new SKContext(functionRunner.Object, functions: functions.Object);
var context = new SKContext(functionRunner.Object, serviceProvider.Object, functions: functions.Object);

var mockFunctionFlowFunction = new Mock<ISKFunction>();
mockFunctionFlowFunction.Setup(x => x.InvokeAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Moq;
using Xunit;

Expand Down Expand Up @@ -57,9 +58,10 @@ public async Task CanCallGetAvailableFunctionsWithNoFunctionsAsync(Type t)
.Returns(asyncEnumerable);

var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(functionRunner.Object, variables);
var context = new SKContext(functionRunner.Object, serviceProvider.Object, variables);
var config = InitializeConfig(t);
var semanticQuery = "test";

Expand Down Expand Up @@ -136,9 +138,10 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsAsync(Type t)
functions.Setup(x => x.GetFunctionViews()).Returns(functionsView);

var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(functionRunner.Object, variables, functions.Object);
var context = new SKContext(functionRunner.Object, serviceProvider.Object, variables, functions.Object);
var config = InitializeConfig(t);
var semanticQuery = "test";

Expand Down Expand Up @@ -206,9 +209,10 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsWithRelevancyAsync(Ty
functions.Setup(x => x.GetFunctionViews()).Returns(functionsView);

var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(functionRunner.Object, variables, functions.Object);
var context = new SKContext(functionRunner.Object, serviceProvider.Object, variables, functions.Object);
var config = InitializeConfig(t);
config.SemanticMemoryConfig = new() { RelevancyThreshold = 0.78, Memory = memory.Object };
var semanticQuery = "test";
Expand Down Expand Up @@ -243,6 +247,7 @@ public async Task CanCallGetAvailableFunctionsAsyncWithDefaultRelevancyAsync(Typ
// Arrange
var kernel = new Mock<IKernel>();
var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();

var variables = new ContextVariables();
var functions = new FunctionCollection();
Expand All @@ -267,7 +272,7 @@ public async Task CanCallGetAvailableFunctionsAsyncWithDefaultRelevancyAsync(Typ
.Returns(asyncEnumerable);

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(functionRunner.Object, variables);
var context = new SKContext(functionRunner.Object, serviceProvider.Object, variables);
var config = InitializeConfig(t);
config.SemanticMemoryConfig = new() { RelevancyThreshold = 0.78, Memory = memory.Object };
var semanticQuery = "test";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TemplateEngine;
using Moq;
using Xunit;
Expand Down Expand Up @@ -38,9 +39,10 @@ private Mock<IKernel> CreateKernelMock(

private SKContext CreateSKContext(
IFunctionRunner functionRunner,
IAIServiceProvider serviceProvider,
ContextVariables? variables = null)
{
return new SKContext(functionRunner, variables);
return new SKContext(functionRunner, serviceProvider, variables);
}

private static Mock<ISKFunction> CreateMockFunction(FunctionView functionView, string result = "")
Expand All @@ -59,12 +61,13 @@ private void CreateKernelAndFunctionCreateMocks(List<(string name, string plugin
kernel = kernelMock.Object;

var functionRunnerMock = new Mock<IFunctionRunner>();
var serviceProviderMock = new Mock<IAIServiceProvider>();

// For Create
kernelMock.Setup(k => k.CreateNewContext(It.IsAny<ContextVariables>(), It.IsAny<IReadOnlyFunctionCollection>(), It.IsAny<ILoggerFactory>(), It.IsAny<CultureInfo>()))
.Returns<ContextVariables, IReadOnlyFunctionCollection, ILoggerFactory, CultureInfo>((contextVariables, skills, loggerFactory, culture) =>
{
return this.CreateSKContext(functionRunnerMock.Object, contextVariables);
return this.CreateSKContext(functionRunnerMock.Object, serviceProviderMock.Object, contextVariables);
});

var functionsView = new List<FunctionView>();
Expand All @@ -77,7 +80,7 @@ private void CreateKernelAndFunctionCreateMocks(List<(string name, string plugin
var mockFunction = CreateMockFunction(functionView);
functionsView.Add(functionView);

var result = this.CreateSKContext(functionRunnerMock.Object);
var result = this.CreateSKContext(functionRunnerMock.Object, serviceProviderMock.Object);
result.Variables.Update(resultString);
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, It.IsAny<CancellationToken>()))
.ReturnsAsync(new FunctionResult(name, pluginName, result));
Expand All @@ -88,7 +91,8 @@ private void CreateKernelAndFunctionCreateMocks(List<(string name, string plugin
It.IsAny<string>(),
It.IsAny<string>(),
It.IsAny<PromptTemplateConfig>(),
It.IsAny<IPromptTemplate>()
It.IsAny<IPromptTemplate>(),
It.IsAny<IAIServiceSelector>()
)).Returns(mockFunction.Object);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Services;
using Moq;
using Xunit;

Expand Down Expand Up @@ -60,17 +61,20 @@ public async Task ItCanCreatePlanAsync(string goal)

functions.Setup(x => x.GetFunctionViews()).Returns(functionsView);
var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();
kernel.Setup(x => x.LoggerFactory).Returns(new Mock<ILoggerFactory>().Object);

var expectedFunctions = input.Select(x => x.name).ToList();
var expectedPlugins = input.Select(x => x.pluginName).ToList();

var context = new SKContext(
functionRunner.Object,
serviceProvider.Object,
new ContextVariables());

var returnContext = new SKContext(
functionRunner.Object,
serviceProvider.Object,
new ContextVariables());

var planString =
Expand Down Expand Up @@ -147,15 +151,16 @@ public async Task InvalidXMLThrowsAsync()
{
// Arrange
var functionRunner = new Mock<IFunctionRunner>();
var serviceProvider = new Mock<IAIServiceProvider>();
var kernel = new Mock<IKernel>();
var functions = new Mock<IFunctionCollection>();

functions.Setup(x => x.GetFunctionViews()).Returns(new List<FunctionView>());

var planString = "<plan>notvalid<</plan>";
var returnContext = new SKContext(functionRunner.Object, new ContextVariables(planString));
var returnContext = new SKContext(functionRunner.Object, serviceProvider.Object, new ContextVariables(planString));

var context = new SKContext(functionRunner.Object, new ContextVariables());
var context = new SKContext(functionRunner.Object, serviceProvider.Object, new ContextVariables());

var mockFunctionFlowFunction = new Mock<ISKFunction>();
mockFunctionFlowFunction.Setup(x => x.InvokeAsync(
Expand Down
Loading

0 comments on commit 3601022

Please sign in to comment.