From 1c82000af4f064d00ac1f0b08f5a31469b6c17ae Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Mon, 15 Apr 2024 09:06:42 -0400 Subject: [PATCH] Refactor and add tests around JWT auth --- go.mod | 4 ++ go.sum | 1 + main.go | 191 +++++++++++++++++++++++++-------------------------- main_test.go | 171 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 270 insertions(+), 97 deletions(-) create mode 100644 main_test.go diff --git a/go.mod b/go.mod index 4b6b28a..774f54b 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,11 @@ go 1.22.0 require ( github.com/golang-jwt/jwt/v4 v4.5.0 github.com/lestrrat-go/jwx v1.2.29 + github.com/stretchr/testify v1.9.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect @@ -16,5 +18,7 @@ require ( github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.22.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 975eaa9..d369b72 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 7c242c1..419852c 100644 --- a/main.go +++ b/main.go @@ -14,111 +14,23 @@ import ( ) var ( - gitLabJwksURL = "https://%s/oauth/discovery/keys" - aud string + gitLabJwksURL, aud string ) -func init() { - domain := os.Getenv("GITLAB_DOMAIN") - if domain == "" { - log.Fatal("GITLAB_DOMAIN is required. You could use GITLAB_DOMAIN=gitlab.com") - } - gitLabJwksURL = fmt.Sprintf(gitLabJwksURL, domain) +func main() { + gitLabJwksURL = os.Getenv("JWKS_URI") + if gitLabJwksURL == "" { + log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys") + } aud = os.Getenv("JWT_AUD") if aud == "" { log.Fatal("JWT_AUD is required. This needs to be the aud in the JWT you except this service to handle.") } -} - -func main() { - ctx := context.Background() - - // Fetch the JWKS from GitLab - set, err := jwk.Fetch(ctx, gitLabJwksURL) - if err != nil { - fmt.Printf("Failed to fetch JWKS: %v\n", err) - return - } - - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - realIp, lastIP := readUserIP(r) - - a := r.Header.Get("Authorization") - if len(a) < 10 { - log.Println("Not auth header for", realIp, ",", lastIP) - http.Error(w, "need authorizaton: bearer xyz header", http.StatusUnauthorized) - return - } - // Assuming "Bearer " prefix - tokenString := a[7:] - - // Parse and verify the token - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - - // Check audience claim - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("error retrieving claims from token") - } - if !claims.VerifyAudience(aud, true) { - return nil, fmt.Errorf("invalid audience. Expected: %s", aud) - } - - kid, ok := token.Header["kid"].(string) - if !ok { - return nil, fmt.Errorf("expecting JWT header to have string 'kid'") - } - - // Find the appropriate key in JWKS - key, ok := set.LookupKeyID(kid) - if !ok { - return nil, fmt.Errorf("unable to find key '%s'", kid) - } - - var pubkey interface{} - if err := key.Raw(&pubkey); err != nil { - return nil, fmt.Errorf("failed to get raw key: %v", err) - } - - return pubkey, nil - }) - - if err != nil { - log.Println("Failed to verify token for", realIp, ",", lastIP, err.Error()) - http.Error(w, "Failed to verify token.", http.StatusUnauthorized) - return - } - - if !token.Valid { - log.Println("Invalid token for", realIp, ",", lastIP, err.Error()) - http.Error(w, "Invalid token", http.StatusUnauthorized) - return - } - - // TODO make this more customizable - // but for now this fills the need - cmd := exec.Command("/bin/bash", "/rollout.sh") - - var stdErr bytes.Buffer - cmd.Stderr = &stdErr - cmd.Env = os.Environ() - if err := cmd.Run(); err != nil { - log.Printf("Error running %s command: %s", cmd.String(), stdErr.String()) - http.Error(w, "Script execution failed", http.StatusInternalServerError) - return - } - - log.Println("Rollout complete for", realIp, ",", lastIP) - fmt.Fprintln(w, "Rollout complete") - }) - - fmt.Println("Server is running on http://localhost:8080/") - err = http.ListenAndServe(":8080", nil) + http.HandleFunc("/", Rollout) + log.Println("Server is running on :8080") + err := http.ListenAndServe(":8080", nil) if err != nil { log.Fatal("Unable to start service") } @@ -132,3 +44,88 @@ func readUserIP(r *http.Request) (string, string) { } return realIP, lastIP } + +func Rollout(w http.ResponseWriter, r *http.Request) { + realIp, lastIP := readUserIP(r) + + a := r.Header.Get("Authorization") + if len(a) < 10 { + log.Println("Not auth header for", realIp, ",", lastIP) + http.Error(w, "need authorizaton: bearer xyz header", http.StatusUnauthorized) + return + } + // Assuming "Bearer " prefix + tokenString := a[7:] + + // Parse and verify the token + token, err := jwt.Parse(tokenString, ParseToken) + if err != nil { + log.Println("Failed to verify token for", realIp, ",", lastIP, err.Error()) + http.Error(w, "Failed to verify token.", http.StatusUnauthorized) + return + } + + if !token.Valid { + log.Println("Invalid token for", realIp, ",", lastIP, err.Error()) + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + // TODO make this more customizable + // but for now this fills the need + cmd := exec.Command("/bin/bash", "/rollout.sh") + + var stdOut, stdErr bytes.Buffer + cmd.Stdout = &stdOut + cmd.Stderr = &stdErr + cmd.Env = os.Environ() + if err := cmd.Run(); err != nil { + log.Printf("Error running %s command: %s", cmd.String(), stdOut.String()) + log.Printf("stderr: %s", stdErr.String()) + http.Error(w, "Script execution failed", http.StatusInternalServerError) + return + } + + log.Println("Rollout complete for", realIp, ",", lastIP) + fmt.Fprintln(w, "Rollout complete") +} + +func ParseToken(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + // Check audience claim + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("error retrieving claims from token") + } + aud := os.Getenv("JWT_AUD") + if !claims.VerifyAudience(aud, true) { + return nil, fmt.Errorf("invalid audience. Expected: %s", aud) + } + + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("expecting JWT header to have string 'kid'") + } + + ctx := context.Background() + gitLabJwksURL = os.Getenv("JWKS_URI") + jwksSet, err := jwk.Fetch(ctx, gitLabJwksURL) + if err != nil { + log.Fatalf("Unable to fetch JWK set from %s: %v", gitLabJwksURL, err) + } + // Find the appropriate key in JWKS + key, ok := jwksSet.LookupKeyID(kid) + if !ok { + return nil, fmt.Errorf("unable to find key '%s'", kid) + } + + var pubkey interface{} + if err := key.Raw(&pubkey); err != nil { + return nil, fmt.Errorf("failed to get raw key: %v", err) + } + + return pubkey, nil +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..5e13a72 --- /dev/null +++ b/main_test.go @@ -0,0 +1,171 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "log" + "math/big" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" +) + +// createJWKS creates a JWKS JSON representation with a single RSA key. +func mockJWKS(pub *rsa.PublicKey, kid string) (string, error) { + jwks := struct { + Keys []map[string]interface{} `json:"keys"` + }{ + Keys: []map[string]interface{}{ + { + "kty": "RSA", + "use": "sig", + "kid": kid, + "alg": "RS256", + "n": encodeBigInt(pub.N), + "e": encodeBigInt(big.NewInt(int64(pub.E))), + }, + }, + } + + jwksJSON, err := json.Marshal(jwks) + if err != nil { + return "", fmt.Errorf("failed to marshal JWKS: %v", err) + } + return string(jwksJSON), nil +} + +// GenerateRSAKeys generates and returns RSA private and public keys. +func GenerateRSAKeys() (*rsa.PrivateKey, *rsa.PublicKey, error) { + // Generate a private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %v", err) + } + // Extract the public key from the private key + publicKey := &privateKey.PublicKey + return privateKey, publicKey, nil +} + +// encodeBigInt encodes big integers like RSA modulus and exponent to the +// Base64 URL-encoded format used in JWKS. +func encodeBigInt(n *big.Int) string { + return base64.RawURLEncoding.EncodeToString(n.Bytes()) +} + +// Set up the mock server +func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + jwks, err := mockJWKS(pub, kid) + if err != nil { + log.Fatalf("Unable to generate RSA keys: %v", err) + } + + _, err = w.Write([]byte(jwks)) + if err != nil { + log.Fatalf("Unable to generate RSA keys: %v", err) + } + }) + + return httptest.NewServer(handler) +} + +func CreateSignedJWT(kid, aud string, exp int64, privateKey *rsa.PrivateKey) (string, error) { + // Define the claims of the token. You can add more claims based on your needs. + claims := jwt.MapClaims{ + "sub": "1234567890", + "aud": aud, + "iat": time.Now().Unix(), + "exp": exp, + } + + // Create a new token object with the claims and the signing method + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + + token.Header["kid"] = kid + + // Sign the token with the private key + signedToken, err := token.SignedString(privateKey) + if err != nil { + return "", fmt.Errorf("failed to sign the token: %v", err) + } + + return signedToken, nil +} + +func TestTokenVerification(t *testing.T) { + os.Setenv("JWT_AUD", "test-success") + + kid := "no-kidding" + aud := os.Getenv("JWT_AUD") + privateKey, publicKey, err := GenerateRSAKeys() + if err != nil { + log.Fatalf("Unable to generate RSA keys: %v", err) + } + server := setupMockJwksServer(publicKey, kid) + defer server.Close() + + jwkURL := fmt.Sprintf("%s/oauth/discovery/keys", server.URL) + os.Setenv("JWKS_URI", jwkURL) + + // make sure valid tokens succeed + exp := time.Now().Add(time.Hour * 24).Unix() + jwtToken, err := CreateSignedJWT(kid, aud, exp, privateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our test key: %v", err) + } + token, err := jwt.Parse(jwtToken, ParseToken) + assert.NoError(t, err) + assert.True(t, token.Valid) + + // make sure invalid kids fail + jwtToken, err = CreateSignedJWT("just-kidding", aud, exp, privateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our test key: %v", err) + } + token, err = jwt.Parse(jwtToken, ParseToken) + assert.Error(t, err) + assert.False(t, token.Valid) + + // make sure if we pass a JWT signed by another private key it fails + badPrivateKey, _, err := GenerateRSAKeys() + if err != nil { + t.Fatalf("Unable to generate a new private key") + } + jwtToken, err = CreateSignedJWT(kid, aud, exp, badPrivateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our new test key: %v", err) + } + token, err = jwt.Parse(jwtToken, ParseToken) + assert.Error(t, err) + assert.False(t, token.Valid) + + // make sure expired JWTs fail + expired := time.Now().Add(time.Hour * -1).Unix() + jwtToken, err = CreateSignedJWT(kid, aud, expired, privateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our test key: %v", err) + } + token, err = jwt.Parse(jwtToken, ParseToken) + assert.Error(t, err) + assert.False(t, token.Valid) + + // make sure bad audience JWTs fail + jwtToken, err = CreateSignedJWT(kid, "different-audience", exp, privateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our test key: %v", err) + } + token, err = jwt.Parse(jwtToken, ParseToken) + assert.Error(t, err) + assert.False(t, token.Valid) +}