-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(middlewares/auth/temporary): 添加用于创建临时令牌的包
- Loading branch information
Showing
3 changed files
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
// SPDX-FileCopyrightText: 2024 caixw | ||
// | ||
// SPDX-License-Identifier: MIT | ||
|
||
// Package temporary 用于创建一个一次性的令牌 | ||
package temporary | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
"time" | ||
|
||
"github.com/issue9/cache" | ||
"github.com/issue9/mux/v9/header" | ||
"github.com/issue9/web" | ||
"github.com/issue9/web/openapi" | ||
|
||
"github.com/issue9/webuse/v7/internal/mauth" | ||
"github.com/issue9/webuse/v7/middlewares/auth" | ||
) | ||
|
||
type tokenType int | ||
|
||
const tokenContext tokenType = 0 | ||
|
||
type Response struct { | ||
XMLName struct{} `json:"-" cbor:"-" xml:"token" yaml:"-"` | ||
Token string `json:"token" xml:"token" cbor:"token" comment:"access token"` // 访问令牌 | ||
Expire int `json:"expire" xml:"expire,attr" cbor:"expire" comment:"access token expired"` // 访问令牌的有效时长,单位为秒 | ||
} | ||
|
||
type Temporary[T any] struct { | ||
cache web.Cache | ||
ttl time.Duration | ||
expire int | ||
once bool | ||
unauthProblemID string | ||
invalidTokenProblemID string | ||
} | ||
|
||
// New 创建 [Temporary] 对象 | ||
// | ||
// ttl 表示令牌的过期时间。 | ||
// once 是否为一次性令牌,如果为 true,在验证成功之后,该令牌将自动失效; | ||
// unauthProblemID 验证不通过时的错误代码; | ||
// invalidTokenProblemID 令牌无效时返回的错误代码; | ||
func New[T any](s web.Server, ttl time.Duration, once bool, unauthProblemID, invalidTokenProblemID string) *Temporary[T] { | ||
return &Temporary[T]{ | ||
cache: web.NewCache(s.UniqueID(), s.Cache()), | ||
ttl: ttl, | ||
expire: int(ttl.Seconds()), | ||
once: once, | ||
unauthProblemID: unauthProblemID, | ||
invalidTokenProblemID: invalidTokenProblemID, | ||
} | ||
} | ||
|
||
// New 创建令牌 | ||
// | ||
// v 为令牌关联的数据,之后通过验证接口可以访问该数据; | ||
func (t *Temporary[T]) New(ctx *web.Context, v T, status int) web.Responser { | ||
token := ctx.Server().UniqueID() | ||
if err := t.cache.Set(token, v, t.ttl); err != nil { | ||
return ctx.Error(err, "") | ||
} | ||
|
||
return web.Response(status, &Response{Token: token, Expire: t.expire}) | ||
} | ||
|
||
func (t *Temporary[T]) Middleware(next web.HandlerFunc, method, _, _ string) web.HandlerFunc { | ||
if method == http.MethodOptions { | ||
return next | ||
} | ||
|
||
return func(ctx *web.Context) web.Responser { | ||
token := auth.GetBearerToken(ctx, header.Authorization) | ||
if token == "" { | ||
return ctx.Problem(t.unauthProblemID) | ||
} | ||
|
||
var v T | ||
err := t.cache.Get(token, &v) | ||
switch { | ||
case errors.Is(err, cache.ErrCacheMiss()): | ||
return ctx.Problem(t.unauthProblemID) | ||
case err != nil: | ||
return ctx.Error(err, t.invalidTokenProblemID) | ||
default: | ||
mauth.Set(ctx, v) | ||
ctx.SetVar(tokenContext, token) | ||
|
||
if t.once { | ||
if err := t.cache.Delete(token); err != nil { | ||
ctx.Server().Logs().ERROR().Error(err) // 只记录错误,不反馈给客户端。 | ||
} | ||
} | ||
|
||
return next(ctx) | ||
} | ||
} | ||
} | ||
|
||
func (t *Temporary[T]) Logout(ctx *web.Context) error { | ||
if key, found := ctx.GetVar(tokenContext); found { | ||
return t.cache.Delete(key.(string)) | ||
} | ||
return nil | ||
} | ||
|
||
func (t *Temporary[T]) GetInfo(ctx *web.Context) (T, bool) { | ||
return mauth.Get[T](ctx) | ||
} | ||
|
||
// SecurityScheme 声明支持 openapi 的 [openapi.SecurityScheme] 对象 | ||
func SecurityScheme(id string, desc web.LocaleStringer) *openapi.SecurityScheme { | ||
return &openapi.SecurityScheme{ | ||
ID: id, | ||
Type: openapi.SecuritySchemeTypeHTTP, | ||
Description: desc, | ||
Scheme: auth.Bearer, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// SPDX-FileCopyrightText: 2024 caixw | ||
// | ||
// SPDX-License-Identifier: MIT | ||
|
||
package temporary | ||
|
||
import ( | ||
"encoding/json" | ||
"net/http" | ||
"testing" | ||
"time" | ||
|
||
"github.com/issue9/assert/v4" | ||
"github.com/issue9/mux/v9/header" | ||
"github.com/issue9/web" | ||
"github.com/issue9/web/server/servertest" | ||
|
||
"github.com/issue9/webuse/v7/internal/testserver" | ||
"github.com/issue9/webuse/v7/middlewares/auth" | ||
) | ||
|
||
var _ auth.Auth[string] = &Temporary[string]{} | ||
|
||
func TestTemporary(t *testing.T) { | ||
a := assert.New(t, false) | ||
s := testserver.New(a) | ||
|
||
temp := New[string](s, time.Second, true, web.ProblemForbidden, web.ProblemBadRequest) | ||
a.NotNil(temp) | ||
s.Routers() | ||
|
||
r := s.Routers().New("default", nil) | ||
r.Post("/login", func(ctx *web.Context) web.Responser { | ||
return temp.New(ctx, "5", http.StatusCreated) | ||
}) | ||
|
||
r.Get("/info", func(ctx *web.Context) web.Responser { | ||
if info, ok := temp.GetInfo(ctx); ok { | ||
return web.OK(info) // info == /login 中传递的值 "5" | ||
} | ||
panic("永远不可能达到此处") | ||
}, temp) | ||
|
||
defer servertest.Run(a, s)() | ||
defer s.Close(0) | ||
|
||
// 未登录 | ||
servertest.Get(a, "http://localhost:8080/info"). | ||
Do(nil). | ||
Status(http.StatusForbidden) | ||
|
||
servertest.Post(a, "http://localhost:8080/login", nil). | ||
Do(nil). | ||
Status(http.StatusCreated). | ||
BodyFunc(func(a *assert.Assertion, body []byte) { | ||
resp := &Response{} | ||
a.NotError(json.Unmarshal(body, resp)). | ||
NotEmpty(resp.Token). | ||
Equal(1, resp.Expire) | ||
|
||
// 正常访问 | ||
servertest.Get(a, "http://localhost:8080/info"). | ||
Header(header.Authorization, auth.BearerToken(resp.Token)). | ||
Do(nil). | ||
Status(http.StatusOK). | ||
StringBody(`"5"`) | ||
|
||
// 再次访问,令牌失效 | ||
servertest.Get(a, "http://localhost:8080/info"). | ||
Header(header.Authorization, auth.BearerToken(resp.Token)). | ||
Do(nil). | ||
Status(http.StatusForbidden) | ||
}) | ||
} |