Skip to content

Commit

Permalink
fix: return 401 for expired oauth2 tokens (#18)
Browse files Browse the repository at this point in the history
This change updates the OAuth2 middleware to return a 401 Unauthorized
response when a token is expired. The previous behavior of returning 400
was non-compliant. We've also added endpoints to introspect access
tokens.
  • Loading branch information
disintegrator authored Oct 30, 2024
1 parent 2cc2c38 commit ba9cff9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
1 change: 1 addition & 0 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func main() {
flag.Parse()

r := mux.NewRouter()
r.HandleFunc("/oauth2/token", auth.HandleOAuth2InspectToken).Methods(http.MethodGet)
r.HandleFunc("/oauth2/token", auth.HandleOAuth2).Methods(http.MethodPost)
r.HandleFunc("/auth", auth.HandleAuth).Methods(http.MethodPost)
r.HandleFunc("/auth/customsecurity/{customSchemeType}", auth.HandleCustomAuth).Methods(http.MethodGet)
Expand Down
55 changes: 55 additions & 0 deletions internal/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -75,6 +76,38 @@ type OAuth2TokenResponse struct {
ExpiresIn int `json:"expires_in"`
}

func HandleOAuth2InspectToken(w http.ResponseWriter, r *http.Request) {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
w.Header().Set("Content-Type", "application/json")

authz := r.Header.Get("Authorization")
if authz == "" {
http.Error(w, `{"error": "unauthorized"}`, http.StatusUnauthorized)
return
}
if !strings.HasPrefix(authz, "Bearer ") {
http.Error(w, `{"error": "invalid authorization"}`, http.StatusBadRequest)
return
}

token := authz[len("Bearer "):]
claims, err := ParseToken(token)
if err != nil {
http.Error(w, `{"error": "invalid token"}`, http.StatusBadRequest)
return
}

updatedExpiry := GetTokenExpiry(claims)

claims["exp"] = float64(updatedExpiry.Unix())

if err := enc.Encode(claims); err != nil {
http.Error(w, `{"error": "failed to encode response"}`, http.StatusInternalServerError)
return
}
}

func HandleOAuth2(w http.ResponseWriter, r *http.Request) {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
Expand Down Expand Up @@ -165,6 +198,10 @@ func HandleOAuth2(w http.ResponseWriter, r *http.Request) {

now := time.Now()
expires := now.Add(time.Hour)
forcedExpiry := r.Header.Get("x-oauth2-expire-at")
if exp, err := time.Parse(time.RFC3339, forcedExpiry); err == nil {
expires = exp
}

accessTokenID := gofakeit.UUID()
accessTokenClaims := jwt.MapClaims{
Expand Down Expand Up @@ -248,6 +285,24 @@ func RefreshToken(refreshClaims jwt.MapClaims) {
tokenDB.Store(tokenID, expiry)
}

func GetTokenExpiry(tokenClaims jwt.MapClaims) time.Time {
tokenDBLastAccess.Store(time.Now())

tokenID := tokenClaims["id"].(string)

exp, found := tokenDB.Load(tokenID)
if found {
return exp.(time.Time)
}

expiryClaim, err := tokenClaims.GetExpirationTime()
if err != nil {
panic(err)
}

return expiryClaim.Time
}

func IsTokenExpired(tokenClaims jwt.MapClaims) bool {
tokenDBLastAccess.Store(time.Now())

Expand Down
3 changes: 2 additions & 1 deletion internal/middleware/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ func OAuth2(h http.Handler) http.Handler {
}

if auth.IsTokenExpired(claims) {
auth.SendOAuth2Error(w, auth.ErrCodeInvalidRequest, "token has expired")
w.Header().Set("Content-Type", "application/json")
http.Error(w, `{"error": "token has expired"}`, http.StatusUnauthorized)
return
}

Expand Down

0 comments on commit ba9cff9

Please sign in to comment.