Skip to content

Commit

Permalink
Merge branch 'go-oauth2-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmethurst committed Sep 4, 2021
2 parents 902aba1 + 73154db commit 9e4344c
Show file tree
Hide file tree
Showing 13 changed files with 1,282 additions and 543 deletions.
2 changes: 2 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[*.go]
end_of_line = crlf
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ _testmain.go
*.test
*.prof

coverage.txt

# OSX
*.DS_Store
*.db
Expand Down
141 changes: 141 additions & 0 deletions example/client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package main

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

const (
authServerURL = "http://localhost:9096"
)

var (
config = oauth2.Config{
ClientID: "222222",
ClientSecret: "22222222",
Scopes: []string{"all"},
RedirectURL: "http://localhost:9094/oauth2",
Endpoint: oauth2.Endpoint{
AuthURL: authServerURL + "/oauth/authorize",
TokenURL: authServerURL + "/oauth/token",
},
}
globalToken *oauth2.Token // Non-concurrent security
)

func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
u := config.AuthCodeURL("xyz",
oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")),
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
http.Redirect(w, r, u, http.StatusFound)
})

http.HandleFunc("/oauth2", func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
state := r.Form.Get("state")
if state != "xyz" {
http.Error(w, "State invalid", http.StatusBadRequest)
return
}
code := r.Form.Get("code")
if code == "" {
http.Error(w, "Code not found", http.StatusBadRequest)
return
}
token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
globalToken = token

e := json.NewEncoder(w)
e.SetIndent("", " ")
e.Encode(token)
})

http.HandleFunc("/refresh", func(w http.ResponseWriter, r *http.Request) {
if globalToken == nil {
http.Redirect(w, r, "/", http.StatusFound)
return
}

globalToken.Expiry = time.Now()
token, err := config.TokenSource(context.Background(), globalToken).Token()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

globalToken = token
e := json.NewEncoder(w)
e.SetIndent("", " ")
e.Encode(token)
})

http.HandleFunc("/try", func(w http.ResponseWriter, r *http.Request) {
if globalToken == nil {
http.Redirect(w, r, "/", http.StatusFound)
return
}

resp, err := http.Get(fmt.Sprintf("%s/test?access_token=%s", authServerURL, globalToken.AccessToken))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer resp.Body.Close()

io.Copy(w, resp.Body)
})

http.HandleFunc("/pwd", func(w http.ResponseWriter, r *http.Request) {
token, err := config.PasswordCredentialsToken(context.Background(), "test", "test")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

globalToken = token
e := json.NewEncoder(w)
e.SetIndent("", " ")
e.Encode(token)
})

http.HandleFunc("/client", func(w http.ResponseWriter, r *http.Request) {
cfg := clientcredentials.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
TokenURL: config.Endpoint.TokenURL,
}

token, err := cfg.Token(context.Background())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

e := json.NewEncoder(w)
e.SetIndent("", " ")
e.Encode(token)
})

log.Println("Client is running at 9094 port.Please open http://localhost:9094")
log.Fatal(http.ListenAndServe(":9094", nil))
}

func genCodeChallengeS256(s string) string {
s256 := sha256.Sum256([]byte(s))
return base64.URLEncoding.EncodeToString(s256[:])
}
240 changes: 240 additions & 0 deletions example/server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"time"

"github.com/go-session/session"
"github.com/superseriousbusiness/oauth2/v4/errors"
"github.com/superseriousbusiness/oauth2/v4/generates"
"github.com/superseriousbusiness/oauth2/v4/manage"
"github.com/superseriousbusiness/oauth2/v4/models"
"github.com/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/oauth2/v4/store"
)

var (
dumpvar bool
idvar string
secretvar string
domainvar string
portvar int
)

func init() {
flag.BoolVar(&dumpvar, "d", true, "Dump requests and responses")
flag.StringVar(&idvar, "i", "222222", "The client id being passed in")
flag.StringVar(&secretvar, "s", "22222222", "The client secret being passed in")
flag.StringVar(&domainvar, "r", "http://localhost:9094", "The domain of the redirect url")
flag.IntVar(&portvar, "p", 9096, "the base port for the server")
}

func main() {
flag.Parse()
if dumpvar {
log.Println("Dumping requests")
}
manager := manage.NewDefaultManager()
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)

