From 6ceac5bad330b74793df0ed826402858a6beb84e Mon Sep 17 00:00:00 2001 From: Oleksandr Liakhevych Date: Sun, 6 Oct 2024 20:30:09 +0300 Subject: [PATCH] Minor refactoring --- ...jectionGenerator.FindServicesToRegister.cs | 213 +++++++++++++ ...encyInjectionGenerator.ParseMethodModel.cs | 40 +++ .../DependencyInjectionGenerator.cs | 297 ++---------------- 3 files changed, 286 insertions(+), 264 deletions(-) create mode 100644 ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs create mode 100644 ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs new file mode 100644 index 0000000..6ab62de --- /dev/null +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs @@ -0,0 +1,213 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.CodeAnalysis; +using ServiceScan.SourceGenerator.Model; +using static ServiceScan.SourceGenerator.DiagnosticDescriptors; + +namespace ServiceScan.SourceGenerator; + +public partial class DependencyInjectionGenerator +{ + private static DiagnosticModel FindServicesToRegister((DiagnosticModel, Compilation) context) + { + var (diagnosticModel, compilation) = context; + var diagnostic = diagnosticModel.Diagnostic; + + if (diagnostic != null) + return diagnostic; + + var (method, attributes) = diagnosticModel.Model; + + var registrations = new List(); + + foreach (var attribute in attributes) + { + bool typesFound = false; + + var containingType = compilation.GetTypeByMetadataName(method.TypeMetadataName); + + var assembly = (attribute.AssemblyOfTypeName is null + ? containingType + : compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName)).ContainingAssembly; + + var assignableToType = attribute.AssignableToTypeName is null + ? null + : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); + + var keySelectorMethod = attribute.KeySelector is null + ? null + : containingType.GetMembers().OfType().FirstOrDefault(m => + m.IsStatic && m.Name == attribute.KeySelector); + + if (attribute.KeySelector != null) + { + if (keySelectorMethod is null) + return Diagnostic.Create(KeySelectorMethodNotFound, attribute.Location); + + if (keySelectorMethod.ReturnsVoid) + return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); + + var validGenericKeySelector = keySelectorMethod.TypeArguments.Length == 1 && keySelectorMethod.Parameters.Length == 0; + var validNonGenericKeySelector = !keySelectorMethod.IsGenericMethod && keySelectorMethod.Parameters is [{ Type.Name: nameof(Type) }]; + + if (!validGenericKeySelector && !validNonGenericKeySelector) + return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); + } + + if (assignableToType != null && attribute.AssignableToGenericArguments != null) + { + var typeArguments = attribute.AssignableToGenericArguments.Value.Select(t => compilation.GetTypeByMetadataName(t)).ToArray(); + assignableToType = assignableToType.Construct(typeArguments); + } + + var types = GetTypesFromAssembly(assembly) + .Where(t => !t.IsAbstract && !t.IsStatic && t.CanBeReferencedByName && t.TypeKind == TypeKind.Class); + + if (attribute.TypeNameFilter != null) + { + var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$"; + types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex)); + } + + foreach (var t in types) + { + var implementationType = t; + + INamedTypeSymbol matchedType = null; + if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType)) + continue; + + IEnumerable serviceTypes = (attribute.AsSelf, attribute.AsImplementedInterfaces) switch + { + (true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces), + (false, true) => implementationType.AllInterfaces, + (true, false) => [implementationType], + _ => [matchedType ?? implementationType] + }; + + foreach (var serviceType in serviceTypes) + { + if (implementationType.IsGenericType) + { + var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(); + var serviceTypeName = serviceType.IsGenericType + ? serviceType.ConstructUnboundGenericType().ToDisplayString() + : serviceType.ToDisplayString(); + + var registration = new ServiceRegistrationModel( + attribute.Lifetime, + serviceTypeName, + implementationTypeName, + false, + true, + keySelectorMethod?.Name, + keySelectorMethod?.IsGenericMethod); + + registrations.Add(registration); + } + else + { + var shouldResolve = attribute.AsSelf && attribute.AsImplementedInterfaces && !SymbolEqualityComparer.Default.Equals(implementationType, serviceType); + var registration = new ServiceRegistrationModel( + attribute.Lifetime, + serviceType.ToDisplayString(), + implementationType.ToDisplayString(), + shouldResolve, + false, + keySelectorMethod?.Name, + keySelectorMethod?.IsGenericMethod); + registrations.Add(registration); + } + + typesFound = true; + } + } + + if (!typesFound) + diagnostic ??= Diagnostic.Create(NoMatchingTypesFound, attribute.Location); + } + + var implementationModel = new MethodImplementationModel(method, new EquatableArray([.. registrations])); + return new(diagnostic, implementationModel); + } + + private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol matchedType) + { + if (SymbolEqualityComparer.Default.Equals(type, assignableTo)) + { + matchedType = type; + return true; + } + + if (assignableTo.IsGenericType && assignableTo.IsDefinition) + { + if (assignableTo.TypeKind == TypeKind.Interface) + { + var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo)); + matchedType = matchingInterface; + return matchingInterface != null; + } + + var baseType = type.BaseType; + while (baseType != null) + { + if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo)) + { + matchedType = baseType; + return true; + } + + baseType = baseType.BaseType; + } + } + else + { + if (assignableTo.TypeKind == TypeKind.Interface) + { + matchedType = assignableTo; + return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default); + } + + var baseType = type.BaseType; + while (baseType != null) + { + if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo)) + { + matchedType = baseType; + return true; + } + + baseType = baseType.BaseType; + } + } + + matchedType = null; + return false; + } + + private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) + { + var @namespace = assemblySymbol.GlobalNamespace; + return GetTypesFromNamespace(@namespace); + + static IEnumerable GetTypesFromNamespace(INamespaceSymbol namespaceSymbol) + { + foreach (var member in namespaceSymbol.GetMembers()) + { + if (member is INamedTypeSymbol namedType) + { + yield return namedType; + } + else if (member is INamespaceSymbol nestedNamespace) + { + foreach (var type in GetTypesFromNamespace(nestedNamespace)) + { + yield return type; + } + } + } + } + } +} diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs new file mode 100644 index 0000000..acfdc14 --- /dev/null +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs @@ -0,0 +1,40 @@ +using Microsoft.CodeAnalysis; +using ServiceScan.SourceGenerator.Model; +using static ServiceScan.SourceGenerator.DiagnosticDescriptors; + +namespace ServiceScan.SourceGenerator; + +public partial class DependencyInjectionGenerator +{ + private static DiagnosticModel ParseMethodModel(GeneratorAttributeSyntaxContext context) + { + if (context.TargetSymbol is not IMethodSymbol method) + return null; + + if (!method.IsPartialDefinition) + return Diagnostic.Create(NotPartialDefinition, method.Locations[0]); + + var serviceCollectionType = context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Extensions.DependencyInjection.IServiceCollection"); + + if (!method.ReturnsVoid && !SymbolEqualityComparer.Default.Equals(method.ReturnType, serviceCollectionType)) + return Diagnostic.Create(WrongReturnType, method.Locations[0]); + + if (method.Parameters.Length != 1 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, serviceCollectionType)) + return Diagnostic.Create(WrongMethodParameters, method.Locations[0]); + + var attributeData = new AttributeModel[context.Attributes.Length]; + for (var i = 0; i < context.Attributes.Length; i++) + { + attributeData[i] = AttributeModel.Create(context.Attributes[i]); + + if (!attributeData[i].HasSearchCriteria) + return Diagnostic.Create(MissingSearchCriteria, attributeData[i].Location); + + if (attributeData[i].HasErrors) + return null; + } + + var model = MethodModel.Create(method, context.TargetNode); + return new MethodWithAttributesModel(model, new EquatableArray(attributeData)); + } +} diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs index 2259373..c9f1fe8 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs @@ -1,13 +1,9 @@ -using System; -using System.Collections.Generic; -using System.Linq; +using System.Linq; using System.Text; -using System.Text.RegularExpressions; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; using ServiceScan.SourceGenerator.Model; -using static ServiceScan.SourceGenerator.DiagnosticDescriptors; namespace ServiceScan.SourceGenerator; @@ -40,291 +36,64 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return; var (method, registrations) = src.Model; - - var sb = new StringBuilder(); - - foreach (var registration in registrations) - { - if (registration.IsOpenGeneric) - { - sb.AppendLine($" .Add{registration.Lifetime}(typeof({registration.ServiceTypeName}), typeof({registration.ImplementationTypeName}))"); - } - else - { - if (registration.ResolveImplementation) - { - sb.AppendLine($" .Add{registration.Lifetime}<{registration.ServiceTypeName}>(s => s.GetRequiredService<{registration.ImplementationTypeName}>())"); - } - else - { - var addMethod = registration.KeySelectorMethodName != null - ? $"AddKeyed{registration.Lifetime}" - : $"Add{registration.Lifetime}"; - - var keyMethodInvocation = registration.KeySelectorMethodGeneric switch - { - true => $"{registration.KeySelectorMethodName}<{registration.ImplementationTypeName}>()", - false => $"{registration.KeySelectorMethodName}(typeof({registration.ImplementationTypeName}))", - null => null - }; - sb.AppendLine($" .{addMethod}<{registration.ServiceTypeName}, {registration.ImplementationTypeName}>({keyMethodInvocation})"); - } - } - } - - var returnType = method.ReturnsVoid ? "void" : "IServiceCollection"; - - var namespaceDeclaration = method.Namespace is null ? "" : $"namespace {method.Namespace};"; - - var source = $$""" - using Microsoft.Extensions.DependencyInjection; - - {{namespaceDeclaration}} - - {{method.TypeModifiers}} class {{method.TypeName}} - { - {{method.MethodModifiers}} {{returnType}} {{method.MethodName}}({{(method.IsExtensionMethod ? "this" : "")}} IServiceCollection {{method.ParameterName}}) - { - {{(method.ReturnsVoid ? "" : "return ")}}{{method.ParameterName}} - {{sb.ToString().Trim()}}; - } - } - """; + string source = GenerateSource(method, registrations); context.AddSource($"{method.TypeName}_{method.MethodName}.Generated.cs", SourceText.From(source, Encoding.UTF8)); }); } - private static DiagnosticModel FindServicesToRegister((DiagnosticModel, Compilation) context) + private static string GenerateSource(MethodModel method, EquatableArray registrations) { - var (diagnosticModel, compilation) = context; - var diagnostic = diagnosticModel.Diagnostic; - - if (diagnostic != null) - return diagnostic; - - var (method, attributes) = diagnosticModel.Model; + var sb = new StringBuilder(); - var registrations = new List(); - - foreach (var attribute in attributes) + foreach (var registration in registrations) { - bool typesFound = false; - - var containingType = compilation.GetTypeByMetadataName(method.TypeMetadataName); - - var assembly = (attribute.AssemblyOfTypeName is null - ? containingType - : compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName)).ContainingAssembly; - - var assignableToType = attribute.AssignableToTypeName is null - ? null - : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); - - var keySelectorMethod = attribute.KeySelector is null - ? null - : containingType.GetMembers().OfType().FirstOrDefault(m => - m.IsStatic && m.Name == attribute.KeySelector); - - if (attribute.KeySelector != null) - { - if (keySelectorMethod is null) - return Diagnostic.Create(KeySelectorMethodNotFound, attribute.Location); - - if (keySelectorMethod.ReturnsVoid) - return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); - - var validGenericKeySelector = keySelectorMethod.TypeArguments.Length == 1 && keySelectorMethod.Parameters.Length == 0; - var validNonGenericKeySelector = !keySelectorMethod.IsGenericMethod && keySelectorMethod.Parameters is [{ Type.Name: nameof(Type) }]; - - if (!validGenericKeySelector && !validNonGenericKeySelector) - return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); - } - - if (assignableToType != null && attribute.AssignableToGenericArguments != null) + if (registration.IsOpenGeneric) { - var typeArguments = attribute.AssignableToGenericArguments.Value.Select(t => compilation.GetTypeByMetadataName(t)).ToArray(); - assignableToType = assignableToType.Construct(typeArguments); + sb.AppendLine($" .Add{registration.Lifetime}(typeof({registration.ServiceTypeName}), typeof({registration.ImplementationTypeName}))"); } - - var types = GetTypesFromAssembly(assembly) - .Where(t => !t.IsAbstract && !t.IsStatic && t.CanBeReferencedByName && t.TypeKind == TypeKind.Class); - - if (attribute.TypeNameFilter != null) + else { - var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$"; - types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex)); - } - - foreach (var t in types) - { - var implementationType = t; - - INamedTypeSymbol matchedType = null; - if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType)) - continue; - - IEnumerable serviceTypes = (attribute.AsSelf, attribute.AsImplementedInterfaces) switch + if (registration.ResolveImplementation) { - (true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces), - (false, true) => implementationType.AllInterfaces, - (true, false) => [implementationType], - _ => [matchedType ?? implementationType] - }; - - foreach (var serviceType in serviceTypes) + sb.AppendLine($" .Add{registration.Lifetime}<{registration.ServiceTypeName}>(s => s.GetRequiredService<{registration.ImplementationTypeName}>())"); + } + else { - if (implementationType.IsGenericType) - { - var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(); - var serviceTypeName = serviceType.IsGenericType - ? serviceType.ConstructUnboundGenericType().ToDisplayString() - : serviceType.ToDisplayString(); - - var registration = new ServiceRegistrationModel( - attribute.Lifetime, - serviceTypeName, - implementationTypeName, - false, - true, - keySelectorMethod?.Name, - keySelectorMethod?.IsGenericMethod); + var addMethod = registration.KeySelectorMethodName != null + ? $"AddKeyed{registration.Lifetime}" + : $"Add{registration.Lifetime}"; - registrations.Add(registration); - } - else + var keyMethodInvocation = registration.KeySelectorMethodGeneric switch { - var shouldResolve = attribute.AsSelf && attribute.AsImplementedInterfaces && !SymbolEqualityComparer.Default.Equals(implementationType, serviceType); - var registration = new ServiceRegistrationModel( - attribute.Lifetime, - serviceType.ToDisplayString(), - implementationType.ToDisplayString(), - shouldResolve, - false, - keySelectorMethod?.Name, - keySelectorMethod?.IsGenericMethod); - registrations.Add(registration); - } - - typesFound = true; - } - } - - if (!typesFound) - diagnostic ??= Diagnostic.Create(NoMatchingTypesFound, attribute.Location); - } - - var implementationModel = new MethodImplementationModel(method, new EquatableArray([.. registrations])); - return new(diagnostic, implementationModel); - } - - private static DiagnosticModel ParseMethodModel(GeneratorAttributeSyntaxContext context) - { - if (context.TargetSymbol is not IMethodSymbol method) - return null; - - if (!method.IsPartialDefinition) - return Diagnostic.Create(NotPartialDefinition, method.Locations[0]); - - var serviceCollectionType = context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Extensions.DependencyInjection.IServiceCollection"); - - if (!method.ReturnsVoid && !SymbolEqualityComparer.Default.Equals(method.ReturnType, serviceCollectionType)) - return Diagnostic.Create(WrongReturnType, method.Locations[0]); - - if (method.Parameters.Length != 1 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, serviceCollectionType)) - return Diagnostic.Create(WrongMethodParameters, method.Locations[0]); - - var attributeData = new AttributeModel[context.Attributes.Length]; - for (var i = 0; i < context.Attributes.Length; i++) - { - attributeData[i] = AttributeModel.Create(context.Attributes[i]); - - if (!attributeData[i].HasSearchCriteria) - return Diagnostic.Create(MissingSearchCriteria, attributeData[i].Location); - - if (attributeData[i].HasErrors) - return null; - } - - var model = MethodModel.Create(method, context.TargetNode); - return new MethodWithAttributesModel(model, new EquatableArray(attributeData)); - } - - private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol matchedType) - { - if (SymbolEqualityComparer.Default.Equals(type, assignableTo)) - { - matchedType = type; - return true; - } - - if (assignableTo.IsGenericType && assignableTo.IsDefinition) - { - if (assignableTo.TypeKind == TypeKind.Interface) - { - var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo)); - matchedType = matchingInterface; - return matchingInterface != null; - } - - var baseType = type.BaseType; - while (baseType != null) - { - if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo)) - { - matchedType = baseType; - return true; + true => $"{registration.KeySelectorMethodName}<{registration.ImplementationTypeName}>()", + false => $"{registration.KeySelectorMethodName}(typeof({registration.ImplementationTypeName}))", + null => null + }; + sb.AppendLine($" .{addMethod}<{registration.ServiceTypeName}, {registration.ImplementationTypeName}>({keyMethodInvocation})"); } - - baseType = baseType.BaseType; } } - else - { - if (assignableTo.TypeKind == TypeKind.Interface) - { - matchedType = assignableTo; - return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default); - } - var baseType = type.BaseType; - while (baseType != null) - { - if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo)) - { - matchedType = baseType; - return true; - } + var returnType = method.ReturnsVoid ? "void" : "IServiceCollection"; - baseType = baseType.BaseType; - } - } + var namespaceDeclaration = method.Namespace is null ? "" : $"namespace {method.Namespace};"; - matchedType = null; - return false; - } + var source = $$""" + using Microsoft.Extensions.DependencyInjection; - private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) - { - var @namespace = assemblySymbol.GlobalNamespace; - return GetTypesFromNamespace(@namespace); + {{namespaceDeclaration}} - static IEnumerable GetTypesFromNamespace(INamespaceSymbol namespaceSymbol) - { - foreach (var member in namespaceSymbol.GetMembers()) - { - if (member is INamedTypeSymbol namedType) - { - yield return namedType; - } - else if (member is INamespaceSymbol nestedNamespace) + {{method.TypeModifiers}} class {{method.TypeName}} { - foreach (var type in GetTypesFromNamespace(nestedNamespace)) + {{method.MethodModifiers}} {{returnType}} {{method.MethodName}}({{(method.IsExtensionMethod ? "this" : "")}} IServiceCollection {{method.ParameterName}}) { - yield return type; + {{(method.ReturnsVoid ? "" : "return ")}}{{method.ParameterName}} + {{sb.ToString().Trim()}}; } } - } - } + """; + + return source; } }