diff --git a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs index 3f3c70e..9c05990 100644 --- a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs +++ b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs @@ -462,6 +462,42 @@ public class MyService: IServiceA, IServiceB {} Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); } + [Fact] + public void AddNestedTypes() + { + var attribute = "[GenerateServiceRegistrations(AssignableTo = typeof(IService))]"; + var compilation = CreateCompilation(Sources.MethodWithAttribute(attribute), + """ + namespace GeneratorTests; + + public interface IService { } + + public class ParentType1 + { + public class MyService1 : IService { } + public class MyService2 : IService { } + } + + public class ParentType2 + { + public class MyService1 : IService { } + } + """); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var registrations = $""" + return services + .AddTransient() + .AddTransient() + .AddTransient(); + """; + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + } + [Fact] public void AddAsKeyedServices_GenericMethod() { diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs index 60cc5f8..64d9c84 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -104,19 +104,20 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) { var @namespace = assemblySymbol.GlobalNamespace; - return GetTypesFromNamespace(@namespace); + return GetTypesFromNamespaceOrType(@namespace); - static IEnumerable GetTypesFromNamespace(INamespaceSymbol namespaceSymbol) + static IEnumerable GetTypesFromNamespaceOrType(INamespaceOrTypeSymbol symbol) { - foreach (var member in namespaceSymbol.GetMembers()) + foreach (var member in symbol.GetMembers()) { - if (member is INamedTypeSymbol namedType) + if (member is INamespaceOrTypeSymbol namespaceOrType) { - yield return namedType; - } - else if (member is INamespaceSymbol nestedNamespace) - { - foreach (var type in GetTypesFromNamespace(nestedNamespace)) + if (member is INamedTypeSymbol namedType) + { + yield return namedType; + } + + foreach (var type in GetTypesFromNamespaceOrType(namespaceOrType)) { yield return type; }