From 8fe8c664c3747c66ae132700ce6f3101a266891b Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Mon, 15 Apr 2024 16:26:17 -0400 Subject: [PATCH] Allow passing other JWT claims --- README.md | 5 ++++- main.go | 21 ++++++++++++++++++++- main_test.go | 39 ++++++++++++++++++++++++++++++++------- 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 0766d90..d553e6c 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,10 @@ This service requires two envionrment variables. - `JWKS_URI` - the URL of the OIDC Provider's [JSON Web Key (JWK) set document](https://www.rfc-editor.org/info/rfc7517). This is used to ensure the JWT was signed by the provider. - `JWT_AUD` - the audience set in the JWT token. +- `CUSTOM_CLAIMS` - (optional) JSON of key/value pairs to validate in the JWT e.g. +``` +{"foo": "bar", "foo2": "bar2"} +``` - `ROLLOUT_CMD` (default: `/bin/bash`) - the command to execute a rollout - `ROLLOUT_ARGS` (default: `/rollout.sh` ) - the args to pass to `ROLLOUT_CMD` @@ -56,4 +60,3 @@ JWT_AUD=aud-string-you-set-in-your-job - [ ] Add a full example for GitHub - [ ] Install instructions using binary - [ ] Tag/push versions to dockerhub -- [ ] Allow more custom auth handling diff --git a/main.go b/main.go index eef412c..127a2d4 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "bytes" "context" + "encoding/json" "fmt" "log" "net/http" @@ -51,7 +52,25 @@ func Rollout(w http.ResponseWriter, r *http.Request) { } if !token.Valid { - log.Println("Invalid token for", realIp, ",", lastIP, err.Error()) + log.Println("Invalid token for", realIp, ",", lastIP) + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok { + ccStr := os.Getenv("CUSTOM_CLAIMS") + log.Println(ccStr) + var cc map[string]string + json.Unmarshal([]byte(ccStr), &cc) + for k, v := range cc { + if claims[k] != v { + log.Println("Claim for", k, "doesn't match", realIp, ",", lastIP) + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + } + } else { + log.Println("Unable to read token claims for", realIp, ",", lastIP) http.Error(w, "Invalid token", http.StatusUnauthorized) return } diff --git a/main_test.go b/main_test.go index 1d79bcd..ad5075a 100644 --- a/main_test.go +++ b/main_test.go @@ -80,13 +80,14 @@ func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server { return httptest.NewServer(handler) } -func CreateSignedJWT(kid, aud string, exp int64, privateKey *rsa.PrivateKey) (string, error) { +func CreateSignedJWT(kid, aud, claim string, exp int64, privateKey *rsa.PrivateKey) (string, error) { // Define the claims of the token. You can add more claims based on your needs. claims := jwt.MapClaims{ "sub": "1234567890", "aud": aud, "iat": time.Now().Unix(), "exp": exp, + "foo": claim, } // Create a new token object with the claims and the signing method @@ -126,6 +127,7 @@ func TestRollout(t *testing.T) { os.Setenv("JWT_AUD", "test-success") kid := "no-kidding" aud := os.Getenv("JWT_AUD") + claim := "bar" privateKey, publicKey, err := GenerateRSAKeys() if err != nil { log.Fatalf("Unable to generate RSA keys: %v", err) @@ -137,13 +139,13 @@ func TestRollout(t *testing.T) { // get a valid token exp := time.Now().Add(time.Hour * 1).Unix() - jwtToken, err := CreateSignedJWT(kid, aud, exp, privateKey) + jwtToken, err := CreateSignedJWT(kid, aud, claim, exp, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } // make sure invalid kids fail - badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, exp, privateKey) + badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, claim, exp, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } @@ -153,20 +155,26 @@ func TestRollout(t *testing.T) { if err != nil { t.Fatalf("Unable to generate a new private key") } - badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, exp, badPrivateKey) + badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, claim, exp, badPrivateKey) if err != nil { t.Fatalf("Unable to create a JWT with our new test key: %v", err) } // make sure expired JWTs fail expired := time.Now().Add(time.Hour * -1).Unix() - expiredJwtToken, err := CreateSignedJWT(kid, aud, expired, privateKey) + expiredJwtToken, err := CreateSignedJWT(kid, aud, claim, expired, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } // make sure bad audience JWTs fail - badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", exp, privateKey) + badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", claim, exp, privateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our test key: %v", err) + } + + // make sure JWTs with a bad custom claim fail + badClaimJwtToken, err := CreateSignedJWT(kid, aud, "bad-claim", exp, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } @@ -177,6 +185,7 @@ func TestRollout(t *testing.T) { authHeader string expectedStatus int expectedBody string + claim map[string]string }{ { name: "No Authorization Header", @@ -214,6 +223,18 @@ func TestRollout(t *testing.T) { expectedStatus: http.StatusUnauthorized, expectedBody: "Failed to verify token.\n", }, + { + name: "Bad custom claim", + authHeader: "Bearer " + badClaimJwtToken, + expectedStatus: http.StatusUnauthorized, + expectedBody: "Invalid token\n", + }, + { + name: "No custom claim", + authHeader: "Bearer " + jwtToken, + expectedStatus: http.StatusOK, + expectedBody: "Rollout complete\n", + }, { name: "Valid Token and Successful Command", authHeader: "Bearer " + jwtToken, @@ -221,12 +242,16 @@ func TestRollout(t *testing.T) { expectedBody: "Rollout complete\n", }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { recorder := httptest.NewRecorder() request := createRequest(tt.authHeader) + if tt.name == "No custom claim" { + os.Setenv("CUSTOM_CLAIMS", "") + } else { + os.Setenv("CUSTOM_CLAIMS", `{"foo": "bar"}`) + } Rollout(recorder, request) assert.Equal(t, tt.expectedStatus, recorder.Code)