Skip to content

Commit

Permalink
Only cache authorization policies if permitted by policy provider
Browse files Browse the repository at this point in the history
  • Loading branch information
tobias-tengler committed Nov 9, 2024
1 parent 984e263 commit 5ff7efc
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,21 @@

namespace HotChocolate.AspNetCore.Authorization;

internal sealed class AuthorizationPolicyCache(IAuthorizationPolicyProvider policyProvider)
internal sealed class AuthorizationPolicyCache
{
private readonly ConcurrentDictionary<string, Task<AuthorizationPolicy>> _cache = new();
private readonly ConcurrentDictionary<string, AuthorizationPolicy> _cache = new();

public Task<AuthorizationPolicy> GetOrCreatePolicyAsync(AuthorizeDirective directive)
public AuthorizationPolicy? LookupPolicy(AuthorizeDirective directive)
{
var cacheKey = directive.GetPolicyCacheKey();

return _cache.GetOrAdd(cacheKey, _ => BuildAuthorizationPolicy(directive.Policy, directive.Roles));
return _cache.GetValueOrDefault(cacheKey);
}

private async Task<AuthorizationPolicy> BuildAuthorizationPolicy(
string? policyName,
IReadOnlyList<string>? roles)
public void CachePolicy(AuthorizeDirective directive, AuthorizationPolicy policy)
{
var policyBuilder = new AuthorizationPolicyBuilder();

if (!string.IsNullOrWhiteSpace(policyName))
{
var policy = await policyProvider.GetPolicyAsync(policyName).ConfigureAwait(false);

if (policy is not null)
{
policyBuilder = policyBuilder.Combine(policy);
}
else
{
throw new MissingAuthorizationPolicyException(policyName);
}
}
else
{
var defaultPolicy = await policyProvider.GetDefaultPolicyAsync().ConfigureAwait(false);

policyBuilder = policyBuilder.Combine(defaultPolicy);
}

if (roles is not null)
{
policyBuilder = policyBuilder.RequireRole(roles);
}
var cacheKey = directive.GetPolicyCacheKey();

return policyBuilder.Build();
_cache.TryAdd(cacheKey, policy);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,39 @@ namespace HotChocolate.AspNetCore.Authorization;
internal sealed class DefaultAuthorizationHandler : IAuthorizationHandler
{
private readonly IAuthorizationService _authSvc;
private readonly AuthorizationPolicyCache _policyCache;
private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider;
private readonly AuthorizationPolicyCache _authorizationPolicyCache;
private readonly bool _canCachePolicies;

/// <summary>
/// Initializes a new instance <see cref="DefaultAuthorizationHandler"/>.
/// </summary>
/// <param name="authorizationService">
/// The authorization service.
/// </param>
/// <param name="policyCache">
/// <param name="authorizationPolicyProvider">
/// The authorization policy provider.
/// </param>
/// <param name="authorizationPolicyCache">
/// The authorization policy cache.
/// </param>
/// <exception cref="ArgumentNullException">
/// <paramref name="authorizationService"/> is <c>null</c>.
/// <paramref name="policyCache"/> is <c>null</c>.
/// <paramref name="authorizationPolicyCache"/> is <c>null</c>.
/// </exception>
public DefaultAuthorizationHandler(
IAuthorizationService authorizationService,
AuthorizationPolicyCache policyCache)
IAuthorizationPolicyProvider authorizationPolicyProvider,
AuthorizationPolicyCache authorizationPolicyCache)
{
_authSvc = authorizationService ??
throw new ArgumentNullException(nameof(authorizationService));
_policyCache = policyCache ??
throw new ArgumentNullException(nameof(policyCache));
_authorizationPolicyProvider = authorizationPolicyProvider ??
throw new ArgumentNullException(nameof(authorizationPolicyProvider));
_authorizationPolicyCache = authorizationPolicyCache ??
throw new ArgumentNullException(nameof(authorizationPolicyCache));

_canCachePolicies = _authorizationPolicyProvider.AllowsCachingPolicies;
}

/// <summary>
Expand Down Expand Up @@ -123,9 +133,24 @@ private async ValueTask<AuthorizeResult> AuthorizeAsync(
{
try
{
var combinedPolicy = await _policyCache.GetOrCreatePolicyAsync(directive);
AuthorizationPolicy? authorizationPolicy = null;

if (_canCachePolicies)
{
authorizationPolicy = _authorizationPolicyCache.LookupPolicy(directive);
}

if (authorizationPolicy is null)
{
authorizationPolicy = await BuildAuthorizationPolicy(directive.Policy, directive.Roles);

var result = await _authSvc.AuthorizeAsync(user, context, combinedPolicy).ConfigureAwait(false);
if (_canCachePolicies)
{
_authorizationPolicyCache.CachePolicy(directive, authorizationPolicy);
}
}

var result = await _authSvc.AuthorizeAsync(user, context, authorizationPolicy).ConfigureAwait(false);

return result.Succeeded
? AuthorizeResult.Allowed
Expand All @@ -137,6 +162,40 @@ private async ValueTask<AuthorizeResult> AuthorizeAsync(
}
}

private async Task<AuthorizationPolicy> BuildAuthorizationPolicy(
string? policyName,
IReadOnlyList<string>? roles)
{
var policyBuilder = new AuthorizationPolicyBuilder();

if (!string.IsNullOrWhiteSpace(policyName))
{
var policy = await _authorizationPolicyProvider.GetPolicyAsync(policyName).ConfigureAwait(false);

if (policy is not null)
{
policyBuilder = policyBuilder.Combine(policy);
}
else
{
throw new MissingAuthorizationPolicyException(policyName);
}
}
else
{
var defaultPolicy = await _authorizationPolicyProvider.GetDefaultPolicyAsync().ConfigureAwait(false);

policyBuilder = policyBuilder.Combine(defaultPolicy);
}

if (roles is not null)
{
policyBuilder = policyBuilder.RequireRole(roles);
}

return policyBuilder.Build();
}

private static UserState GetUserState(IDictionary<string, object?> contextData)
{
if (contextData.TryGetValue(WellKnownContextData.UserState, out var value) &&
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
using System.Net;
using System.Security.Claims;
using HotChocolate.AspNetCore.Tests.Utilities;
using HotChocolate.Execution.Configuration;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Authorization.Infrastructure;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;

namespace HotChocolate.AspNetCore.Authorization;

public class AuthorizationPolicyProviderTess(TestServerFactory serverFactory) : ServerTestBase(serverFactory)
{
[Fact]
public async Task Policies_Are_Cached_If_PolicyProvider_Allows_Caching()
{
// arrange
var policyProvider = new CustomAuthorizationPolicyProvider(allowsCaching: true);

var server = CreateTestServer(
builder =>
{
builder.Services.AddSingleton<IAuthorizationPolicyProvider>(_ => policyProvider);

builder
.AddQueryType<Query>()
.AddAuthorization();
},
context =>
{
var identity = new ClaimsIdentity("testauth");
identity.AddClaim(new Claim(
ClaimTypes.DateOfBirth,
"2013-05-30"));
context.User = new ClaimsPrincipal(identity);
});

// act
var result1 =
await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", });
var result2 =
await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", });

// assert
Assert.Equal(HttpStatusCode.OK, result1.StatusCode);
Assert.Null(result1.Errors);
Assert.Equal(HttpStatusCode.OK, result2.StatusCode);
Assert.Null(result2.Errors);
Assert.Equal(1, policyProvider.InvocationsOfGetPolicyAsync);
}

[Fact]
public async Task Policies_Are_Not_Cached_If_PolicyProvider_Disallows_Caching()
{
// arrange
var policyProvider = new CustomAuthorizationPolicyProvider(allowsCaching: false);

var server = CreateTestServer(
builder =>
{
builder.Services.AddSingleton<IAuthorizationPolicyProvider>(_ => policyProvider);

builder
.AddQueryType<Query>()
.AddAuthorization();
},
context =>
{
var identity = new ClaimsIdentity("testauth");
identity.AddClaim(new Claim(
ClaimTypes.DateOfBirth,
"2013-05-30"));
context.User = new ClaimsPrincipal(identity);
});

// act
var result1 =
await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", });
var result2 =
await server.PostAsync(new ClientQueryRequest { Query = "{ bar }", });

// assert
Assert.Equal(HttpStatusCode.OK, result1.StatusCode);
Assert.Null(result1.Errors);
Assert.Equal(HttpStatusCode.OK, result2.StatusCode);
Assert.Null(result2.Errors);
Assert.Equal(2, policyProvider.InvocationsOfGetPolicyAsync);
}

public class Query
{
[HotChocolate.Authorization.Authorize(Policy = "policy")]
public string Bar() => "bar";
}

private TestServer CreateTestServer(
Action<IRequestExecutorBuilder> build,
Action<HttpContext> configureUser)
{
return ServerFactory.Create(
services =>
{
build(services
.AddRouting()
.AddGraphQLServer()
.AddHttpRequestInterceptor(
(context, requestExecutor, requestBuilder, cancellationToken) =>
{
configureUser(context);
return default;
}));
},
app =>
{
app.UseRouting();
app.UseEndpoints(b => b.MapGraphQL());
});
}

public class CustomAuthorizationPolicyProvider(bool allowsCaching) : IAuthorizationPolicyProvider
{
public int InvocationsOfGetPolicyAsync { get; private set; }

public Task<AuthorizationPolicy?> GetPolicyAsync(string policyName)
{
InvocationsOfGetPolicyAsync++;

var policy = new AuthorizationPolicyBuilder()
.AddRequirements(new DenyAnonymousAuthorizationRequirement())
.Build();

return Task.FromResult<AuthorizationPolicy?>(policy);
}

public Task<AuthorizationPolicy> GetDefaultPolicyAsync() => throw new NotImplementedException();

public Task<AuthorizationPolicy?> GetFallbackPolicyAsync() => throw new NotImplementedException();

public virtual bool AllowsCachingPolicies => allowsCaching;
}
}

0 comments on commit 5ff7efc

Please sign in to comment.