Skip to content

Commit

Permalink
Move several IKernel methods to a KernelExtensions class
Browse files Browse the repository at this point in the history
`IKernel` declares 7 overloads of `RunAsync`, and `Kernel` dutifully implements all 7, but 6 of them are one-liner convenience methods that just delegate to the 7th.  Rather than having all of these on the interface and making any implementer implement them all, 6 of them can just be extension methods.

`IKernel` also declares an `ImportFunctions` method, which is almost entirely reflection-based logic unrelated to the kernel instance: the only thing it actually needs from the kernel is to be able to add functions to its functions collection, and `RegisterCustomFunction` provides that. So `ImportFunctions` can also be made into an extension method, such that any `IKernel` implementation gets it for free, and the interface is less cluttered.

Also fixed the nullable annotation on Kernel's ctor's ILoggerFactory parameter.
  • Loading branch information
stephentoub committed Oct 15, 2023
1 parent bdc517d commit 5522897
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public async Task ItCanCreatePlanAsync(string goal)
// Arrange
var kernel = new Mock<IKernel>();
kernel.Setup(x => x.LoggerFactory).Returns(new Mock<ILoggerFactory>().Object);
kernel.Setup(x => x.RunAsync(It.IsAny<ISKFunction>(), It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>()))
.Returns<ISKFunction, ContextVariables, CancellationToken>(async (function, vars, cancellationToken) =>
kernel.Setup(x => x.RunAsync(It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>(), It.IsAny<ISKFunction>()))
.Returns<ContextVariables, CancellationToken, ISKFunction[]>(async (vars, cancellationToken, functions) =>
{
var functionResult = await function.InvokeAsync(kernel.Object, vars, cancellationToken: cancellationToken);
var functionResult = await functions[0].InvokeAsync(kernel.Object, vars, cancellationToken: cancellationToken);
return KernelResult.FromFunctionResults(functionResult.GetValue<string>(), new List<FunctionResult> { functionResult });
});

Expand Down Expand Up @@ -168,10 +168,10 @@ public async Task InvalidXMLThrowsAsync()

// Mock Plugins
kernel.Setup(x => x.Functions).Returns(functions.Object);
kernel.Setup(x => x.RunAsync(It.IsAny<ISKFunction>(), It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>()))
.Returns<ISKFunction, ContextVariables, CancellationToken>(async (function, vars, cancellationToken) =>
kernel.Setup(x => x.RunAsync(It.IsAny<ContextVariables>(), It.IsAny<CancellationToken>(), It.IsAny<ISKFunction>()))
.Returns<ContextVariables, CancellationToken, ISKFunction[]>(async (vars, cancellationToken, functions) =>
{
var functionResult = await function.InvokeAsync(kernel.Object, vars, cancellationToken: cancellationToken);
var functionResult = await functions[0].InvokeAsync(kernel.Object, vars, cancellationToken: cancellationToken);
return KernelResult.FromFunctionResults(functionResult.GetValue<string>(), new List<FunctionResult> { functionResult });
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Moq;
using Xunit;

Expand All @@ -25,7 +25,7 @@ public void WhenInputIsFinalAnswerReturnsFinalAnswer(string input, string expect
{
// Arrange
var kernel = new Mock<IKernel>();
kernel.Setup(x => x.LoggerFactory).Returns(new Mock<ILoggerFactory>().Object);
kernel.Setup(x => x.LoggerFactory).Returns(NullLoggerFactory.Instance);

var planner = new StepwisePlanner(kernel.Object);

Expand Down Expand Up @@ -77,7 +77,7 @@ public void ParseActionReturnsAction(string input, string expectedThought, strin

// Arrange
var kernel = new Mock<IKernel>();
kernel.Setup(x => x.LoggerFactory).Returns(new Mock<ILoggerFactory>().Object);
kernel.Setup(x => x.LoggerFactory).Returns(NullLoggerFactory.Instance);

var planner = new StepwisePlanner(kernel.Object);

Expand Down
73 changes: 1 addition & 72 deletions dotnet/src/SemanticKernel.Abstractions/IKernel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace Microsoft.SemanticKernel;
public interface IKernel
{
/// <summary>
/// App logger
/// The ILoggerFactory used to create a logger for logging.
/// </summary>
ILoggerFactory LoggerFactory { get; }

Expand All @@ -48,77 +48,6 @@ public interface IKernel
/// <returns>A C# function wrapping the function execution logic.</returns>
ISKFunction RegisterCustomFunction(ISKFunction customFunction);

/// <summary>
/// Import a set of functions as a plugin from the given object instance. Only the functions that have the `SKFunction` attribute will be included in the plugin.
/// Once these functions are imported, the prompt templates can use functions to import content at runtime.
/// </summary>
/// <param name="functionsInstance">Instance of a class containing functions</param>
/// <param name="pluginName">Name of the plugin for function collection and prompt templates. If the value is empty functions are registered in the global namespace.</param>
/// <returns>A list of all the semantic functions found in the directory, indexed by function name.</returns>
IDictionary<string, ISKFunction> ImportFunctions(object functionsInstance, string? pluginName = null);

/// <summary>
/// Run a single synchronous or asynchronous <see cref="ISKFunction"/>.
/// </summary>
/// <param name="skFunction">A Semantic Kernel function to run</param>
/// <param name="variables">Input to process</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Result of the function</returns>
Task<KernelResult> RunAsync(
ISKFunction skFunction,
ContextVariables? variables = null,
CancellationToken cancellationToken = default);

/// <summary>
/// Run a pipeline composed of synchronous and asynchronous functions.
/// </summary>
/// <param name="pipeline">List of functions</param>
/// <returns>Result of the function composition</returns>
Task<KernelResult> RunAsync(
params ISKFunction[] pipeline);

/// <summary>
/// Run a pipeline composed of synchronous and asynchronous functions.
/// </summary>
/// <param name="input">Input to process</param>
/// <param name="pipeline">List of functions</param>
/// <returns>Result of the function composition</returns>
Task<KernelResult> RunAsync(
string input,
params ISKFunction[] pipeline);

/// <summary>
/// Run a pipeline composed of synchronous and asynchronous functions.
/// </summary>
/// <param name="variables">Input to process</param>
/// <param name="pipeline">List of functions</param>
/// <returns>Result of the function composition</returns>
Task<KernelResult> RunAsync(
ContextVariables variables,
params ISKFunction[] pipeline);

/// <summary>
/// Run a pipeline composed of synchronous and asynchronous functions.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <param name="pipeline">List of functions</param>
/// <returns>Result of the function composition</returns>
Task<KernelResult> RunAsync(
CancellationToken cancellationToken,
params ISKFunction[] pipeline);

/// <summary>
/// Run a pipeline composed of synchronous and asynchronous functions.
/// </summary>
/// <param name="input">Input to process</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <param name="pipeline">List of functions</param>
/// <returns>Result of the function composition</returns>
Task<KernelResult> RunAsync(
string input,
CancellationToken cancellationToken,
params ISKFunction[] pipeline);

/// <summary>
/// Run a pipeline composed of synchronous and asynchronous functions.
/// </summary>
Expand Down
101 changes: 6 additions & 95 deletions dotnet/src/SemanticKernel.Core/Kernel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Globalization;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Events;
using Microsoft.SemanticKernel.Http;
Expand Down Expand Up @@ -47,7 +47,7 @@ public sealed class Kernel : IKernel, IDisposable
public static KernelBuilder Builder => new();

/// <inheritdoc/>
public IDelegatingHandlerFactory HttpHandlerFactory => this._httpHandlerFactory;
public IDelegatingHandlerFactory HttpHandlerFactory { get; }

/// <inheritdoc/>
public event EventHandler<FunctionInvokingEventArgs>? FunctionInvoking;
Expand All @@ -70,10 +70,12 @@ public Kernel(
IPromptTemplateEngine promptTemplateEngine,
ISemanticTextMemory memory,
IDelegatingHandlerFactory httpHandlerFactory,
ILoggerFactory loggerFactory)
ILoggerFactory? loggerFactory)
{
loggerFactory ??= NullLoggerFactory.Instance;

this.LoggerFactory = loggerFactory;
this._httpHandlerFactory = httpHandlerFactory;
this.HttpHandlerFactory = httpHandlerFactory;
this.PromptTemplateEngine = promptTemplateEngine;
this._memory = memory;
this._aiServiceProvider = aiServiceProvider;
Expand All @@ -83,36 +85,6 @@ public Kernel(
this._logger = loggerFactory.CreateLogger(typeof(Kernel));
}

/// <inheritdoc/>
public IDictionary<string, ISKFunction> ImportFunctions(object functionsInstance, string? pluginName = null)
{
Verify.NotNull(functionsInstance);

if (string.IsNullOrWhiteSpace(pluginName))
{
pluginName = FunctionCollection.GlobalFunctionsPluginName;
this._logger.LogTrace("Importing functions from {0} to the global plugin namespace", functionsInstance.GetType().FullName);
}
else
{
this._logger.LogTrace("Importing functions from {0} to the {1} namespace", functionsInstance.GetType().FullName, pluginName);
}

Dictionary<string, ISKFunction> functions = ImportFunctions(
functionsInstance,
pluginName!,
this._logger,
this.LoggerFactory
);
foreach (KeyValuePair<string, ISKFunction> f in functions)
{
f.Value.SetDefaultFunctionCollection(this.Functions);
this._functionCollection.AddFunction(f.Value);
}

return functions;
}

/// <inheritdoc/>
public ISKFunction RegisterCustomFunction(ISKFunction customFunction)
{
Expand All @@ -124,32 +96,6 @@ public ISKFunction RegisterCustomFunction(ISKFunction customFunction)
return customFunction;
}

/// <inheritdoc/>
public Task<KernelResult> RunAsync(ISKFunction skFunction,
ContextVariables? variables = null,
CancellationToken cancellationToken = default)
=> this.RunAsync(variables ?? new(), cancellationToken, skFunction);

/// <inheritdoc/>
public Task<KernelResult> RunAsync(params ISKFunction[] pipeline)
=> this.RunAsync(new ContextVariables(), pipeline);

/// <inheritdoc/>
public Task<KernelResult> RunAsync(string input, params ISKFunction[] pipeline)
=> this.RunAsync(new ContextVariables(input), pipeline);

/// <inheritdoc/>
public Task<KernelResult> RunAsync(ContextVariables variables, params ISKFunction[] pipeline)
=> this.RunAsync(variables, CancellationToken.None, pipeline);

/// <inheritdoc/>
public Task<KernelResult> RunAsync(CancellationToken cancellationToken, params ISKFunction[] pipeline)
=> this.RunAsync(new ContextVariables(), cancellationToken, pipeline);

/// <inheritdoc/>
public Task<KernelResult> RunAsync(string input, CancellationToken cancellationToken, params ISKFunction[] pipeline)
=> this.RunAsync(new ContextVariables(input), cancellationToken, pipeline);

/// <inheritdoc/>
public async Task<KernelResult> RunAsync(ContextVariables variables, CancellationToken cancellationToken, params ISKFunction[] pipeline)
{
Expand Down Expand Up @@ -266,7 +212,6 @@ public void Dispose()
private readonly IPromptTemplateEngine _promptTemplateEngine;
private readonly IAIServiceProvider _aiServiceProvider;
private readonly ILogger _logger;
private readonly IDelegatingHandlerFactory _httpHandlerFactory;

/// <summary>
/// Execute the OnFunctionInvoking event handlers.
Expand Down Expand Up @@ -306,40 +251,6 @@ public void Dispose()
return null;
}

/// <summary>
/// Import a native functions into the kernel function collection, so that semantic functions and pipelines can consume its functions.
/// </summary>
/// <param name="pluginInstance">Class instance from which to import available native functions</param>
/// <param name="pluginName">Plugin name, used to group functions under a shared namespace</param>
/// <param name="logger">Application logger</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
/// <returns>Dictionary of functions imported from the given class instance, case-insensitively indexed by name.</returns>
private static Dictionary<string, ISKFunction> ImportFunctions(object pluginInstance, string pluginName, ILogger logger, ILoggerFactory loggerFactory)
{
MethodInfo[] methods = pluginInstance.GetType().GetMethods(BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public);
logger.LogTrace("Importing plugin name: {0}. Potential methods found: {1}", pluginName, methods.Length);

// Filter out non-SKFunctions and fail if two functions have the same name
Dictionary<string, ISKFunction> result = new(StringComparer.OrdinalIgnoreCase);
foreach (MethodInfo method in methods)
{
if (method.GetCustomAttribute<SKFunctionAttribute>() is not null)
{
ISKFunction function = SKFunction.FromNativeMethod(method, pluginInstance, pluginName, loggerFactory);
if (result.ContainsKey(function.Name))
{
throw new SKException("Function overloads are not supported, please differentiate function names");
}

result.Add(function.Name, function);
}
}

logger.LogTrace("Methods imported {0}", result.Count);

return result;
}

#endregion

#region Obsolete ===============================================================================
Expand Down
Loading

0 comments on commit 5522897

Please sign in to comment.