// token store
manager.MustTokenStorage(store.NewMemoryTokenStore())

// generate jwt access token
// manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512))
manager.MapAccessGenerate(generates.NewAccessGenerate())

clientStore := store.NewClientStore()
clientStore.Set(context.Background(), idvar, models.New(idvar, secretvar, domainvar, ""))
manager.MapClientStorage(clientStore)

srv := server.NewServer(server.NewConfig(), manager)

srv.SetPasswordAuthorizationHandler(func(username, password string) (userID string, err error) {
if username == "test" && password == "test" {
userID = "test"
}
return
})

srv.SetUserAuthorizationHandler(userAuthorizeHandler)

srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
log.Println("Internal Error:", err.Error())
return
})

srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Println("Response Error:", re.Error.Error())
})

http.HandleFunc("/login", loginHandler)
http.HandleFunc("/auth", authHandler)

http.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) {
if dumpvar {
dumpRequest(os.Stdout, "authorize", r)
}

store, err := session.Start(r.Context(), w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

var form url.Values
if v, ok := store.Get("ReturnUri"); ok {
form = v.(url.Values)
}
r.Form = form

store.Delete("ReturnUri")
store.Save()

err = srv.HandleAuthorizeRequest(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
}
})

http.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
if dumpvar {
_ = dumpRequest(os.Stdout, "token", r) // Ignore the error
}

err := srv.HandleTokenRequest(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})

http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
if dumpvar {
_ = dumpRequest(os.Stdout, "test", r) // Ignore the error
}
token, err := srv.ValidationBearerToken(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

data := map[string]interface{}{
"expires_in": int64(token.GetAccessCreateAt().Add(token.GetAccessExpiresIn()).Sub(time.Now()).Seconds()),
"client_id": token.GetClientID(),
"user_id": token.GetUserID(),
}
e := json.NewEncoder(w)
e.SetIndent("", " ")
e.Encode(data)
})

log.Printf("Server is running at %d port.\n", portvar)
log.Printf("Point your OAuth client Auth endpoint to %s:%d%s", "http://localhost", portvar, "/oauth/authorize")
log.Printf("Point your OAuth client Token endpoint to %s:%d%s", "http://localhost", portvar, "/oauth/token")
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", portvar), nil))
}

func dumpRequest(writer io.Writer, header string, r *http.Request) error {
data, err := httputil.DumpRequest(r, true)
if err != nil {
return err
}
writer.Write([]byte("\n" + header + ": \n"))
writer.Write(data)
return nil
}

func userAuthorizeHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
if dumpvar {
_ = dumpRequest(os.Stdout, "userAuthorizeHandler", r) // Ignore the error
}
store, err := session.Start(r.Context(), w, r)
if err != nil {
return
}

uid, ok := store.Get("LoggedInUserID")
if !ok {
if r.Form == nil {
r.ParseForm()
}

store.Set("ReturnUri", r.Form)
store.Save()

w.Header().Set("Location", "/login")
w.WriteHeader(http.StatusFound)
return
}

userID = uid.(string)
store.Delete("LoggedInUserID")
store.Save()
return
}

func loginHandler(w http.ResponseWriter, r *http.Request) {
if dumpvar {
_ = dumpRequest(os.Stdout, "login", r) // Ignore the error
}
store, err := session.Start(r.Context(), w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if r.Method == "POST" {
if r.Form == nil {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
store.Set("LoggedInUserID", r.Form.Get("username"))
store.Save()

w.Header().Set("Location", "/auth")
w.WriteHeader(http.StatusFound)
return
}
outputHTML(w, r, "static/login.html")
}

func authHandler(w http.ResponseWriter, r *http.Request) {
if dumpvar {
_ = dumpRequest(os.Stdout, "auth", r) // Ignore the error
}
store, err := session.Start(nil, w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if _, ok := store.Get("LoggedInUserID"); !ok {
w.Header().Set("Location", "/login")
w.WriteHeader(http.StatusFound)
return
}

outputHTML(w, r, "static/auth.html")
}

func outputHTML(w http.ResponseWriter, req *http.Request, filename string) {
file, err := os.Open(filename)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer file.Close()
fi, _ := file.Stat()
http.ServeContent(w, req, file.Name(), fi.ModTime(), file)
}
Loading

0 comments on commit 9e4344c

Please sign in to comment.