diff --git a/handler/rfc8693/handler.go b/handler/rfc8693/handler.go new file mode 100644 index 000000000..937c9a2d4 --- /dev/null +++ b/handler/rfc8693/handler.go @@ -0,0 +1,268 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/token/jwt" + "github.com/ory/x/errorsx" +) + +// #nosec G101 +const ( + tokenTypeIDToken = "urn:ietf:params:oauth:token-type:id_token" + tokenTypeAT = "urn:ietf:params:oauth:token-type:access_token" +) + +func TokenExchangeGrantFactory(config *compose.CommonStrategy, storage, strategy interface{}) interface{} { + return nil +} + +type Handler struct { + Storage RFC8693Storage + Strategy ClientAuthenticationStrategy + ScopeStrategy fosite.ScopeStrategy + AudienceMatchingStrategy fosite.AudienceMatchingStrategy + RefreshTokenStrategy oauth2.RefreshTokenStrategy + RefreshTokenStorage oauth2.RefreshTokenStorage + fosite.RefreshTokenScopesProvider + + *oauth2.HandleHelper +} + +type tokenExchangeParams struct { + subjectToken string + subjectTokenType string +} + +func parseRequestParameter(requester fosite.AccessRequester) (*tokenExchangeParams, error) { + form := requester.GetRequestForm() + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // subject_token + // REQUIRED. A security token that represents the identity of the + // party on behalf of whom the request is being made. Typically, the + // subject of this token will be the subject of the security token + // issued in response to the request. + subjectToken := form.Get("subject_token") + if subjectToken == "" { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("subject_token is missing")) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // subject_token_type + // REQUIRED. An identifier, that indicates the type of the + // security token in the "subject_token" parameter. + subjectTokenType := form.Get("subject_token_type") + switch subjectTokenType { + case tokenTypeIDToken, tokenTypeAT: + default: + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("unsupported or missing subject_token_type %s", subjectTokenType)) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // requested_token_type + // OPTIONAL. An identifier, for the type of the requested security token. + // If the requested type is unspecified, + // the issued token type is at the discretion of the authorization server and + // may be dictated by knowledge of the requirements of the service or + // resource indicated by the resource or audience parameter. + requestedTokenType := form.Get("requested_token_type") + switch requestedTokenType { + case tokenTypeAT, "": + default: + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("unsupported requested_token_type %s", requestedTokenType)) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // actor_token + // OPTIONAL . A security token that represents the identity of the acting party. + // Typically, this will be the party that is authorized to use the requested security + // token and act on behalf of the subject. + actorToken := form.Get("actor_token") + if actorToken != "" { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("'actor_token' was provided but delegation is currently not supported.")) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // actor_token_type + // An identifier, as described in Section 3, that indicates the type of the security token + // in the actor_token parameter. This is REQUIRED when the actor_token parameter is present + // in the request but MUST NOT be included otherwise. + actorTokenType := form.Get("actor_token_type") + if actorTokenType != "" { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("'actor_token_type' was provided but delegation is currently not supported.")) + } + + return &tokenExchangeParams{ + subjectToken: subjectToken, + subjectTokenType: subjectTokenType, + }, nil +} + +func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error { + if !c.CanHandleTokenEndpointRequest(requester) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + client := requester.GetClient() + if client.GetID() == "" { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("unauthenticated client")) + } + + // Check whether client is allowed to use token exchange. + if !client.GetGrantTypes().Has(string(fosite.GrantTypeTokenExchange)) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("the client is not allowed to use token-exchange")) + } + + // Get request parameter related token exchange. + params, err := parseRequestParameter(requester) + if err != nil { + return err + } + + // Check and grant scope. + for _, scope := range requester.GetRequestedScopes() { + if !c.ScopeStrategy(client.GetScopes(), scope) { + return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope)) + } + requester.GrantScope(scope) + } + + // Check and grant audience. + if err := c.AudienceMatchingStrategy(client.GetAudience(), requester.GetRequestedAudience()); err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("audience not match: %v", err)) + } + for _, audience := range requester.GetRequestedAudience() { + requester.GrantAudience(audience) + } + + // Verify subject token. + switch params.subjectTokenType { + case tokenTypeIDToken: + claims := jwt.MapClaims{} + if _, err := jwt.ParseWithClaims(params.subjectToken, claims, c.keyFunc(ctx)); err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("failed to verify JWT: %v", err)) + } + subject, err := c.Storage.GetImpersonateSubject(ctx, claims, requester) + if err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("not allowed to token exchange by jwt: %v", err)) + } + requester.SetSession(&fosite.DefaultSession{ + Subject: subject, + }) + requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.Config.GetAccessTokenLifespan(ctx))) + return nil + case tokenTypeAT: + or, err := c.verifyAccessTokenAsSubjectToken(ctx, client.GetID(), params) + if err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("not allowed to token exchange by at: %v", err)) + } + requester.SetSession(or.GetSession().Clone()) + // When the subject_type is AT, the expiration time is same with subject_token. + // Therefore, we don't need to set the expiresAt. + return nil + default: + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("unsupported subject_type %s", params.subjectTokenType)) + } +} + +func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { + if !c.CanHandleTokenEndpointRequest(requester) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + if !requester.GetClient().GetGrantTypes().Has(string(fosite.GrantTypeTokenExchange)) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHintf("The OAuth 2.0 Client is not allowed to use authorization grant '%s'.", fosite.GrantTypeTokenExchange)) + } + + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeTokenExchange, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) + + if err := c.IssueAccessToken(ctx, atLifespan, requester, responder); err != nil { + return err + } + + if canIssueRefreshToken(ctx, c, requester) { + fmt.Println(requester) + refresh, refreshSignature, err := c.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + if err := c.RefreshTokenStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithDebug(err.Error())) + } + + responder.SetExtra("refresh_token", refresh) + } + return nil +} + +func canIssueRefreshToken(ctx context.Context, c *Handler, requester fosite.Requester) bool { + scope := c.GetRefreshTokenScopes(ctx) + // Require one of the refresh token scopes, if set. + if len(scope) > 0 && !requester.GetGrantedScopes().HasOneOf(scope...) { + return false + } + // Do not issue a refresh token to clients that cannot use the refresh token grant type. + if !requester.GetClient().GetGrantTypes().Has("refresh_token") { + return false + } + return true +} + +func (c *Handler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return c.Strategy.CanSkipClientAuth(requester) +} + +func (c *Handler) keyFunc(ctx context.Context) jwt.Keyfunc { + return jwt.Keyfunc(func(t *jwt.Token) (interface{}, error) { + kid, ok := t.Header["kid"].(string) + if !ok { + return nil, errors.New("invalid kid") + } + iss, ok := t.Claims["iss"].(string) + if !ok { + return nil, errors.New("invalid iss") + } + return c.Storage.GetIDTokenPublicKey(ctx, iss, kid) + }) +} + +func (c *Handler) verifyAccessTokenAsSubjectToken(ctx context.Context, clientID string, params *tokenExchangeParams) (fosite.Requester, error) { + sig := c.HandleHelper.AccessTokenStrategy.AccessTokenSignature(ctx, params.subjectToken) + or, err := c.HandleHelper.AccessTokenStorage.GetAccessTokenSession(ctx, sig, nil) + if err != nil { + return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(err).WithDebug(err.Error())) + } else if err := c.AccessTokenStrategy.ValidateAccessToken(ctx, or, params.subjectToken); err != nil { + return nil, err + } + + allowClientIDs, err := c.Storage.GetAllowedClientIDs(ctx, clientID) + if err != nil { + return nil, err + } + + for _, cid := range allowClientIDs { + if or.GetClient().GetID() == cid { + return or, nil + } + } + return nil, fmt.Errorf("this access_token is not allowed to use token exchange based on AT: original_client:%s, request_client:%s ", or.GetClient().GetID(), clientID) +} + +func (c *Handler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + return requester.GetGrantTypes().ExactOne(string(fosite.GrantTypeTokenExchange)) +} diff --git a/handler/rfc8693/handler_test.go b/handler/rfc8693/handler_test.go new file mode 100644 index 000000000..89a5cf7a7 --- /dev/null +++ b/handler/rfc8693/handler_test.go @@ -0,0 +1,354 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +import ( + "context" + "net/http" + "net/url" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/ory/fosite" + fositeOAuth2 "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/internal" + "github.com/ory/fosite/token/jwt" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" +) + +func TestTokenExchange_HandleTokenEndpointRequest(t *testing.T) { + ctrl := gomock.NewController(t) + teStore := internal.NewMockRFC8693Storage(ctrl) + atStore := internal.NewMockAccessTokenStorage(ctrl) + rtStore := internal.NewMockRefreshTokenGrantStorage(ctrl) + chgen := internal.NewMockAccessTokenStrategy(ctrl) + areq := internal.NewMockAccessRequester(ctrl) + defer ctrl.Finish() + + h := Handler{ + Storage: teStore, + HandleHelper: &fositeOAuth2.HandleHelper{ + AccessTokenStorage: atStore, + AccessTokenStrategy: chgen, + Config: &fosite.Config{ + AccessTokenLifespan: time.Hour, + }, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenStorage: rtStore, + } + + for _, c := range []struct { + name string + mock func() + req *http.Request + expectErr error + }{ + { + name: "should fail because granttype is missing", + expectErr: fosite.ErrUnknownRequest, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{""}) + }, + }, + { + name: "should fail because invalid client_id", + expectErr: fosite.ErrUnauthorizedClient, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{}) + }, + }, + { + name: "should fail because grant_type is not valid", + expectErr: fosite.ErrUnauthorizedClient, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{""}, + }) + }, + }, + { + name: "should fail because no subject_token", + expectErr: fosite.ErrInvalidRequest, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{""}, + }) + }, + }, + { + name: "should fail because unsupported subject_token_type", + expectErr: fosite.ErrInvalidRequest, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{"subject_token"}, + "subject_token_type": []string{"unsupported_subject_token_type"}, + }) + }, + }, + { + name: "should fail because scope not valid", + expectErr: fosite.ErrInvalidScope, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"none"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{"subject_token"}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + }, + }, + { + name: "should pass as AT", + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"foo"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{"subject_token"}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + + // scope and audience. + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + areq.EXPECT().GrantScope("foo") + areq.EXPECT().GetRequestedAudience().Return([]string{}) + areq.EXPECT().GetRequestedAudience().Return([]string{}) + chgen.EXPECT().AccessTokenSignature(gomock.Any(), gomock.Any()).Return("signature") + + // original request. + ar := internal.NewMockAccessRequester(ctrl) + atStore.EXPECT().GetAccessTokenSession(gomock.Any(), "signature", nil).Return(ar, nil) + chgen.EXPECT().ValidateAccessToken(gomock.Any(), ar, gomock.Any()).Return(nil) + + teStore.EXPECT().GetAllowedClientIDs(gomock.Any(), "client").Return([]string{"client2"}, nil) + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client2", + }) + ar.EXPECT().GetSession().Return(new(fosite.DefaultSession)) + areq.EXPECT().SetSession(gomock.Any()) + }, + }, + { + name: "should fail because of different key", + expectErr: fosite.ErrInvalidRequest, + mock: func() { + // ID Token JWT. + key := []byte("aabbbbccccddddddd") + token := jwt.Token{ + Header: map[string]interface{}{ + "kid": "12asd4q34daf", + }, + Claims: jwt.MapClaims{ + "sub": "foo", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, + Method: jose.HS256, + } + tokenString, err := token.SignedString(key) + require.NoError(t, err) + + // request. + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"foo"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{tokenString}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:id_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + + // scope and audience. + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + areq.EXPECT().GrantScope("foo") + areq.EXPECT().GetRequestedAudience().Return([]string{}) + areq.EXPECT().GetRequestedAudience().Return([]string{}) + + // verify IDToken. + teStore.EXPECT().GetIDTokenPublicKey(gomock.Any(), "bar", "12asd4q34daf").Return(&jose.JSONWebKey{ + Key: []byte("differnet_key"), + }, nil) + }, + }, + { + name: "should pass as JWT", + mock: func() { + // ID Token JWT. + key := []byte("aaabbbbcccddd") + token := jwt.Token{ + Header: map[string]interface{}{ + "kid": "12asd4q34daf", + }, + Claims: jwt.MapClaims{ + "sub": "foo", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, + Method: jose.HS256, + } + tokenString, err := token.SignedString(key) + require.NoError(t, err) + + // request. + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"foo"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{tokenString}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:id_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + + // scope and audience. + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + areq.EXPECT().GrantScope("foo") + areq.EXPECT().GetRequestedAudience().Return([]string{}) + areq.EXPECT().GetRequestedAudience().Return([]string{}) + + // verify IDToken. + teStore.EXPECT().GetIDTokenPublicKey(gomock.Any(), "bar", "12asd4q34daf").Return(&jose.JSONWebKey{ + Key: key, + }, nil) + teStore.EXPECT().GetImpersonateSubject(gomock.Any(), gomock.Any(), gomock.Any()).Return("client", nil) + + areq.EXPECT().SetSession(gomock.Any()) + areq.EXPECT().GetSession().Return(new(fosite.DefaultSession)) + }, + }, + } { + t.Run(c.name, func(t *testing.T) { + c.mock() + err := h.HandleTokenEndpointRequest(context.TODO(), areq) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTokenExchange_PopulateTokenEndpointResponse(t *testing.T) { + ctrl := gomock.NewController(t) + atStore := internal.NewMockAccessTokenStorage(ctrl) + chgen := internal.NewMockAccessTokenStrategy(ctrl) + + areq := fosite.NewAccessRequest(new(fosite.DefaultSession)) + aresp := fosite.NewAccessResponse() + rtStrategy := internal.NewMockRefreshTokenStrategy(ctrl) + rtStore := internal.NewMockRefreshTokenGrantStorage(ctrl) + + defer ctrl.Finish() + + h := Handler{ + HandleHelper: &fositeOAuth2.HandleHelper{ + AccessTokenStorage: atStore, + AccessTokenStrategy: chgen, + Config: &fosite.Config{ + AccessTokenLifespan: time.Hour, + }, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenStrategy: rtStrategy, + RefreshTokenStorage: rtStore, + RefreshTokenScopesProvider: &fosite.Config{ + RefreshTokenScopes: []string{"offline", "offline_access"}, + }, + } + for _, c := range []struct { + name string + mock func() + req *http.Request + expectErr error + }{ + { + name: "should fail because not responsible", + expectErr: fosite.ErrUnknownRequest, + mock: func() { + areq.GrantTypes = fosite.Arguments{""} + }, + }, + { + name: "should fail because grant_type not allowed", + expectErr: fosite.ErrUnauthorizedClient, + mock: func() { + areq.GrantTypes = fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"} + areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"authorization_code"}} + }, + }, + { + name: "should pass", + mock: func() { + areq.GrantTypes = fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"} + areq.Session = &fosite.DefaultSession{} + areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}} + chgen.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return("tokenfoo.bar", "bar", nil) + atStore.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) + }, + }, + { + name: "should populate both AT and RT", + mock: func() { + areq.GrantedScope = fosite.Arguments{"offline_access"} + areq.GrantTypes = fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"} + areq.Session = &fosite.DefaultSession{} + areq.Client = &fosite.DefaultClient{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange", "refresh_token"}, + } + chgen.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return("tokenfoo.bar", "bar", nil) + atStore.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) + rtStrategy.EXPECT().GenerateRefreshToken(gomock.Any(), gomock.Any()).Return("refresh_token", "refresh_token_signature", nil) + rtStore.EXPECT().CreateRefreshTokenSession(gomock.Any(), "refresh_token_signature", gomock.Eq(areq)).Return(nil) + }, + }, + } { + t.Run(c.name, func(t *testing.T) { + c.mock() + err := h.PopulateTokenEndpointResponse(context.TODO(), areq, aresp) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/handler/rfc8693/storage.go b/handler/rfc8693/storage.go new file mode 100644 index 000000000..528f720ba --- /dev/null +++ b/handler/rfc8693/storage.go @@ -0,0 +1,28 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +//go:generate mockgen -source=storage.go -destination=../../internal/oauth2_token_exchange_storage.go -package=internal + +import ( + "context" + + "github.com/ory/fosite" + "github.com/ory/fosite/token/jwt" +) + +// RFC8693Storage hold information needed to perform token exchange. +type RFC8693Storage interface { + // GetAllowedClientIDs returns clientIDs that can be used for subject_token. + // The subject token is a security token that represents the identity of + // the party on behalf of whom the request is being made. + // https://datatracker.ietf.org/doc/html/rfc8693#section-2.1 + GetAllowedClientIDs(ctx context.Context, clientID string) ([]string, error) + + // GetIDTokenPublicKey returns the public key that can be used to verify ID Token. + GetIDTokenPublicKey(ctx context.Context, iss, kid string) (interface{}, error) + + // GetImpersonateSubject returns subject value to use the token based on a JWT. + GetImpersonateSubject(ctx context.Context, claims jwt.MapClaims, req fosite.Requester) (string, error) +} diff --git a/handler/rfc8693/strategy.go b/handler/rfc8693/strategy.go new file mode 100644 index 000000000..df2051628 --- /dev/null +++ b/handler/rfc8693/strategy.go @@ -0,0 +1,19 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +//go:generate mockgen -source=strategy.go -destination=../../internal/oauth2_token_exchange_strategy.go -package=internal + +import "github.com/ory/fosite" + +type ClientAuthenticationStrategy interface { + CanSkipClientAuth(requester fosite.AccessRequester) bool +} + +// DefaultClientAuthenticationStrategy enforces client authentication for all the cases. +type DefaultClientAuthenticationStrategy struct{} + +func (s *DefaultClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.Requester) bool { + return false +} diff --git a/internal/oauth2_token_exchange_storage.go b/internal/oauth2_token_exchange_storage.go new file mode 100644 index 000000000..a8db68cc6 --- /dev/null +++ b/internal/oauth2_token_exchange_storage.go @@ -0,0 +1,85 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: storage.go + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" + jwt "github.com/ory/fosite/token/jwt" +) + +// MockRFC8693Storage is a mock of RFC8693Storage interface. +type MockRFC8693Storage struct { + ctrl *gomock.Controller + recorder *MockRFC8693StorageMockRecorder +} + +// MockRFC8693StorageMockRecorder is the mock recorder for MockRFC8693Storage. +type MockRFC8693StorageMockRecorder struct { + mock *MockRFC8693Storage +} + +// NewMockRFC8693Storage creates a new mock instance. +func NewMockRFC8693Storage(ctrl *gomock.Controller) *MockRFC8693Storage { + mock := &MockRFC8693Storage{ctrl: ctrl} + mock.recorder = &MockRFC8693StorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRFC8693Storage) EXPECT() *MockRFC8693StorageMockRecorder { + return m.recorder +} + +// GetAllowedClientIDs mocks base method. +func (m *MockRFC8693Storage) GetAllowedClientIDs(ctx context.Context, clientID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllowedClientIDs", ctx, clientID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllowedClientIDs indicates an expected call of GetAllowedClientIDs. +func (mr *MockRFC8693StorageMockRecorder) GetAllowedClientIDs(ctx, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllowedClientIDs", reflect.TypeOf((*MockRFC8693Storage)(nil).GetAllowedClientIDs), ctx, clientID) +} + +// GetIDTokenPublicKey mocks base method. +func (m *MockRFC8693Storage) GetIDTokenPublicKey(ctx context.Context, iss, kid string) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIDTokenPublicKey", ctx, iss, kid) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetIDTokenPublicKey indicates an expected call of GetIDTokenPublicKey. +func (mr *MockRFC8693StorageMockRecorder) GetIDTokenPublicKey(ctx, iss, kid interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIDTokenPublicKey", reflect.TypeOf((*MockRFC8693Storage)(nil).GetIDTokenPublicKey), ctx, iss, kid) +} + +// GetImpersonateSubject mocks base method. +func (m *MockRFC8693Storage) GetImpersonateSubject(ctx context.Context, claims jwt.MapClaims, req fosite.Requester) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetImpersonateSubject", ctx, claims, req) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetImpersonateSubject indicates an expected call of GetImpersonateSubject. +func (mr *MockRFC8693StorageMockRecorder) GetImpersonateSubject(ctx, claims, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetImpersonateSubject", reflect.TypeOf((*MockRFC8693Storage)(nil).GetImpersonateSubject), ctx, claims, req) +} diff --git a/internal/oauth2_token_exchange_strategy.go b/internal/oauth2_token_exchange_strategy.go new file mode 100644 index 000000000..66875129f --- /dev/null +++ b/internal/oauth2_token_exchange_strategy.go @@ -0,0 +1,52 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: strategy.go + +// Package internal is a generated GoMock package. +package internal + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" +) + +// MockClientAuthenticationStrategy is a mock of ClientAuthenticationStrategy interface. +type MockClientAuthenticationStrategy struct { + ctrl *gomock.Controller + recorder *MockClientAuthenticationStrategyMockRecorder +} + +// MockClientAuthenticationStrategyMockRecorder is the mock recorder for MockClientAuthenticationStrategy. +type MockClientAuthenticationStrategyMockRecorder struct { + mock *MockClientAuthenticationStrategy +} + +// NewMockClientAuthenticationStrategy creates a new mock instance. +func NewMockClientAuthenticationStrategy(ctrl *gomock.Controller) *MockClientAuthenticationStrategy { + mock := &MockClientAuthenticationStrategy{ctrl: ctrl} + mock.recorder = &MockClientAuthenticationStrategyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientAuthenticationStrategy) EXPECT() *MockClientAuthenticationStrategyMockRecorder { + return m.recorder +} + +// CanSkipClientAuth mocks base method. +func (m *MockClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.AccessRequester) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CanSkipClientAuth", requester) + ret0, _ := ret[0].(bool) + return ret0 +} + +// CanSkipClientAuth indicates an expected call of CanSkipClientAuth. +func (mr *MockClientAuthenticationStrategyMockRecorder) CanSkipClientAuth(requester interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSkipClientAuth", reflect.TypeOf((*MockClientAuthenticationStrategy)(nil).CanSkipClientAuth), requester) +} diff --git a/oauth2.go b/oauth2.go index 3184d39d3..d6ade26bf 100644 --- a/oauth2.go +++ b/oauth2.go @@ -31,7 +31,8 @@ const ( GrantTypeAuthorizationCode GrantType = "authorization_code" GrantTypePassword GrantType = "password" GrantTypeClientCredentials GrantType = "client_credentials" - GrantTypeJWTBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" //nolint:gosec // this is not a hardcoded credential + GrantTypeJWTBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" //nolint:gosec // this is not a hardcoded credential + GrantTypeTokenExchange GrantType = "urn:ietf:params:oauth:grant-type:token-exchange" //nolint:gosec // this is not a hardcoded credential BearerAccessToken string = "bearer" )