diff --git a/main.go b/main.go index 626e195..ae3a5aa 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,7 @@ import ( "context" "encoding/json" "fmt" - "log" + "log/slog" "net/http" "os" "os/exec" @@ -22,17 +22,20 @@ func init() { func main() { if os.Getenv("JWKS_URI") == "" { - log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys") + slog.Error("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys") + os.Exit(1) } if os.Getenv("JWT_AUD") == "" { - log.Fatal("JWT_AUD is required. This needs to be the aud in the JWT you except this service to handle.") + slog.Error("JWT_AUD is required. This needs to be the aud in the JWT you except this service to handle.") + os.Exit(1) } http.HandleFunc("/", Rollout) - log.Println("Server is running on :8080") + slog.Info("Server is running on :8080") err := http.ListenAndServe(":8080", nil) if err != nil { - log.Fatal("Unable to start service") + slog.Error("Unable to start service") + os.Exit(1) } } @@ -41,7 +44,7 @@ func Rollout(w http.ResponseWriter, r *http.Request) { a := r.Header.Get("Authorization") if len(a) < 10 { - log.Println("Not auth header for", realIp, ",", lastIP) + slog.Info("Not auth header", "forwarded-ip", realIp, "lasthop-ip", lastIP) http.Error(w, "need authorizaton: bearer xyz header", http.StatusUnauthorized) return } @@ -51,13 +54,13 @@ func Rollout(w http.ResponseWriter, r *http.Request) { // Parse and verify the token token, err := jwt.Parse(tokenString, ParseToken) if err != nil { - log.Println("Failed to verify token for", realIp, ",", lastIP, err.Error()) + slog.Info("Failed to verify token for", "forwarded-ip", realIp, "lasthop-ip", lastIP, "err", err.Error()) http.Error(w, "Failed to verify token.", http.StatusUnauthorized) return } if !token.Valid { - log.Println("Invalid token for", realIp, ",", lastIP) + slog.Info("Invalid token for", "forwarded-ip", realIp, "lasthop-ip", lastIP) http.Error(w, "Invalid token", http.StatusUnauthorized) return } @@ -68,19 +71,19 @@ func Rollout(w http.ResponseWriter, r *http.Request) { var cc map[string]string err = json.Unmarshal([]byte(ccStr), &cc) if err != nil { - log.Println("Unable to read token claims for", realIp, ",", lastIP) + slog.Info("Unable to read token claims", "forwarded-ip", realIp, "lasthop-ip", lastIP) http.Error(w, "Invalid token", http.StatusUnauthorized) return } for k, v := range cc { if claims[k] != v { - log.Println("Claim for", k, "doesn't match", realIp, ",", lastIP) + slog.Info("Claim doesn't match", "claim", k, "forwarded-ip", realIp, "lasthop-ip", lastIP) http.Error(w, "Invalid token", http.StatusUnauthorized) return } } } else if !ok { - log.Println("Unable to read token claims for", realIp, ",", lastIP) + slog.Info("Unable to read token claims", "forwarded-ip", realIp, "lasthop-ip", lastIP) http.Error(w, "Invalid token", http.StatusUnauthorized) return } @@ -96,13 +99,12 @@ func Rollout(w http.ResponseWriter, r *http.Request) { cmd.Stderr = &stdErr cmd.Env = os.Environ() if err := cmd.Run(); err != nil { - log.Printf("Error running %s command: %s", cmd.String(), stdOut.String()) - log.Printf("stderr: %s", stdErr.String()) + slog.Error("Error running", "command", cmd.String(), "stdout", stdOut.String(), "stderr", stdErr.String()) http.Error(w, "Script execution failed", http.StatusInternalServerError) return } - log.Println("Rollout complete for", realIp, ",", lastIP) + slog.Info("Rollout complete for", "forwarded-ip", realIp, "lasthop-ip", lastIP) fmt.Fprintln(w, "Rollout complete") } @@ -130,7 +132,7 @@ func ParseToken(token *jwt.Token) (interface{}, error) { jwksUri := os.Getenv("JWKS_URI") jwksSet, err := jwk.Fetch(ctx, jwksUri) if err != nil { - log.Fatalf("Unable to fetch JWK set from %s: %v", jwksUri, err) + return nil, fmt.Errorf("Unable to fetch JWK set from %s: %v", jwksUri, err) } // Find the appropriate key in JWKS key, ok := jwksSet.LookupKeyID(kid) @@ -171,7 +173,8 @@ func getArgs() []string { } rolloutArgs, err := shlex.Split(args) if err != nil { - log.Fatalf("Error parsing ROLLOUT_ARGS %s: %v", args, err) + slog.Error("Error parsing ROLLOUT_ARGS", args, "err", err) + os.Exit(1) } return rolloutArgs diff --git a/main_test.go b/main_test.go index f3ee9b0..c1cfd5c 100644 --- a/main_test.go +++ b/main_test.go @@ -6,7 +6,7 @@ import ( "encoding/base64" "encoding/json" "fmt" - "log" + "log/slog" "math/big" "net/http" "net/http/httptest" @@ -73,12 +73,14 @@ func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server { w.WriteHeader(http.StatusOK) jwks, err := mockJWKS(pub, kid) if err != nil { - log.Fatalf("Unable to generate RSA keys: %v", err) + slog.Error("Unable to generate RSA keys", "err", err) + os.Exit(1) } _, err = w.Write([]byte(jwks)) if err != nil { - log.Fatalf("Unable to generate RSA keys: %v", err) + slog.Error("Unable to generate RSA keys", "err", err) + os.Exit(1) } }) @@ -95,7 +97,8 @@ func createMockJwksServer() *httptest.Server { claim = "bar" privateKey, publicKey, err = GenerateRSAKeys() if err != nil { - log.Fatalf("Unable to generate RSA keys: %v", err) + slog.Error("Unable to generate RSA keys", "err", err) + os.Exit(1) } testServer := setupMockJwksServer(publicKey, kid) os.Setenv("JWKS_URI", fmt.Sprintf("%s/oauth/discovery/keys", testServer.URL)) @@ -144,7 +147,8 @@ func TestRollout(t *testing.T) { // make sure the test file doesn't exist err := RemoveFileIfExists(testFile) if err != nil { - log.Fatalf("Unable to cleanup test file: %v", err) + slog.Error("Unable to cleanup test file", "err", err) + os.Exit(1) } s := createMockJwksServer() @@ -298,7 +302,8 @@ func TestRollout(t *testing.T) { // cleanup err = RemoveFileIfExists(f) if err != nil { - log.Fatalf("Unable to cleanup test file: %v", err) + slog.Error("Unable to cleanup test file", "file", f, "err", err) + os.Exit(1) } } }