diff --git a/client_authentication.go b/client_authentication.go index 7cb00c408..d65b1a090 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -21,6 +21,9 @@ import ( "github.com/ory/fosite/token/jwt" ) +// CanSkipClientAuthenticationStrategy provides a method signature for checking if client authentication can be skipped. +type CanSkipClientAuthenticationStrategy func(context.Context, AccessRequester) bool + // ClientAuthenticationStrategy provides a method signature for authenticating a client request type ClientAuthenticationStrategy func(context.Context, *http.Request, url.Values) (Client, error) diff --git a/config.go b/config.go index 983a35940..e4088f9bc 100644 --- a/config.go +++ b/config.go @@ -150,6 +150,12 @@ type GrantTypeJWTBearerCanSkipClientAuthProvider interface { GetGrantTypeJWTBearerCanSkipClientAuth(ctx context.Context) bool } +// GrantTypeTokenExchangeCanSkipClientAuthProvider returns the provider for configuring the grant type Token Exchange can skip client auth. +type GrantTypeTokenExchangeCanSkipClientAuthProvider interface { + // GetGrantTypeTokenExchangeCanSkipClientAuth returns the grant type Token Exchange can skip client auth. + GetGrantTypeTokenExchangeCanSkipClientAuth(ctx context.Context) CanSkipClientAuthenticationStrategy +} + // GrantTypeJWTBearerIDOptionalProvider returns the provider for configuring the grant type JWT bearer ID optional. type GrantTypeJWTBearerIDOptionalProvider interface { // GetGrantTypeJWTBearerIDOptional returns the grant type JWT bearer ID optional. diff --git a/config_default.go b/config_default.go index df6fa2a50..127b4ad22 100644 --- a/config_default.go +++ b/config_default.go @@ -23,45 +23,46 @@ const ( ) var ( - _ AuthorizeCodeLifespanProvider = (*Config)(nil) - _ RefreshTokenLifespanProvider = (*Config)(nil) - _ AccessTokenLifespanProvider = (*Config)(nil) - _ ScopeStrategyProvider = (*Config)(nil) - _ AudienceStrategyProvider = (*Config)(nil) - _ RedirectSecureCheckerProvider = (*Config)(nil) - _ RefreshTokenScopesProvider = (*Config)(nil) - _ DisableRefreshTokenValidationProvider = (*Config)(nil) - _ AccessTokenIssuerProvider = (*Config)(nil) - _ JWTScopeFieldProvider = (*Config)(nil) - _ AllowedPromptsProvider = (*Config)(nil) - _ OmitRedirectScopeParamProvider = (*Config)(nil) - _ MinParameterEntropyProvider = (*Config)(nil) - _ SanitationAllowedProvider = (*Config)(nil) - _ EnforcePKCEForPublicClientsProvider = (*Config)(nil) - _ EnablePKCEPlainChallengeMethodProvider = (*Config)(nil) - _ EnforcePKCEProvider = (*Config)(nil) - _ GrantTypeJWTBearerCanSkipClientAuthProvider = (*Config)(nil) - _ GrantTypeJWTBearerIDOptionalProvider = (*Config)(nil) - _ GrantTypeJWTBearerIssuedDateOptionalProvider = (*Config)(nil) - _ GetJWTMaxDurationProvider = (*Config)(nil) - _ IDTokenLifespanProvider = (*Config)(nil) - _ IDTokenIssuerProvider = (*Config)(nil) - _ JWKSFetcherStrategyProvider = (*Config)(nil) - _ ClientAuthenticationStrategyProvider = (*Config)(nil) - _ SendDebugMessagesToClientsProvider = (*Config)(nil) - _ ResponseModeHandlerExtensionProvider = (*Config)(nil) - _ MessageCatalogProvider = (*Config)(nil) - _ FormPostHTMLTemplateProvider = (*Config)(nil) - _ TokenURLProvider = (*Config)(nil) - _ GetSecretsHashingProvider = (*Config)(nil) - _ HTTPClientProvider = (*Config)(nil) - _ HMACHashingProvider = (*Config)(nil) - _ AuthorizeEndpointHandlersProvider = (*Config)(nil) - _ TokenEndpointHandlersProvider = (*Config)(nil) - _ TokenIntrospectionHandlersProvider = (*Config)(nil) - _ RevocationHandlersProvider = (*Config)(nil) - _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) - _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) + _ AuthorizeCodeLifespanProvider = (*Config)(nil) + _ RefreshTokenLifespanProvider = (*Config)(nil) + _ AccessTokenLifespanProvider = (*Config)(nil) + _ ScopeStrategyProvider = (*Config)(nil) + _ AudienceStrategyProvider = (*Config)(nil) + _ RedirectSecureCheckerProvider = (*Config)(nil) + _ RefreshTokenScopesProvider = (*Config)(nil) + _ DisableRefreshTokenValidationProvider = (*Config)(nil) + _ AccessTokenIssuerProvider = (*Config)(nil) + _ JWTScopeFieldProvider = (*Config)(nil) + _ AllowedPromptsProvider = (*Config)(nil) + _ OmitRedirectScopeParamProvider = (*Config)(nil) + _ MinParameterEntropyProvider = (*Config)(nil) + _ SanitationAllowedProvider = (*Config)(nil) + _ EnforcePKCEForPublicClientsProvider = (*Config)(nil) + _ EnablePKCEPlainChallengeMethodProvider = (*Config)(nil) + _ EnforcePKCEProvider = (*Config)(nil) + _ GrantTypeTokenExchangeCanSkipClientAuthProvider = (*Config)(nil) + _ GrantTypeJWTBearerCanSkipClientAuthProvider = (*Config)(nil) + _ GrantTypeJWTBearerIDOptionalProvider = (*Config)(nil) + _ GrantTypeJWTBearerIssuedDateOptionalProvider = (*Config)(nil) + _ GetJWTMaxDurationProvider = (*Config)(nil) + _ IDTokenLifespanProvider = (*Config)(nil) + _ IDTokenIssuerProvider = (*Config)(nil) + _ JWKSFetcherStrategyProvider = (*Config)(nil) + _ ClientAuthenticationStrategyProvider = (*Config)(nil) + _ SendDebugMessagesToClientsProvider = (*Config)(nil) + _ ResponseModeHandlerExtensionProvider = (*Config)(nil) + _ MessageCatalogProvider = (*Config)(nil) + _ FormPostHTMLTemplateProvider = (*Config)(nil) + _ TokenURLProvider = (*Config)(nil) + _ GetSecretsHashingProvider = (*Config)(nil) + _ HTTPClientProvider = (*Config)(nil) + _ HMACHashingProvider = (*Config)(nil) + _ AuthorizeEndpointHandlersProvider = (*Config)(nil) + _ TokenEndpointHandlersProvider = (*Config)(nil) + _ TokenIntrospectionHandlersProvider = (*Config)(nil) + _ RevocationHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) ) type Config struct { @@ -148,6 +149,9 @@ type Config struct { // GrantTypeJWTBearerMaxDuration sets the maximum time after JWT issued date, during which the JWT is considered valid. GrantTypeJWTBearerMaxDuration time.Duration + // GrantTypeTokenExchangeCanSkipClientAuth indicates the stretegy to check if client authentication can be skipped. + GrantTypeTokenExchangeCanSkipClientAuth CanSkipClientAuthenticationStrategy + // ClientAuthenticationStrategy indicates the Strategy to authenticate client requests ClientAuthenticationStrategy ClientAuthenticationStrategy @@ -299,6 +303,12 @@ func (c *Config) GetGrantTypeJWTBearerCanSkipClientAuth(ctx context.Context) boo return c.GrantTypeJWTBearerCanSkipClientAuth } +// GetGrantTypeTokenExchangeCanSkipClientAuth returns the GrantTypeTokenExchangeCanSkipClientAuth field. +// Defaults to nil, in which case TokenExchange follows the default behavior. +func (c *Config) GetGrantTypeTokenExchangeCanSkipClientAuth(ctx context.Context) CanSkipClientAuthenticationStrategy { + return c.GrantTypeTokenExchangeCanSkipClientAuth +} + // GetEnforcePKCE If set to true, public clients must use PKCE. func (c *Config) GetEnforcePKCE(ctx context.Context) bool { return c.EnforcePKCE diff --git a/handler/rfc8693/handler.go b/handler/rfc8693/handler.go index 937c9a2d4..774505068 100644 --- a/handler/rfc8693/handler.go +++ b/handler/rfc8693/handler.go @@ -10,7 +10,6 @@ import ( "time" "github.com/ory/fosite" - "github.com/ory/fosite/compose" "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/token/jwt" "github.com/ory/x/errorsx" @@ -22,18 +21,17 @@ const ( tokenTypeAT = "urn:ietf:params:oauth:token-type:access_token" ) -func TokenExchangeGrantFactory(config *compose.CommonStrategy, storage, strategy interface{}) interface{} { - return nil -} - type Handler struct { - Storage RFC8693Storage - Strategy ClientAuthenticationStrategy - ScopeStrategy fosite.ScopeStrategy - AudienceMatchingStrategy fosite.AudienceMatchingStrategy - RefreshTokenStrategy oauth2.RefreshTokenStrategy - RefreshTokenStorage oauth2.RefreshTokenStorage - fosite.RefreshTokenScopesProvider + Storage RFC8693Storage + RefreshTokenStorage oauth2.RefreshTokenStorage + RefreshTokenStrategy oauth2.RefreshTokenStrategy + + Config interface { + fosite.GrantTypeTokenExchangeCanSkipClientAuthProvider + fosite.ScopeStrategyProvider + fosite.AudienceStrategyProvider + fosite.RefreshTokenScopesProvider + } *oauth2.HandleHelper } @@ -136,14 +134,14 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester fosi // Check and grant scope. for _, scope := range requester.GetRequestedScopes() { - if !c.ScopeStrategy(client.GetScopes(), scope) { + if !c.Config.GetScopeStrategy(ctx)(client.GetScopes(), scope) { return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope)) } requester.GrantScope(scope) } // Check and grant audience. - if err := c.AudienceMatchingStrategy(client.GetAudience(), requester.GetRequestedAudience()); err != nil { + if err := c.Config.GetAudienceStrategy(ctx)(client.GetAudience(), requester.GetRequestedAudience()); err != nil { return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("audience not match: %v", err)) } for _, audience := range requester.GetRequestedAudience() { @@ -164,7 +162,7 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester fosi requester.SetSession(&fosite.DefaultSession{ Subject: subject, }) - requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.Config.GetAccessTokenLifespan(ctx))) + requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.HandleHelper.Config.GetAccessTokenLifespan(ctx))) return nil case tokenTypeAT: or, err := c.verifyAccessTokenAsSubjectToken(ctx, client.GetID(), params) @@ -189,14 +187,13 @@ func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester f return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHintf("The OAuth 2.0 Client is not allowed to use authorization grant '%s'.", fosite.GrantTypeTokenExchange)) } - atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeTokenExchange, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeTokenExchange, fosite.AccessToken, c.HandleHelper.Config.GetAccessTokenLifespan(ctx)) if err := c.IssueAccessToken(ctx, atLifespan, requester, responder); err != nil { return err } if canIssueRefreshToken(ctx, c, requester) { - fmt.Println(requester) refresh, refreshSignature, err := c.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester) if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) @@ -211,7 +208,7 @@ func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester f } func canIssueRefreshToken(ctx context.Context, c *Handler, requester fosite.Requester) bool { - scope := c.GetRefreshTokenScopes(ctx) + scope := c.Config.GetRefreshTokenScopes(ctx) // Require one of the refresh token scopes, if set. if len(scope) > 0 && !requester.GetGrantedScopes().HasOneOf(scope...) { return false @@ -223,8 +220,12 @@ func canIssueRefreshToken(ctx context.Context, c *Handler, requester fosite.Requ return true } -func (c *Handler) CanSkipClientAuth(requester fosite.AccessRequester) bool { - return c.Strategy.CanSkipClientAuth(requester) +func (c *Handler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { + if s := c.Config.GetGrantTypeTokenExchangeCanSkipClientAuth(ctx); s != nil { + return s(ctx, requester) + } + + return false } func (c *Handler) keyFunc(ctx context.Context) jwt.Keyfunc { diff --git a/handler/rfc8693/handler_test.go b/handler/rfc8693/handler_test.go index 89a5cf7a7..0d3ad9502 100644 --- a/handler/rfc8693/handler_test.go +++ b/handler/rfc8693/handler_test.go @@ -30,6 +30,7 @@ func TestTokenExchange_HandleTokenEndpointRequest(t *testing.T) { h := Handler{ Storage: teStore, + Config: &fosite.Config{}, HandleHelper: &fositeOAuth2.HandleHelper{ AccessTokenStorage: atStore, AccessTokenStrategy: chgen, @@ -37,9 +38,7 @@ func TestTokenExchange_HandleTokenEndpointRequest(t *testing.T) { AccessTokenLifespan: time.Hour, }, }, - ScopeStrategy: fosite.HierarchicScopeStrategy, - AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, - RefreshTokenStorage: rtStore, + RefreshTokenStorage: rtStore, } for _, c := range []struct { @@ -286,13 +285,9 @@ func TestTokenExchange_PopulateTokenEndpointResponse(t *testing.T) { AccessTokenLifespan: time.Hour, }, }, - ScopeStrategy: fosite.HierarchicScopeStrategy, - AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, - RefreshTokenStrategy: rtStrategy, - RefreshTokenStorage: rtStore, - RefreshTokenScopesProvider: &fosite.Config{ - RefreshTokenScopes: []string{"offline", "offline_access"}, - }, + Config: &fosite.Config{}, + RefreshTokenStrategy: rtStrategy, + RefreshTokenStorage: rtStore, } for _, c := range []struct { name string diff --git a/handler/rfc8693/strategy.go b/handler/rfc8693/strategy.go deleted file mode 100644 index df2051628..000000000 --- a/handler/rfc8693/strategy.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package rfc8693 - -//go:generate mockgen -source=strategy.go -destination=../../internal/oauth2_token_exchange_strategy.go -package=internal - -import "github.com/ory/fosite" - -type ClientAuthenticationStrategy interface { - CanSkipClientAuth(requester fosite.AccessRequester) bool -} - -// DefaultClientAuthenticationStrategy enforces client authentication for all the cases. -type DefaultClientAuthenticationStrategy struct{} - -func (s *DefaultClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.Requester) bool { - return false -} diff --git a/internal/oauth2_token_exchange_strategy.go b/internal/oauth2_token_exchange_strategy.go deleted file mode 100644 index 66875129f..000000000 --- a/internal/oauth2_token_exchange_strategy.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -// Code generated by MockGen. DO NOT EDIT. -// Source: strategy.go - -// Package internal is a generated GoMock package. -package internal - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" -) - -// MockClientAuthenticationStrategy is a mock of ClientAuthenticationStrategy interface. -type MockClientAuthenticationStrategy struct { - ctrl *gomock.Controller - recorder *MockClientAuthenticationStrategyMockRecorder -} - -// MockClientAuthenticationStrategyMockRecorder is the mock recorder for MockClientAuthenticationStrategy. -type MockClientAuthenticationStrategyMockRecorder struct { - mock *MockClientAuthenticationStrategy -} - -// NewMockClientAuthenticationStrategy creates a new mock instance. -func NewMockClientAuthenticationStrategy(ctrl *gomock.Controller) *MockClientAuthenticationStrategy { - mock := &MockClientAuthenticationStrategy{ctrl: ctrl} - mock.recorder = &MockClientAuthenticationStrategyMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockClientAuthenticationStrategy) EXPECT() *MockClientAuthenticationStrategyMockRecorder { - return m.recorder -} - -// CanSkipClientAuth mocks base method. -func (m *MockClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.AccessRequester) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanSkipClientAuth", requester) - ret0, _ := ret[0].(bool) - return ret0 -} - -// CanSkipClientAuth indicates an expected call of CanSkipClientAuth. -func (mr *MockClientAuthenticationStrategyMockRecorder) CanSkipClientAuth(requester interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSkipClientAuth", reflect.TypeOf((*MockClientAuthenticationStrategy)(nil).CanSkipClientAuth), requester) -}