diff --git a/go.mod b/go.mod index 774f54b..c6f116b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.0 require ( github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/lestrrat-go/jwx v1.2.29 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index d369b72..0cf2331 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= diff --git a/main.go b/main.go index 7c0d7c3..eef412c 100644 --- a/main.go +++ b/main.go @@ -10,12 +10,11 @@ import ( "os/exec" "strings" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/lestrrat-go/jwx/jwk" ) func main() { - if os.Getenv("JWKS_URI") == "" { log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys") } @@ -88,12 +87,12 @@ func ParseToken(token *jwt.Token) (interface{}, error) { } // Check audience claim - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("error retrieving claims from token") - } aud := os.Getenv("JWT_AUD") - if !claims.VerifyAudience(aud, true) { + taud, err := token.Claims.GetAudience() + if err != nil { + return nil, fmt.Errorf("could not get aud claim: %v", err) + } + if !strInSlice(aud, taud) { return nil, fmt.Errorf("invalid audience. Expected: %s", aud) } @@ -130,3 +129,12 @@ func readUserIP(r *http.Request) (string, string) { } return realIP, lastIP } + +func strInSlice(e string, s []string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +}