Skip to content

Commit

Permalink
[minor] allow setting the rollout command and args (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall authored Apr 15, 2024
1 parent 82c0c52 commit 3cbe7fe
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 48 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ This service requires two envionrment variables.

- `JWKS_URI` - the URL of the OIDC Provider's [JSON Web Key (JWK) set document](https://www.rfc-editor.org/info/rfc7517). This is used to ensure the JWT was signed by the provider.
- `JWT_AUD` - the audience set in the JWT token.
- `ROLLOUT_CMD` (default: `/bin/bash`) - the command to execute a rollout
- `ROLLOUT_ARGS` (default: `/rollout.sh` ) - the args to pass to `ROLLOUT_CMD`

### GitHub

Expand All @@ -55,4 +57,3 @@ JWT_AUD=aud-string-you-set-in-your-job
- [ ] Install instructions using binary
- [ ] Tag/push versions to dockerhub
- [ ] Allow more custom auth handling
- [ ] Allow more custom rollout than a single bash script
47 changes: 24 additions & 23 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,18 @@ import (
"net/http"
"os"
"os/exec"
"strings"

"github.com/golang-jwt/jwt/v4"
"github.com/lestrrat-go/jwx/jwk"
)

var (
gitLabJwksURL, aud string
)

func main() {

gitLabJwksURL = os.Getenv("JWKS_URI")
if gitLabJwksURL == "" {
if os.Getenv("JWKS_URI") == "" {
log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys")
}
aud = os.Getenv("JWT_AUD")
if aud == "" {
if os.Getenv("JWT_AUD") == "" {
log.Fatal("JWT_AUD is required. This needs to be the aud in the JWT you except this service to handle.")
}

Expand All @@ -36,15 +31,6 @@ func main() {
}
}

func readUserIP(r *http.Request) (string, string) {
realIP := r.Header.Get("X-Real-Ip")
lastIP := r.RemoteAddr
if realIP == "" {
realIP = r.Header.Get("X-Forwarded-For")
}
return realIP, lastIP
}

func Rollout(w http.ResponseWriter, r *http.Request) {
realIp, lastIP := readUserIP(r)

Expand All @@ -71,9 +57,15 @@ func Rollout(w http.ResponseWriter, r *http.Request) {
return
}

// TODO make this more customizable
// but for now this fills the need
cmd := exec.Command("/bin/bash", "/rollout.sh")
name := os.Getenv("ROLLOUT_CMD")
if name == "" {
name = "/bin/bash"
}
args := os.Getenv("ROLLOUT_ARGS")
if args == "" {
args = "/rollout.sh"
}
cmd := exec.Command(name, strings.Split(args, " ")...)

var stdOut, stdErr bytes.Buffer
cmd.Stdout = &stdOut
Expand Down Expand Up @@ -111,10 +103,10 @@ func ParseToken(token *jwt.Token) (interface{}, error) {
}

ctx := context.Background()
gitLabJwksURL = os.Getenv("JWKS_URI")
jwksSet, err := jwk.Fetch(ctx, gitLabJwksURL)
jwksUri := os.Getenv("JWKS_URI")
jwksSet, err := jwk.Fetch(ctx, jwksUri)
if err != nil {
log.Fatalf("Unable to fetch JWK set from %s: %v", gitLabJwksURL, err)
log.Fatalf("Unable to fetch JWK set from %s: %v", jwksUri, err)
}
// Find the appropriate key in JWKS
key, ok := jwksSet.LookupKeyID(kid)
Expand All @@ -129,3 +121,12 @@ func ParseToken(token *jwt.Token) (interface{}, error) {

return pubkey, nil
}

func readUserIP(r *http.Request) (string, string) {
realIP := r.Header.Get("X-Real-Ip")
lastIP := r.RemoteAddr
if realIP == "" {
realIP = r.Header.Get("X-Forwarded-For")
}
return realIP, lastIP
}
139 changes: 115 additions & 24 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,27 @@ func CreateSignedJWT(kid, aud string, exp int64, privateKey *rsa.PrivateKey) (st
return signedToken, nil
}

func TestTokenVerification(t *testing.T) {
os.Setenv("JWT_AUD", "test-success")
// Utility function to create a request with an Authorization header
func createRequest(authHeader string) *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", authHeader)
return req
}

// TestRollout tests the Rollout function with various scenarios
func TestRollout(t *testing.T) {
testFile := "/tmp/rollout-test.txt"
os.Setenv("ROLLOUT_CMD", "touch")
os.Setenv("ROLLOUT_ARGS", testFile)

// make sure the test file doesn't exist
err := RemoveFileIfExists(testFile)
if err != nil {
log.Fatalf("Unable to cleanup test file: %v", err)
}

// mock the JWKS server response
os.Setenv("JWT_AUD", "test-success")
kid := "no-kidding"
aud := os.Getenv("JWT_AUD")
privateKey, publicKey, err := GenerateRSAKeys()
Expand All @@ -114,58 +132,131 @@ func TestTokenVerification(t *testing.T) {
}
server := setupMockJwksServer(publicKey, kid)
defer server.Close()

jwkURL := fmt.Sprintf("%s/oauth/discovery/keys", server.URL)
os.Setenv("JWKS_URI", jwkURL)

// make sure valid tokens succeed
exp := time.Now().Add(time.Hour * 24).Unix()
// get a valid token
exp := time.Now().Add(time.Hour * 1).Unix()
jwtToken, err := CreateSignedJWT(kid, aud, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
token, err := jwt.Parse(jwtToken, ParseToken)
assert.NoError(t, err)
assert.True(t, token.Valid)

// make sure invalid kids fail
jwtToken, err = CreateSignedJWT("just-kidding", aud, exp, privateKey)
badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
token, err = jwt.Parse(jwtToken, ParseToken)
assert.Error(t, err)
assert.False(t, token.Valid)

// make sure if we pass a JWT signed by another private key it fails
badPrivateKey, _, err := GenerateRSAKeys()
if err != nil {
t.Fatalf("Unable to generate a new private key")
}
jwtToken, err = CreateSignedJWT(kid, aud, exp, badPrivateKey)
badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, exp, badPrivateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our new test key: %v", err)
}
token, err = jwt.Parse(jwtToken, ParseToken)
assert.Error(t, err)
assert.False(t, token.Valid)

// make sure expired JWTs fail
expired := time.Now().Add(time.Hour * -1).Unix()
jwtToken, err = CreateSignedJWT(kid, aud, expired, privateKey)
expiredJwtToken, err := CreateSignedJWT(kid, aud, expired, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
token, err = jwt.Parse(jwtToken, ParseToken)
assert.Error(t, err)
assert.False(t, token.Valid)

// make sure bad audience JWTs fail
jwtToken, err = CreateSignedJWT(kid, "different-audience", exp, privateKey)
badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
token, err = jwt.Parse(jwtToken, ParseToken)
assert.Error(t, err)
assert.False(t, token.Valid)

// Define test cases
tests := []struct {
name string
authHeader string
expectedStatus int
expectedBody string
}{
{
name: "No Authorization Header",
authHeader: "",
expectedStatus: http.StatusUnauthorized,
expectedBody: "need authorizaton: bearer xyz header\n",
},
{
name: "Invalid Token",
authHeader: "Bearer invalidtoken",
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Bad kid Token",
authHeader: "Bearer " + badKidJwtToken,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Signed from wrong JWKS Token",
authHeader: "Bearer " + badPrivKeyjwtToken,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Expired Token",
authHeader: "Bearer " + expiredJwtToken,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Bad aud Token",
authHeader: "Bearer " + badAudJwtToken,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Valid Token and Successful Command",
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
expectedBody: "Rollout complete\n",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
request := createRequest(tt.authHeader)

Rollout(recorder, request)

assert.Equal(t, tt.expectedStatus, recorder.Code)
assert.Equal(t, tt.expectedBody, recorder.Body.String())
})
}

// make sure the rollout command actually ran the command
_, err = os.Stat(testFile)
if err != nil && os.IsNotExist(err) {
t.Errorf("The successful test did not create the expected file")
}

// cleanup
err = RemoveFileIfExists(testFile)
if err != nil {
log.Fatalf("Unable to cleanup test file: %v", err)
}
}

func RemoveFileIfExists(filePath string) error {
_, err := os.Stat(filePath)
if err == nil {
err := os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove file: %v", err)
}
} else if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("error checking file: %v", err)
}

return nil
}

0 comments on commit 3cbe7fe

Please sign in to comment.