From 3cbe7fe2dc8ea640aa782b3682f31d9e08825f5d Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Mon, 15 Apr 2024 11:24:49 -0400 Subject: [PATCH] [minor] allow setting the rollout command and args (#1) --- README.md | 3 +- main.go | 47 ++++++++--------- main_test.go | 139 ++++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 141 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index bcd4171..0766d90 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ This service requires two envionrment variables. - `JWKS_URI` - the URL of the OIDC Provider's [JSON Web Key (JWK) set document](https://www.rfc-editor.org/info/rfc7517). This is used to ensure the JWT was signed by the provider. - `JWT_AUD` - the audience set in the JWT token. +- `ROLLOUT_CMD` (default: `/bin/bash`) - the command to execute a rollout +- `ROLLOUT_ARGS` (default: `/rollout.sh` ) - the args to pass to `ROLLOUT_CMD` ### GitHub @@ -55,4 +57,3 @@ JWT_AUD=aud-string-you-set-in-your-job - [ ] Install instructions using binary - [ ] Tag/push versions to dockerhub - [ ] Allow more custom auth handling -- [ ] Allow more custom rollout than a single bash script diff --git a/main.go b/main.go index 419852c..7c0d7c3 100644 --- a/main.go +++ b/main.go @@ -8,23 +8,18 @@ import ( "net/http" "os" "os/exec" + "strings" "github.com/golang-jwt/jwt/v4" "github.com/lestrrat-go/jwx/jwk" ) -var ( - gitLabJwksURL, aud string -) - func main() { - gitLabJwksURL = os.Getenv("JWKS_URI") - if gitLabJwksURL == "" { + if os.Getenv("JWKS_URI") == "" { log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys") } - aud = os.Getenv("JWT_AUD") - if aud == "" { + if os.Getenv("JWT_AUD") == "" { log.Fatal("JWT_AUD is required. This needs to be the aud in the JWT you except this service to handle.") } @@ -36,15 +31,6 @@ func main() { } } -func readUserIP(r *http.Request) (string, string) { - realIP := r.Header.Get("X-Real-Ip") - lastIP := r.RemoteAddr - if realIP == "" { - realIP = r.Header.Get("X-Forwarded-For") - } - return realIP, lastIP -} - func Rollout(w http.ResponseWriter, r *http.Request) { realIp, lastIP := readUserIP(r) @@ -71,9 +57,15 @@ func Rollout(w http.ResponseWriter, r *http.Request) { return } - // TODO make this more customizable - // but for now this fills the need - cmd := exec.Command("/bin/bash", "/rollout.sh") + name := os.Getenv("ROLLOUT_CMD") + if name == "" { + name = "/bin/bash" + } + args := os.Getenv("ROLLOUT_ARGS") + if args == "" { + args = "/rollout.sh" + } + cmd := exec.Command(name, strings.Split(args, " ")...) var stdOut, stdErr bytes.Buffer cmd.Stdout = &stdOut @@ -111,10 +103,10 @@ func ParseToken(token *jwt.Token) (interface{}, error) { } ctx := context.Background() - gitLabJwksURL = os.Getenv("JWKS_URI") - jwksSet, err := jwk.Fetch(ctx, gitLabJwksURL) + jwksUri := os.Getenv("JWKS_URI") + jwksSet, err := jwk.Fetch(ctx, jwksUri) if err != nil { - log.Fatalf("Unable to fetch JWK set from %s: %v", gitLabJwksURL, err) + log.Fatalf("Unable to fetch JWK set from %s: %v", jwksUri, err) } // Find the appropriate key in JWKS key, ok := jwksSet.LookupKeyID(kid) @@ -129,3 +121,12 @@ func ParseToken(token *jwt.Token) (interface{}, error) { return pubkey, nil } + +func readUserIP(r *http.Request) (string, string) { + realIP := r.Header.Get("X-Real-Ip") + lastIP := r.RemoteAddr + if realIP == "" { + realIP = r.Header.Get("X-Forwarded-For") + } + return realIP, lastIP +} diff --git a/main_test.go b/main_test.go index 5e13a72..3fb954d 100644 --- a/main_test.go +++ b/main_test.go @@ -103,9 +103,27 @@ func CreateSignedJWT(kid, aud string, exp int64, privateKey *rsa.PrivateKey) (st return signedToken, nil } -func TestTokenVerification(t *testing.T) { - os.Setenv("JWT_AUD", "test-success") +// Utility function to create a request with an Authorization header +func createRequest(authHeader string) *http.Request { + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", authHeader) + return req +} + +// TestRollout tests the Rollout function with various scenarios +func TestRollout(t *testing.T) { + testFile := "/tmp/rollout-test.txt" + os.Setenv("ROLLOUT_CMD", "touch") + os.Setenv("ROLLOUT_ARGS", testFile) + // make sure the test file doesn't exist + err := RemoveFileIfExists(testFile) + if err != nil { + log.Fatalf("Unable to cleanup test file: %v", err) + } + + // mock the JWKS server response + os.Setenv("JWT_AUD", "test-success") kid := "no-kidding" aud := os.Getenv("JWT_AUD") privateKey, publicKey, err := GenerateRSAKeys() @@ -114,58 +132,131 @@ func TestTokenVerification(t *testing.T) { } server := setupMockJwksServer(publicKey, kid) defer server.Close() - jwkURL := fmt.Sprintf("%s/oauth/discovery/keys", server.URL) os.Setenv("JWKS_URI", jwkURL) - // make sure valid tokens succeed - exp := time.Now().Add(time.Hour * 24).Unix() + // get a valid token + exp := time.Now().Add(time.Hour * 1).Unix() jwtToken, err := CreateSignedJWT(kid, aud, exp, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } - token, err := jwt.Parse(jwtToken, ParseToken) - assert.NoError(t, err) - assert.True(t, token.Valid) // make sure invalid kids fail - jwtToken, err = CreateSignedJWT("just-kidding", aud, exp, privateKey) + badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, exp, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } - token, err = jwt.Parse(jwtToken, ParseToken) - assert.Error(t, err) - assert.False(t, token.Valid) // make sure if we pass a JWT signed by another private key it fails badPrivateKey, _, err := GenerateRSAKeys() if err != nil { t.Fatalf("Unable to generate a new private key") } - jwtToken, err = CreateSignedJWT(kid, aud, exp, badPrivateKey) + badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, exp, badPrivateKey) if err != nil { t.Fatalf("Unable to create a JWT with our new test key: %v", err) } - token, err = jwt.Parse(jwtToken, ParseToken) - assert.Error(t, err) - assert.False(t, token.Valid) // make sure expired JWTs fail expired := time.Now().Add(time.Hour * -1).Unix() - jwtToken, err = CreateSignedJWT(kid, aud, expired, privateKey) + expiredJwtToken, err := CreateSignedJWT(kid, aud, expired, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } - token, err = jwt.Parse(jwtToken, ParseToken) - assert.Error(t, err) - assert.False(t, token.Valid) // make sure bad audience JWTs fail - jwtToken, err = CreateSignedJWT(kid, "different-audience", exp, privateKey) + badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", exp, privateKey) if err != nil { t.Fatalf("Unable to create a JWT with our test key: %v", err) } - token, err = jwt.Parse(jwtToken, ParseToken) - assert.Error(t, err) - assert.False(t, token.Valid) + + // Define test cases + tests := []struct { + name string + authHeader string + expectedStatus int + expectedBody string + }{ + { + name: "No Authorization Header", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + expectedBody: "need authorizaton: bearer xyz header\n", + }, + { + name: "Invalid Token", + authHeader: "Bearer invalidtoken", + expectedStatus: http.StatusUnauthorized, + expectedBody: "Failed to verify token.\n", + }, + { + name: "Bad kid Token", + authHeader: "Bearer " + badKidJwtToken, + expectedStatus: http.StatusUnauthorized, + expectedBody: "Failed to verify token.\n", + }, + { + name: "Signed from wrong JWKS Token", + authHeader: "Bearer " + badPrivKeyjwtToken, + expectedStatus: http.StatusUnauthorized, + expectedBody: "Failed to verify token.\n", + }, + { + name: "Expired Token", + authHeader: "Bearer " + expiredJwtToken, + expectedStatus: http.StatusUnauthorized, + expectedBody: "Failed to verify token.\n", + }, + { + name: "Bad aud Token", + authHeader: "Bearer " + badAudJwtToken, + expectedStatus: http.StatusUnauthorized, + expectedBody: "Failed to verify token.\n", + }, + { + name: "Valid Token and Successful Command", + authHeader: "Bearer " + jwtToken, + expectedStatus: http.StatusOK, + expectedBody: "Rollout complete\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + request := createRequest(tt.authHeader) + + Rollout(recorder, request) + + assert.Equal(t, tt.expectedStatus, recorder.Code) + assert.Equal(t, tt.expectedBody, recorder.Body.String()) + }) + } + + // make sure the rollout command actually ran the command + _, err = os.Stat(testFile) + if err != nil && os.IsNotExist(err) { + t.Errorf("The successful test did not create the expected file") + } + + // cleanup + err = RemoveFileIfExists(testFile) + if err != nil { + log.Fatalf("Unable to cleanup test file: %v", err) + } +} + +func RemoveFileIfExists(filePath string) error { + _, err := os.Stat(filePath) + if err == nil { + err := os.Remove(filePath) + if err != nil { + return fmt.Errorf("failed to remove file: %v", err) + } + } else if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("error checking file: %v", err) + } + + return nil }