diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/AuthorizationPolicyCache.cs b/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/AuthorizationPolicyCache.cs index 258288c701d..591c326aa1b 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/AuthorizationPolicyCache.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/AuthorizationPolicyCache.cs @@ -4,48 +4,21 @@ namespace HotChocolate.AspNetCore.Authorization; -internal sealed class AuthorizationPolicyCache(IAuthorizationPolicyProvider policyProvider) +internal sealed class AuthorizationPolicyCache { - private readonly ConcurrentDictionary> _cache = new(); + private readonly ConcurrentDictionary _cache = new(); - public Task GetOrCreatePolicyAsync(AuthorizeDirective directive) + public AuthorizationPolicy? LookupPolicy(AuthorizeDirective directive) { var cacheKey = directive.GetPolicyCacheKey(); - return _cache.GetOrAdd(cacheKey, _ => BuildAuthorizationPolicy(directive.Policy, directive.Roles)); + return _cache.GetValueOrDefault(cacheKey); } - private async Task BuildAuthorizationPolicy( - string? policyName, - IReadOnlyList? roles) + public void CachePolicy(AuthorizeDirective directive, AuthorizationPolicy policy) { - var policyBuilder = new AuthorizationPolicyBuilder(); - - if (!string.IsNullOrWhiteSpace(policyName)) - { - var policy = await policyProvider.GetPolicyAsync(policyName).ConfigureAwait(false); - - if (policy is not null) - { - policyBuilder = policyBuilder.Combine(policy); - } - else - { - throw new MissingAuthorizationPolicyException(policyName); - } - } - else - { - var defaultPolicy = await policyProvider.GetDefaultPolicyAsync().ConfigureAwait(false); - - policyBuilder = policyBuilder.Combine(defaultPolicy); - } - - if (roles is not null) - { - policyBuilder = policyBuilder.RequireRole(roles); - } + var cacheKey = directive.GetPolicyCacheKey(); - return policyBuilder.Build(); + _cache.TryAdd(cacheKey, policy); } } diff --git a/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/DefaultAuthorizationHandler.cs b/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/DefaultAuthorizationHandler.cs index c36ae4fb598..bcc15389477 100644 --- a/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/DefaultAuthorizationHandler.cs +++ b/src/HotChocolate/AspNetCore/src/AspNetCore.Authorization/DefaultAuthorizationHandler.cs @@ -12,7 +12,9 @@ namespace HotChocolate.AspNetCore.Authorization; internal sealed class DefaultAuthorizationHandler : IAuthorizationHandler { private readonly IAuthorizationService _authSvc; - private readonly AuthorizationPolicyCache _policyCache; + private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider; + private readonly AuthorizationPolicyCache _authorizationPolicyCache; + private readonly bool _canCachePolicies; /// /// Initializes a new instance . @@ -20,21 +22,29 @@ internal sealed class DefaultAuthorizationHandler : IAuthorizationHandler /// /// The authorization service. /// - /// + /// + /// The authorization policy provider. + /// + /// /// The authorization policy cache. /// /// /// is null. - /// is null. + /// is null. /// public DefaultAuthorizationHandler( IAuthorizationService authorizationService, - AuthorizationPolicyCache policyCache) + IAuthorizationPolicyProvider authorizationPolicyProvider, + AuthorizationPolicyCache authorizationPolicyCache) { _authSvc = authorizationService ?? throw new ArgumentNullException(nameof(authorizationService)); - _policyCache = policyCache ?? - throw new ArgumentNullException(nameof(policyCache)); + _authorizationPolicyProvider = authorizationPolicyProvider ?? + throw new ArgumentNullException(nameof(authorizationPolicyProvider)); + _authorizationPolicyCache = authorizationPolicyCache ?? + throw new ArgumentNullException(nameof(authorizationPolicyCache)); + + _canCachePolicies = _authorizationPolicyProvider.AllowsCachingPolicies; } /// @@ -123,9 +133,24 @@ private async ValueTask AuthorizeAsync( { try { - var combinedPolicy = await _policyCache.GetOrCreatePolicyAsync(directive); + AuthorizationPolicy? authorizationPolicy = null; + + if (_canCachePolicies) + { + authorizationPolicy = _authorizationPolicyCache.LookupPolicy(directive); + } + + if (authorizationPolicy is null) + { + authorizationPolicy = await BuildAuthorizationPolicy(directive.Policy, directive.Roles); - var result = await _authSvc.AuthorizeAsync(user, context, combinedPolicy).ConfigureAwait(false); + if (_canCachePolicies) + { + _authorizationPolicyCache.CachePolicy(directive, authorizationPolicy); + } + } + + var result = await _authSvc.AuthorizeAsync(user, context, authorizationPolicy).ConfigureAwait(false); return result.Succeeded ? AuthorizeResult.Allowed @@ -137,6 +162,40 @@ private async ValueTask AuthorizeAsync( } } + private async Task BuildAuthorizationPolicy( + string? policyName, + IReadOnlyList? roles) + { + var policyBuilder = new AuthorizationPolicyBuilder(); + + if (!string.IsNullOrWhiteSpace(policyName)) + { + var policy = await _authorizationPolicyProvider.GetPolicyAsync(policyName).ConfigureAwait(false); + + if (policy is not null) + { + policyBuilder = policyBuilder.Combine(policy); + } + else + { + throw new MissingAuthorizationPolicyException(policyName); + } + } + else + { + var defaultPolicy = await _authorizationPolicyProvider.GetDefaultPolicyAsync().ConfigureAwait(false); + + policyBuilder = policyBuilder.Combine(defaultPolicy); + } + + if (roles is not null) + { + policyBuilder = policyBuilder.RequireRole(roles); + } + + return policyBuilder.Build(); + } + private static UserState GetUserState(IDictionary contextData) { if (contextData.TryGetValue(WellKnownContextData.UserState, out var value) && diff --git a/src/HotChocolate/AspNetCore/test/AspNetCore.Authorization.Tests/AuthorizationPolicyProviderTess.cs b/src/HotChocolate/AspNetCore/test/AspNetCore.Authorization.Tests/AuthorizationPolicyProviderTess.cs new file mode 100644 index 00000000000..6e587b0dce3 --- /dev/null +++ b/src/HotChocolate/AspNetCore/test/AspNetCore.Authorization.Tests/AuthorizationPolicyProviderTess.cs @@ -0,0 +1,143 @@ +using System.Net; +using System.Security.Claims; +using HotChocolate.AspNetCore.Tests.Utilities; +using HotChocolate.Execution.Configuration; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Authorization.Infrastructure; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; + +namespace HotChocolate.AspNetCore.Authorization; + +public class AuthorizationPolicyProviderTess(TestServerFactory serverFactory) : ServerTestBase(serverFactory) +{ + [Fact] + public async Task Policies_Are_Cached_If_PolicyProvider_Allows_Caching() + { + // arrange + var policyProvider = new CustomAuthorizationPolicyProvider(allowsCaching: true); + + var server = CreateTestServer( + builder => + { + builder.Services.AddSingleton(_ => policyProvider); + + builder + .AddQueryType() + .AddAuthorization(); + }, + context => + { + var identity = new ClaimsIdentity("testauth"); + identity.AddClaim(new Claim( + ClaimTypes.DateOfBirth, + "2013-05-30")); + context.User = new ClaimsPrincipal(identity); + }); + + // act + var result1 = + await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", }); + var result2 = + await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", }); + + // assert + Assert.Equal(HttpStatusCode.OK, result1.StatusCode); + Assert.Null(result1.Errors); + Assert.Equal(HttpStatusCode.OK, result2.StatusCode); + Assert.Null(result2.Errors); + Assert.Equal(1, policyProvider.InvocationsOfGetPolicyAsync); + } + + [Fact] + public async Task Policies_Are_Not_Cached_If_PolicyProvider_Disallows_Caching() + { + // arrange + var policyProvider = new CustomAuthorizationPolicyProvider(allowsCaching: false); + + var server = CreateTestServer( + builder => + { + builder.Services.AddSingleton(_ => policyProvider); + + builder + .AddQueryType() + .AddAuthorization(); + }, + context => + { + var identity = new ClaimsIdentity("testauth"); + identity.AddClaim(new Claim( + ClaimTypes.DateOfBirth, + "2013-05-30")); + context.User = new ClaimsPrincipal(identity); + }); + + // act + var result1 = + await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", }); + var result2 = + await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", }); + + // assert + Assert.Equal(HttpStatusCode.OK, result1.StatusCode); + Assert.Null(result1.Errors); + Assert.Equal(HttpStatusCode.OK, result2.StatusCode); + Assert.Null(result2.Errors); + Assert.Equal(2, policyProvider.InvocationsOfGetPolicyAsync); + } + + public class Query + { + [HotChocolate.Authorization.Authorize(Policy = "policy")] + public string Bar() => "bar"; + } + + private TestServer CreateTestServer( + Action build, + Action configureUser) + { + return ServerFactory.Create( + services => + { + build(services + .AddRouting() + .AddGraphQLServer() + .AddHttpRequestInterceptor( + (context, requestExecutor, requestBuilder, cancellationToken) => + { + configureUser(context); + return default; + })); + }, + app => + { + app.UseRouting(); + app.UseEndpoints(b => b.MapGraphQL()); + }); + } + + public class CustomAuthorizationPolicyProvider(bool allowsCaching) : IAuthorizationPolicyProvider + { + public int InvocationsOfGetPolicyAsync { get; private set; } + + public Task GetPolicyAsync(string policyName) + { + InvocationsOfGetPolicyAsync++; + + var policy = new AuthorizationPolicyBuilder() + .AddRequirements(new DenyAnonymousAuthorizationRequirement()) + .Build(); + + return Task.FromResult(policy); + } + + public Task GetDefaultPolicyAsync() => throw new NotImplementedException(); + + public Task GetFallbackPolicyAsync() => throw new NotImplementedException(); + + public virtual bool AllowsCachingPolicies => allowsCaching; + } +}