diff --git a/README.md b/README.md index 2edbcf8..af61d65 100644 --- a/README.md +++ b/README.md @@ -145,9 +145,9 @@ def vip_roles_and_article_update(request: Request): ``` -### 依赖项 +### 依赖项(推荐) -- 推荐场景: 路由集合,FastAPI应用 +- 推荐场景: 单个路由,路由集合,FastAPI应用. ```python from fastapi import Depends @@ -156,11 +156,10 @@ from fastapi_user_auth.auth import Auth from fastapi_user_auth.auth.models import User -# 路由参数依赖项 -@app.get("/auth/admin_roles_depend_1") -def admin_roles(request: Request, - auth_result: Tuple[Auth, User] = Depends(auth.requires('admin')())): - return request.user +# 路由参数依赖项, 推荐使用此方式 +@app.get("/auth/admin_roles_depend_1") +def admin_roles(user: User = Depends(auth.get_current_user)): + return user # or request.user # 路径操作装饰器依赖项 @@ -200,6 +199,7 @@ from fastapi_user_auth.auth.models import User async def get_request_user(request: Request) -> Optional[User]: + # user= await auth.get_current_user(request) if await auth.requires('admin', response=False)(request): return request.user else: diff --git a/fastapi_user_auth/admin.py b/fastapi_user_auth/admin.py index 4df173d..f0d6e55 100644 --- a/fastapi_user_auth/admin.py +++ b/fastapi_user_auth/admin.py @@ -11,7 +11,7 @@ from fastapi_user_auth.auth.models import BaseUser, User, Group, Permission, Role from fastapi_user_auth.auth.schemas import UserLoginOut from pydantic import BaseModel -from sqlalchemy import insert, update +from sqlalchemy import insert, update, select from starlette import status from starlette.requests import Request from starlette.responses import Response @@ -126,10 +126,10 @@ async def handle( **kwargs ) -> BaseApiOut[BaseModel]: # self.schema_submit_out auth: Auth = request.auth - user = await auth.get_user_by_username(data.username) + user = await auth.db.scalar(select(self.user_model).where(self.user_model.username == data.username)) if user: return BaseApiOut(status = -1, msg = _('Username has been registered!'), data = None) - user = await auth.get_user_by_whereclause(self.user_model.email == data.email) + user = await auth.db.scalar(select(self.user_model).where(self.user_model.email == data.email)) if user: return BaseApiOut(status = -2, msg = _('Email has been registered!'), data = None) user = self.user_model.parse_obj(data) diff --git a/fastapi_user_auth/auth/auth.py b/fastapi_user_auth/auth/auth.py index d8aaec2..0c28fd9 100644 --- a/fastapi_user_auth/auth/auth.py +++ b/fastapi_user_auth/auth/auth.py @@ -5,7 +5,7 @@ from collections.abc import Coroutine from typing import Type, Any, TypeVar, Optional, Sequence, Tuple, Union, Callable, Generic -from fastapi import FastAPI, HTTPException, Depends, Form +from fastapi import FastAPI, HTTPException, Depends, Form, params from fastapi.security import OAuth2PasswordBearer from fastapi.security.utils import get_authorization_scheme_param from fastapi_amis_admin.crud.base import RouterMixin @@ -15,6 +15,7 @@ from fastapi_amis_admin.utils.translation import i18n as _ from passlib.context import CryptContext from pydantic import BaseModel, SecretStr +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy_database import AsyncDatabase, Database from sqlmodel import select @@ -46,16 +47,7 @@ def get_user_token(request: Request) -> Optional[str]: return token async def authenticate(self, request: Request) -> Tuple["Auth", Optional[_UserModelT]]: - if request.scope.get('auth'): # 防止重复授权 - return request.scope.get('auth'), request.scope.get('user') - request.scope["auth"], request.scope["user"] = self.auth, None - token = self.get_user_token(request) - if not token: - return self.auth, None - token_data = await self.token_store.read_token(token) - if token_data is not None: - request.scope["user"]: _UserModelT = await self.auth.get_user_by_username(token_data.username) - return request.auth, request.user + return self.auth, await self.auth.get_current_user(request) def attach_middleware(self, app: FastAPI): app.add_middleware(AuthenticationMiddleware, backend = self) # 添加auth中间件 @@ -78,14 +70,8 @@ def __init__( self.backend = self.backend or AuthBackend(self, token_store or DbTokenStore(self.db)) self.pwd_context = pwd_context - async def get_user_by_username(self, username: str) -> Optional[_UserModelT]: - return await self.get_user_by_whereclause(self.user_model.username == username) - - async def get_user_by_whereclause(self, *whereclause: Any) -> Optional[_UserModelT]: - return await self.db.async_scalar(select(self.user_model).where(*whereclause)) - async def authenticate_user(self, username: str, password: Union[str, SecretStr]) -> Optional[_UserModelT]: - user = await self.get_user_by_username(username) + user = await self.db.async_scalar(select(self.user_model).where(self.user_model.username == username)) if user: pwd = password.get_secret_value() if isinstance(password, SecretStr) else password pwd2 = user.password.get_secret_value() if isinstance(user.password, SecretStr) else user.password @@ -93,6 +79,25 @@ async def authenticate_user(self, username: str, password: Union[str, SecretStr] return user return None + @cached_property + def get_current_user(self): + async def _get_current_user( + request: Request, + session: Union[Session, AsyncSession, None] = Depends(self.db.session_generator) + ) -> Optional[_UserModelT]: + if request.scope.get('auth'): # 防止重复授权 + return request.scope.get('user') + request.scope["auth"], request.scope["user"] = self, None + token = self.backend.get_user_token(request) + if not token: + return None + token_data = await self.backend.token_store.read_token(token) + if token_data is not None: + request.scope["user"]: _UserModelT = await self.db.async_get(self.user_model, token_data.id, session = session) + return request.user + + return _get_current_user + def requires( self, roles: Union[str, Sequence[str]] = None, @@ -103,21 +108,22 @@ def requires( response: Union[bool, Response] = None, ) -> Callable: # sourcery no-metrics - async def has_requires(conn: HTTPConnection) -> bool: - # todo websocket support - await self.backend.authenticate(conn) # type:ignore - if not conn.user: - return False - return await self.db.async_run_sync( - conn.user.has_requires, + async def has_requires(user: _UserModelT) -> bool: + return user and await self.db.async_run_sync( + user.has_requires, roles = roles, groups = groups, permissions = permissions, is_session = True ) - async def depend(request: Request) -> Union[bool, Response]: - if not await has_requires(request): + async def depend( + request: Request, + user: _UserModelT = Depends(self.get_current_user), + ) -> Union[bool, Response]: + if isinstance(user, params.Depends): + user = await self.get_current_user(request) + if not await has_requires(user): if response is not None: return response code, headers = status_code, {} @@ -150,7 +156,8 @@ async def websocket_wrapper( ) -> None: websocket = kwargs.get("websocket", args[idx] if args else None) assert isinstance(websocket, WebSocket) - if not await has_requires(websocket): + user = await self.get_current_user(websocket) # type: ignore + if not await has_requires(user): await websocket.close() else: await func(*args, **kwargs) diff --git a/tests/test_auth/conftest.py b/tests/test_auth/conftest.py index 775afd3..fde6dfd 100644 --- a/tests/test_auth/conftest.py +++ b/tests/test_auth/conftest.py @@ -15,43 +15,41 @@ from fastapi_user_auth.auth.models import User from tests.conftest import async_db, sync_db - -@pytest.fixture(params=[async_db, sync_db]) +@pytest.fixture(params = [async_db, sync_db]) async def db(request) -> Union[Database, AsyncDatabase]: database = request.param - await database.async_run_sync(SQLModel.metadata.create_all, is_session=False) + await database.async_run_sync(SQLModel.metadata.create_all, is_session = False) yield database - await database.async_run_sync(SQLModel.metadata.drop_all, is_session=False) - + await database.async_run_sync(SQLModel.metadata.drop_all, is_session = False) app = FastAPI() # 创建auth实例 -auth = Auth(db=async_db) +auth = Auth(db = async_db) # 注册auth基础路由 -auth_router = AuthRouter(auth=auth) +auth_router = AuthRouter(auth = auth) app.include_router(auth_router.router) - class UserClient: + def __init__(self, client: TestClient = None, user: User = None) -> None: self.client: TestClient = client or TestClient(app) self.user: User = user - def get_login_client(username: str = None, password: str = None) -> UserClient: client = TestClient(app) if not username or not password: return UserClient() - response = client.post('/auth/gettoken', - data={'username': username, 'password': password}, - headers={"Content-Type": "application/x-www-form-urlencoded"}) + response = client.post( + '/auth/gettoken', + data = {'username': username, 'password': password}, + headers = {"Content-Type": "application/x-www-form-urlencoded"} + ) data = response.json() assert data['data']['access_token'] user = User.parse_obj(data['data']) assert user.is_active assert user.username == username - return UserClient(client=client, user=user) - + return UserClient(client = client, user = user) @pytest.fixture def logins(request) -> UserClient: @@ -64,52 +62,54 @@ def logins(request) -> UserClient: user = user_data.get(request.param) or {} return get_login_client(**user) - -@pytest.fixture(scope="session") +@pytest.fixture(scope = "session") async def prepare_database() -> AsyncGenerator[None, None]: - await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session=False) + await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session = False) yield - await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session=False) - + await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session = False) -@pytest.fixture(scope="session") +@pytest.fixture(scope = "session") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() - -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture(scope = "session", autouse = True) async def fake_users(prepare_database): await auth.db.async_run_sync(create_fake_users) - # noinspection PyTypeChecker def create_fake_users(session: Session): # init permission - admin_perm = Permission(key='admin', name='admin permission') - vip_perm = Permission(key='vip', name='vip permission') - test_perm = Permission(key='test', name='test permission') + admin_perm = Permission(key = 'admin', name = 'admin permission') + vip_perm = Permission(key = 'vip', name = 'vip permission') + test_perm = Permission(key = 'test', name = 'test permission') session.add_all([admin_perm, vip_perm, test_perm]) session.flush([admin_perm, vip_perm, test_perm]) # init role - admin_role = Role(key='admin', name='admin role', permissions=[admin_perm]) - vip_role = Role(key='vip', name='vip role', permissions=[vip_perm]) - test_role = Role(key='test', name='test role', permissions=[test_perm]) + admin_role = Role(key = 'admin', name = 'admin role', permissions = [admin_perm]) + vip_role = Role(key = 'vip', name = 'vip role', permissions = [vip_perm]) + test_role = Role(key = 'test', name = 'test role', permissions = [test_perm]) session.add_all([admin_role, vip_role, test_role]) session.flush([admin_role, vip_role, test_role]) # init group - admin_group = Group(key='admin', name='admin group', roles=[admin_role]) - vip_group = Group(key='vip', name='vip group', roles=[vip_role]) - test_group = Group(key='test', name='test group', roles=[test_role]) + admin_group = Group(key = 'admin', name = 'admin group', roles = [admin_role]) + vip_group = Group(key = 'vip', name = 'vip group', roles = [vip_role]) + test_group = Group(key = 'test', name = 'test group', roles = [test_role]) session.add_all([admin_group, vip_group, test_group]) session.flush([admin_group, vip_group, test_group]) # init user - admin_user = User(username='admin', password=auth.pwd_context.hash('admin'), email='admin@amis.work', - roles=[admin_role], groups=[admin_group]) - vip_user = User(username='vip', password=auth.pwd_context.hash('vip'), email='vip@amis.work', roles=[vip_role], - groups=[vip_group]) - test_user = User(username='test', password=auth.pwd_context.hash('test'), email='test@amis.work', roles=[test_role], - groups=[test_group]) + admin_user = User( + username = 'admin', password = auth.pwd_context.hash('admin'), email = 'admin@amis.work', + roles = [admin_role], groups = [admin_group] + ) + vip_user = User( + username = 'vip', password = auth.pwd_context.hash('vip'), email = 'vip@amis.work', roles = [vip_role], + groups = [vip_group] + ) + test_user = User( + username = 'test', password = auth.pwd_context.hash('test'), email = 'test@amis.work', roles = [test_role], + groups = [test_group] + ) session.add_all([admin_user, vip_user, test_user]) session.flush([admin_user, vip_user, test_user]) diff --git a/tests/test_auth/test_auth.py b/tests/test_auth/test_auth.py index f60304b..48f1694 100644 --- a/tests/test_auth/test_auth.py +++ b/tests/test_auth/test_auth.py @@ -4,7 +4,6 @@ from fastapi_user_auth.auth.models import User from tests.test_auth.conftest import auth - async def test_create_role_user(): user = await auth.create_role_user('admin2') assert user.username == 'admin2' @@ -14,7 +13,6 @@ async def test_create_role_user(): role = result.roles[0] assert role.key == 'admin2' - async def test_authenticate_user(): # error user = await auth.authenticate_user('admin', 'admin1') @@ -23,16 +21,3 @@ async def test_authenticate_user(): # admin user = await auth.authenticate_user('admin', 'admin') assert user.username == 'admin' - - -async def test_get_user_by_whereclause(): - user = await auth.get_user_by_whereclause(User.id == 1) - assert user.username == 'admin' - - user = await auth.get_user_by_whereclause(User.username == 'admin') - assert user.username == 'admin' - - -async def test_get_user_by_username(): - user = await auth.get_user_by_username('admin') - assert user.username == 'admin' diff --git a/tests/test_auth/test_auth_mount.py b/tests/test_auth/test_auth_mount.py index d7103ff..d7afd8e 100644 --- a/tests/test_auth/test_auth_mount.py +++ b/tests/test_auth/test_auth_mount.py @@ -11,39 +11,34 @@ app.mount('/subapp2', subapp2) auth.backend.attach_middleware(subapp2) -subapp3 = FastAPI(dependencies=[Depends(auth.requires('admin')())]) +subapp3 = FastAPI(dependencies = [Depends(auth.requires('admin')())]) app.mount('/subapp3', subapp3) - # auth decorator @subapp1.get("/auth/user") @auth.requires() def user(request: Request): return request.user - @subapp2.get("/auth/user") def user_2(request: Request): if request.user: return request.user else: - raise HTTPException(status_code=403) - + raise HTTPException(status_code = 403) @subapp3.get("/auth/user") @auth.requires() def user_3(request: Request): return request.user - path_admin_auth = { "/subapp1/auth/user", "/subapp2/auth/user", "/subapp3/auth/user", } - -@pytest.mark.parametrize("logins", ['admin'], indirect=True) +@pytest.mark.parametrize("logins", ['admin'], indirect = True) @pytest.mark.parametrize("path", list(path_admin_auth)) def test_admin_auth(logins: UserClient, path): response = logins.client.get(path) diff --git a/tests/test_auth/test_auth_requires.py b/tests/test_auth/test_auth_requires.py index 4476bec..94cdd10 100644 --- a/tests/test_auth/test_auth_requires.py +++ b/tests/test_auth/test_auth_requires.py @@ -9,115 +9,107 @@ from fastapi_user_auth.auth.models import User from tests.test_auth.conftest import UserClient, app, auth - # auth decorator @app.get("/auth/user") @auth.requires() def user(request: Request): return request.user - @app.get("/auth/admin_roles") @auth.requires('admin') def admin_roles(request: Request): return request.user - @app.get("/auth/vip_roles") @auth.requires(['vip']) def vip_roles(request: Request): return request.user - @app.get("/auth/admin_or_vip_roles") -@auth.requires(roles=['admin', 'vip']) +@auth.requires(roles = ['admin', 'vip']) def admin_or_vip_roles(request: Request): return request.user - # auth async decorator @app.get("/auth/admin_roles_async") @auth.requires('admin') async def admin_roles_async(request: Request): return request.user - # auth depend -@app.get("/auth/user_1", dependencies=[Depends(auth.backend.authenticate)]) +@app.get("/auth/user_1", dependencies = [Depends(auth.backend.authenticate)]) def user_1(request: Request): if request.user: return request.user else: - raise HTTPException(status_code=403) - + raise HTTPException(status_code = 403) @app.get("/auth/user_2") def user_2( - request: Request, - auth_result: Tuple[Auth, User] = Depends(auth.backend.authenticate) + request: Request, + auth_result: Tuple[Auth, User] = Depends(auth.backend.authenticate) ): if request.user: return request.user else: - raise HTTPException(status_code=403) - + raise HTTPException(status_code = 403) @app.get("/auth/user_3") async def user_3(request: Request): if await auth.requires()(request): return request.user +@app.get("/auth/user_4") +async def user_4(request: Request, user: User = Depends(auth.get_current_user)): + if user is None: + raise HTTPException(status_code = 403) + return request.user -@app.get("/auth/admin_roles_depend_1", dependencies=[Depends(auth.requires('admin')())]) +@app.get("/auth/admin_roles_depend_1", dependencies = [Depends(auth.requires('admin')())]) def admin_roles_1(request: Request): return request.user - @app.get("/auth/admin_roles_depend_2") def admin_roles_2( - request: Request, - auth_result: Union[bool, Response] = Depends(auth.requires('admin')()) + request: Request, + auth_result: Union[bool, Response] = Depends(auth.requires('admin')()) ): return request.user - # auth group @app.get("/auth/admin_groups") -@auth.requires(groups='admin') +@auth.requires(groups = 'admin') async def admin_groups(request: Request): return request.user - @app.get("/auth/vip_groups") -@auth.requires(groups=['vip']) +@auth.requires(groups = ['vip']) async def vip_groups(request: Request): return request.user - @app.get("/auth/admin_or_vip_groups") -@auth.requires(groups=['admin', 'vip']) +@auth.requires(groups = ['admin', 'vip']) async def admin_or_vip_groups(request: Request): return request.user - # auth permission @app.get("/auth/permissions") -@auth.requires(permissions=['test']) +@auth.requires(permissions = ['test']) async def route(request: Request): return request.user - -@pytest.mark.parametrize("logins", ['guest'], indirect=True) +@pytest.mark.parametrize("logins", ['guest'], indirect = True) def test_router_token(logins: UserClient): - response = logins.client.post('/auth/gettoken', - data={'username': 'admin', 'password': 'Incorrect'}, - headers={"Content-Type": "application/x-www-form-urlencoded"}) + response = logins.client.post( + '/auth/gettoken', + data = {'username': 'admin', 'password': 'Incorrect'}, + headers = {"Content-Type": "application/x-www-form-urlencoded"} + ) data = response.json() assert data['data'] is None - -@pytest.mark.parametrize("logins", ['admin', 'vip', 'test', 'guest'], indirect=True) +@pytest.mark.parametrize("logins", ['admin', 'vip', 'test', 'guest'], indirect = True) def test_router_userinfo(logins: UserClient): response = logins.client.get('/auth/userinfo') data = response.json() @@ -127,12 +119,12 @@ def test_router_userinfo(logins: UserClient): else: assert data['detail'] == 'Forbidden' - path_all = { "/auth/user", "/auth/user_1", "/auth/user_2", "/auth/user_3", + "/auth/user_4", # auth role '/auth/admin_roles', '/auth/vip_roles', @@ -152,6 +144,7 @@ def test_router_userinfo(logins: UserClient): "/auth/user_1", "/auth/user_2", "/auth/user_3", + "/auth/user_4", '/auth/admin_roles', "/auth/admin_or_vip_roles", "/auth/admin_roles_depend_1", @@ -166,6 +159,7 @@ def test_router_userinfo(logins: UserClient): "/auth/user_1", "/auth/user_2", "/auth/user_3", + "/auth/user_4", '/auth/vip_roles', "/auth/admin_or_vip_roles", '/auth/vip_groups', @@ -176,11 +170,11 @@ def test_router_userinfo(logins: UserClient): "/auth/user_1", "/auth/user_2", "/auth/user_3", + "/auth/user_4", '/auth/permissions', } - -@pytest.mark.parametrize("logins", ['admin'], indirect=True) +@pytest.mark.parametrize("logins", ['admin'], indirect = True) @pytest.mark.parametrize("path", list(path_admin_auth)) def test_admin_auth(logins: UserClient, path): response = logins.client.get(path) @@ -188,16 +182,14 @@ def test_admin_auth(logins: UserClient, path): assert data['id'] == logins.user.id assert data['username'] == logins.user.username - -@pytest.mark.parametrize("logins", ['admin'], indirect=True) +@pytest.mark.parametrize("logins", ['admin'], indirect = True) @pytest.mark.parametrize("path", list(path_all - path_admin_auth)) def test_admin_forbidden(logins: UserClient, path): response = logins.client.get(path) data = response.json() assert data['detail'] == 'Forbidden' - -@pytest.mark.parametrize("logins", ['vip'], indirect=True) +@pytest.mark.parametrize("logins", ['vip'], indirect = True) @pytest.mark.parametrize("path", list(path_vip_auth)) def test_vip_auth(logins: UserClient, path): response = logins.client.get(path) @@ -205,16 +197,14 @@ def test_vip_auth(logins: UserClient, path): assert data['id'] == logins.user.id assert data['username'] == logins.user.username - -@pytest.mark.parametrize("logins", ['vip'], indirect=True) +@pytest.mark.parametrize("logins", ['vip'], indirect = True) @pytest.mark.parametrize("path", list(path_all - path_vip_auth)) def test_vip_forbidden(logins: UserClient, path): response = logins.client.get(path) data = response.json() assert data['detail'] == 'Forbidden' - -@pytest.mark.parametrize("logins", ['test'], indirect=True) +@pytest.mark.parametrize("logins", ['test'], indirect = True) @pytest.mark.parametrize("path", list(path_test_auth)) def test_test_auth(logins: UserClient, path): response = logins.client.get(path) @@ -222,8 +212,7 @@ def test_test_auth(logins: UserClient, path): assert data['id'] == logins.user.id assert data['username'] == logins.user.username - -@pytest.mark.parametrize("logins", ['guest'], indirect=True) +@pytest.mark.parametrize("logins", ['guest'], indirect = True) @pytest.mark.parametrize("path", list(path_all)) def test_guest_forbidden(logins: UserClient, path): response = logins.client.get(path) diff --git a/tests/test_auth/test_backend.py b/tests/test_auth/test_backend.py index b36915c..cd7a41f 100644 --- a/tests/test_auth/test_backend.py +++ b/tests/test_auth/test_backend.py @@ -4,32 +4,29 @@ from fastapi_user_auth.auth.backends.jwt import JwtTokenStore from fastapi_user_auth.auth.schemas import BaseTokenData -token_data = BaseTokenData(id=1, username='test') - +token_data = BaseTokenData(id = 1, username = 'test') @pytest.mark.asyncio async def test_jwt_token_store(): - store = JwtTokenStore(secret_key='09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7') + store = JwtTokenStore(secret_key = '09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7') token = await store.write_token(token_data) assert token - data = await store.read_token(token=token) + data = await store.read_token(token = token) assert data == token_data with pytest.raises(NotImplementedError): await store.destroy_token(token) - @pytest.mark.asyncio async def test_db_token_store(db): store = DbTokenStore(db) token = await store.write_token(token_data) assert token - data = await store.read_token(token=token) + data = await store.read_token(token = token) assert data == token_data - await store.destroy_token(token=token) - data = await store.read_token(token=token) + await store.destroy_token(token = token) + data = await store.read_token(token = token) assert data is None - @pytest.mark.asyncio async def test_redis_token_store(): pass