From b8e5a6b70934914b028be33fbe5d3b364502f519 Mon Sep 17 00:00:00 2001 From: NitiwatOwen Date: Fri, 29 Dec 2023 15:52:11 +0700 Subject: [PATCH] fix: type --- src/internal/domain/dto/token/token.dto.go | 10 +- src/internal/service/jwt/jwt.service.go | 1 - src/internal/service/jwt/jwt.service_test.go | 1 - src/internal/service/token/token.service.go | 16 +-- .../service/token/token.service_test.go | 103 ++++++++---------- 5 files changed, 58 insertions(+), 73 deletions(-) diff --git a/src/internal/domain/dto/token/token.dto.go b/src/internal/domain/dto/token/token.dto.go index d0db8ac..b538444 100644 --- a/src/internal/domain/dto/token/token.dto.go +++ b/src/internal/domain/dto/token/token.dto.go @@ -14,14 +14,14 @@ type UserCredential struct { type AuthPayload struct { jwt.RegisteredClaims - UserID string `json:"user_id"` - Role constant.Role `json:"role"` - AuthSessionID string `json:"auth_session_id"` + UserID string `json:"user_id"` + AuthSessionID string `json:"auth_session_id"` } type AccessTokenCache struct { - Token string `json:"token"` - RefreshToken string `json:"refresh_token"` + Token string `json:"token"` + Role constant.Role `json:"role"` + RefreshToken string `json:"refresh_token"` } type RefreshTokenCache struct { diff --git a/src/internal/service/jwt/jwt.service.go b/src/internal/service/jwt/jwt.service.go index 954f844..7731531 100644 --- a/src/internal/service/jwt/jwt.service.go +++ b/src/internal/service/jwt/jwt.service.go @@ -30,7 +30,6 @@ func (s *serviceImpl) SignAuth(userId string, role constant.Role, authSessionId IssuedAt: s.jwtUtil.GetNumericDate(time.Now()), }, UserID: userId, - Role: role, AuthSessionID: authSessionId, } diff --git a/src/internal/service/jwt/jwt.service_test.go b/src/internal/service/jwt/jwt.service_test.go index 808d3ff..a1b6230 100644 --- a/src/internal/service/jwt/jwt.service_test.go +++ b/src/internal/service/jwt/jwt.service_test.go @@ -51,7 +51,6 @@ func (t *JwtServiceTest) SetupTest() { IssuedAt: numericDate, }, UserID: userId, - Role: role, AuthSessionID: authSessionId, } diff --git a/src/internal/service/token/token.service.go b/src/internal/service/token/token.service.go index b3a3bfe..383f9aa 100644 --- a/src/internal/service/token/token.service.go +++ b/src/internal/service/token/token.service.go @@ -1,6 +1,7 @@ package token import ( + _jwt "github.com/golang-jwt/jwt/v4" "github.com/isd-sgcu/johnjud-auth/src/internal/constant" tokenDto "github.com/isd-sgcu/johnjud-auth/src/internal/domain/dto/token" "github.com/isd-sgcu/johnjud-auth/src/internal/utils" @@ -39,6 +40,7 @@ func (s *serviceImpl) CreateCredential(userId string, role constant.Role, authSe accessTokenCache := &tokenDto.AccessTokenCache{ Token: accessToken, + Role: role, RefreshToken: refreshToken, } err = s.accessTokenCache.SetValue(authSessionId, accessTokenCache, jwtConf.ExpiresIn) @@ -71,17 +73,17 @@ func (s *serviceImpl) Validate(token string) (*tokenDto.UserCredential, error) { return nil, err } - payloads := jwtToken.Claims.(tokenDto.AuthPayload) - if payloads.Issuer != s.jwtService.GetConfig().Issuer { + payloads := jwtToken.Claims.(_jwt.MapClaims) + if payloads["iss"] != s.jwtService.GetConfig().Issuer { return nil, errors.New("invalid token") } - if time.Unix(payloads.ExpiresAt.Unix(), 0).Before(time.Now()) { + if time.Unix(int64(payloads["exp"].(float64)), 0).Before(time.Now()) { return nil, errors.New("expired token") } accessTokenCache := &tokenDto.AccessTokenCache{} - err = s.accessTokenCache.GetValue(payloads.AuthSessionID, accessTokenCache) + err = s.accessTokenCache.GetValue(payloads["auth_session_id"].(string), accessTokenCache) if err != nil { if err != redis.Nil { return nil, err @@ -94,9 +96,9 @@ func (s *serviceImpl) Validate(token string) (*tokenDto.UserCredential, error) { } userCredential := &tokenDto.UserCredential{ - UserID: payloads.UserID, - Role: payloads.Role, - AuthSessionID: payloads.AuthSessionID, + UserID: payloads["user_id"].(string), + Role: accessTokenCache.Role, + AuthSessionID: payloads["auth_session_id"].(string), RefreshToken: accessTokenCache.RefreshToken, } return userCredential, nil diff --git a/src/internal/service/token/token.service_test.go b/src/internal/service/token/token.service_test.go index 355841b..ad1cd71 100644 --- a/src/internal/service/token/token.service_test.go +++ b/src/internal/service/token/token.service_test.go @@ -61,6 +61,7 @@ func (t *TokenServiceTest) SetupTest() { func (t *TokenServiceTest) TestCreateCredentialSuccess() { accessTokenCache := &tokenDto.AccessTokenCache{ Token: t.accessToken, + Role: t.role, RefreshToken: t.refreshToken.String(), } refreshTokenCache := &tokenDto.RefreshTokenCache{ @@ -120,6 +121,7 @@ func (t *TokenServiceTest) TestCreateCredentialSignAuthFailed() { func (t *TokenServiceTest) TestCreateCredentialSetAccessTokenFailed() { accessTokenCache := &tokenDto.AccessTokenCache{ Token: t.accessToken, + Role: t.role, RefreshToken: t.refreshToken.String(), } setCacheErr := errors.New("Internal server error") @@ -147,6 +149,7 @@ func (t *TokenServiceTest) TestCreateCredentialSetAccessTokenFailed() { func (t *TokenServiceTest) TestCreateCredentialSetRefreshTokenFailed() { accessTokenCache := &tokenDto.AccessTokenCache{ Token: t.accessToken, + Role: t.role, RefreshToken: t.refreshToken.String(), } refreshTokenCache := &tokenDto.RefreshTokenCache{ @@ -180,19 +183,16 @@ func (t *TokenServiceTest) TestCreateCredentialSetRefreshTokenFailed() { func (t *TokenServiceTest) TestValidateSuccess() { expected := &tokenDto.UserCredential{ UserID: t.userId, - Role: constant.USER, + Role: "", AuthSessionID: t.authSessionId, RefreshToken: "", } - payloads := tokenDto.AuthPayload{ - RegisteredClaims: _jwt.RegisteredClaims{ - Issuer: t.jwtConfig.Issuer, - ExpiresAt: _jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))), - IssuedAt: _jwt.NewNumericDate(time.Now()), - }, - UserID: t.userId, - Role: t.role, - AuthSessionID: t.authSessionId, + payloads := _jwt.MapClaims{ + "iss": t.jwtConfig.Issuer, + "exp": float64(_jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))).Unix()), + "iat": float64(_jwt.NewNumericDate(time.Now()).Unix()), + "user_id": t.userId, + "auth_session_id": t.authSessionId, } jwtToken := &_jwt.Token{ Method: _jwt.SigningMethodHS256, @@ -209,7 +209,7 @@ func (t *TokenServiceTest) TestValidateSuccess() { jwtService.On("VerifyAuth", t.validateToken).Return(jwtToken, nil) jwtService.On("GetConfig").Return(t.jwtConfig) - accessTokenRepo.EXPECT().GetValue(payloads.AuthSessionID, accessTokenCache).Return(nil) + accessTokenRepo.EXPECT().GetValue(payloads["auth_session_id"].(string), accessTokenCache).Return(nil) tokenSvc := NewService(&jwtService, accessTokenRepo, refreshTokenRepo, &uuidUtil) actual, err := tokenSvc.Validate(t.validateToken) @@ -221,15 +221,12 @@ func (t *TokenServiceTest) TestValidateSuccess() { func (t *TokenServiceTest) TestValidateInvalidIssuer() { expected := errors.New("invalid token") - payloads := tokenDto.AuthPayload{ - RegisteredClaims: _jwt.RegisteredClaims{ - Issuer: "InvalidIssuer", - ExpiresAt: _jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))), - IssuedAt: _jwt.NewNumericDate(time.Now()), - }, - UserID: t.userId, - Role: t.role, - AuthSessionID: t.authSessionId, + payloads := _jwt.MapClaims{ + "iss": "invalid issuer", + "exp": float64(_jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))).Unix()), + "iat": float64(_jwt.NewNumericDate(time.Now()).Unix()), + "user_id": t.userId, + "auth_session_id": t.authSessionId, } jwtToken := &_jwt.Token{ @@ -257,15 +254,12 @@ func (t *TokenServiceTest) TestValidateInvalidIssuer() { func (t *TokenServiceTest) TestValidateExpireToken() { expected := errors.New("expired token") - payloads := tokenDto.AuthPayload{ - RegisteredClaims: _jwt.RegisteredClaims{ - Issuer: t.jwtConfig.Issuer, - ExpiresAt: _jwt.NewNumericDate(time.Now().Add(time.Second * (-time.Duration(t.jwtConfig.ExpiresIn)))), - IssuedAt: _jwt.NewNumericDate(time.Now()), - }, - UserID: t.userId, - Role: t.role, - AuthSessionID: t.authSessionId, + payloads := _jwt.MapClaims{ + "iss": t.jwtConfig.Issuer, + "exp": float64(_jwt.NewNumericDate(time.Now().Add(time.Second * (-time.Duration(t.jwtConfig.ExpiresIn)))).Unix()), + "iat": float64(_jwt.NewNumericDate(time.Now()).Unix()), + "user_id": t.userId, + "auth_session_id": t.authSessionId, } jwtToken := &_jwt.Token{ Method: _jwt.SigningMethodHS256, @@ -311,15 +305,12 @@ func (t *TokenServiceTest) TestValidateVerifyFailed() { func (t *TokenServiceTest) TestValidateGetCacheKeyNotFound() { expected := errors.New("invalid token") - payloads := tokenDto.AuthPayload{ - RegisteredClaims: _jwt.RegisteredClaims{ - Issuer: t.jwtConfig.Issuer, - ExpiresAt: _jwt.NewNumericDate(time.Now().Add(time.Second * (time.Duration(t.jwtConfig.ExpiresIn)))), - IssuedAt: _jwt.NewNumericDate(time.Now()), - }, - UserID: t.userId, - Role: t.role, - AuthSessionID: t.authSessionId, + payloads := _jwt.MapClaims{ + "iss": t.jwtConfig.Issuer, + "exp": float64(_jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))).Unix()), + "iat": float64(_jwt.NewNumericDate(time.Now()).Unix()), + "user_id": t.userId, + "auth_session_id": t.authSessionId, } jwtToken := &_jwt.Token{ Method: _jwt.SigningMethodHS256, @@ -336,7 +327,7 @@ func (t *TokenServiceTest) TestValidateGetCacheKeyNotFound() { jwtService.On("VerifyAuth", t.validateToken).Return(jwtToken, nil) jwtService.On("GetConfig").Return(t.jwtConfig) - accessTokenRepo.EXPECT().GetValue(payloads.AuthSessionID, accessTokenCache).Return(redis.Nil) + accessTokenRepo.EXPECT().GetValue(payloads["auth_session_id"].(string), accessTokenCache).Return(redis.Nil) tokenSvc := NewService(&jwtService, accessTokenRepo, refreshTokenRepo, &uuidUtil) actual, err := tokenSvc.Validate(t.validateToken) @@ -346,15 +337,12 @@ func (t *TokenServiceTest) TestValidateGetCacheKeyNotFound() { } func (t *TokenServiceTest) TestValidateGetCacheInternalFailed() { - payloads := tokenDto.AuthPayload{ - RegisteredClaims: _jwt.RegisteredClaims{ - Issuer: t.jwtConfig.Issuer, - ExpiresAt: _jwt.NewNumericDate(time.Now().Add(time.Second * (time.Duration(t.jwtConfig.ExpiresIn)))), - IssuedAt: _jwt.NewNumericDate(time.Now()), - }, - UserID: t.userId, - Role: t.role, - AuthSessionID: t.authSessionId, + payloads := _jwt.MapClaims{ + "iss": t.jwtConfig.Issuer, + "exp": float64(_jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))).Unix()), + "iat": float64(_jwt.NewNumericDate(time.Now()).Unix()), + "user_id": t.userId, + "auth_session_id": t.authSessionId, } jwtToken := &_jwt.Token{ Method: _jwt.SigningMethodHS256, @@ -374,7 +362,7 @@ func (t *TokenServiceTest) TestValidateGetCacheInternalFailed() { jwtService.On("VerifyAuth", t.validateToken).Return(jwtToken, nil) jwtService.On("GetConfig").Return(t.jwtConfig) - accessTokenRepo.EXPECT().GetValue(payloads.AuthSessionID, accessTokenCache).Return(getCacheErr) + accessTokenRepo.EXPECT().GetValue(payloads["auth_session_id"].(string), accessTokenCache).Return(getCacheErr) tokenSvc := NewService(&jwtService, accessTokenRepo, refreshTokenRepo, &uuidUtil) actual, err := tokenSvc.Validate(t.validateToken) @@ -387,15 +375,12 @@ func (t *TokenServiceTest) TestValidateInvalidToken() { invalidToken := faker.Word() expected := errors.New("invalid token") - payloads := tokenDto.AuthPayload{ - RegisteredClaims: _jwt.RegisteredClaims{ - Issuer: t.jwtConfig.Issuer, - ExpiresAt: _jwt.NewNumericDate(time.Now().Add(time.Second * (time.Duration(t.jwtConfig.ExpiresIn)))), - IssuedAt: _jwt.NewNumericDate(time.Now()), - }, - UserID: t.userId, - Role: t.role, - AuthSessionID: t.authSessionId, + payloads := _jwt.MapClaims{ + "iss": t.jwtConfig.Issuer, + "exp": float64(_jwt.NewNumericDate(time.Now().Add(time.Second * time.Duration(t.jwtConfig.ExpiresIn))).Unix()), + "iat": float64(_jwt.NewNumericDate(time.Now()).Unix()), + "user_id": t.userId, + "auth_session_id": t.authSessionId, } jwtToken := &_jwt.Token{ Method: _jwt.SigningMethodHS256, @@ -412,7 +397,7 @@ func (t *TokenServiceTest) TestValidateInvalidToken() { jwtService.On("VerifyAuth", invalidToken).Return(jwtToken, nil) jwtService.On("GetConfig").Return(t.jwtConfig) - accessTokenRepo.EXPECT().GetValue(payloads.AuthSessionID, accessTokenCache).Return(nil) + accessTokenRepo.EXPECT().GetValue(payloads["auth_session_id"].(string), accessTokenCache).Return(nil) tokenSvc := NewService(&jwtService, accessTokenRepo, refreshTokenRepo, &uuidUtil) actual, err := tokenSvc.Validate(invalidToken)