Skip to content

Commit

Permalink
feat: add the 'get_current_user' dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
amisadmin committed Aug 31, 2022
1 parent 9705484 commit 207e33b
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 154 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def vip_roles_and_article_update(request: Request):

```

### 依赖项
### 依赖项(推荐)

- 推荐场景: 路由集合,FastAPI应用
- 推荐场景: 单个路由,路由集合,FastAPI应用.

```python
from fastapi import Depends
Expand All @@ -156,11 +156,10 @@ from fastapi_user_auth.auth import Auth
from fastapi_user_auth.auth.models import User


# 路由参数依赖项
@app.get("/auth/admin_roles_depend_1")
def admin_roles(request: Request,
auth_result: Tuple[Auth, User] = Depends(auth.requires('admin')())):
return request.user
# 路由参数依赖项, 推荐使用此方式
@app.get("/auth/admin_roles_depend_1")
def admin_roles(user: User = Depends(auth.get_current_user)):
return user # or request.user


# 路径操作装饰器依赖项
Expand Down Expand Up @@ -200,6 +199,7 @@ from fastapi_user_auth.auth.models import User


async def get_request_user(request: Request) -> Optional[User]:
# user= await auth.get_current_user(request)
if await auth.requires('admin', response=False)(request):
return request.user
else:
Expand Down
6 changes: 3 additions & 3 deletions fastapi_user_auth/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastapi_user_auth.auth.models import BaseUser, User, Group, Permission, Role
from fastapi_user_auth.auth.schemas import UserLoginOut
from pydantic import BaseModel
from sqlalchemy import insert, update
from sqlalchemy import insert, update, select
from starlette import status
from starlette.requests import Request
from starlette.responses import Response
Expand Down Expand Up @@ -126,10 +126,10 @@ async def handle(
**kwargs
) -> BaseApiOut[BaseModel]: # self.schema_submit_out
auth: Auth = request.auth
user = await auth.get_user_by_username(data.username)
user = await auth.db.scalar(select(self.user_model).where(self.user_model.username == data.username))
if user:
return BaseApiOut(status = -1, msg = _('Username has been registered!'), data = None)
user = await auth.get_user_by_whereclause(self.user_model.email == data.email)
user = await auth.db.scalar(select(self.user_model).where(self.user_model.email == data.email))
if user:
return BaseApiOut(status = -2, msg = _('Email has been registered!'), data = None)
user = self.user_model.parse_obj(data)
Expand Down
63 changes: 35 additions & 28 deletions fastapi_user_auth/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Coroutine
from typing import Type, Any, TypeVar, Optional, Sequence, Tuple, Union, Callable, Generic

from fastapi import FastAPI, HTTPException, Depends, Form
from fastapi import FastAPI, HTTPException, Depends, Form, params
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from fastapi_amis_admin.crud.base import RouterMixin
Expand All @@ -15,6 +15,7 @@
from fastapi_amis_admin.utils.translation import i18n as _
from passlib.context import CryptContext
from pydantic import BaseModel, SecretStr
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy_database import AsyncDatabase, Database
from sqlmodel import select
Expand Down Expand Up @@ -46,16 +47,7 @@ def get_user_token(request: Request) -> Optional[str]:
return token

async def authenticate(self, request: Request) -> Tuple["Auth", Optional[_UserModelT]]:
if request.scope.get('auth'): # 防止重复授权
return request.scope.get('auth'), request.scope.get('user')
request.scope["auth"], request.scope["user"] = self.auth, None
token = self.get_user_token(request)
if not token:
return self.auth, None
token_data = await self.token_store.read_token(token)
if token_data is not None:
request.scope["user"]: _UserModelT = await self.auth.get_user_by_username(token_data.username)
return request.auth, request.user
return self.auth, await self.auth.get_current_user(request)

def attach_middleware(self, app: FastAPI):
app.add_middleware(AuthenticationMiddleware, backend = self) # 添加auth中间件
Expand All @@ -78,21 +70,34 @@ def __init__(
self.backend = self.backend or AuthBackend(self, token_store or DbTokenStore(self.db))
self.pwd_context = pwd_context

async def get_user_by_username(self, username: str) -> Optional[_UserModelT]:
return await self.get_user_by_whereclause(self.user_model.username == username)

async def get_user_by_whereclause(self, *whereclause: Any) -> Optional[_UserModelT]:
return await self.db.async_scalar(select(self.user_model).where(*whereclause))

async def authenticate_user(self, username: str, password: Union[str, SecretStr]) -> Optional[_UserModelT]:
user = await self.get_user_by_username(username)
user = await self.db.async_scalar(select(self.user_model).where(self.user_model.username == username))
if user:
pwd = password.get_secret_value() if isinstance(password, SecretStr) else password
pwd2 = user.password.get_secret_value() if isinstance(user.password, SecretStr) else user.password
if self.pwd_context.verify(pwd, pwd2): # 用户存在 且 密码验证通过
return user
return None

@cached_property
def get_current_user(self):
async def _get_current_user(
request: Request,
session: Union[Session, AsyncSession, None] = Depends(self.db.session_generator)
) -> Optional[_UserModelT]:
if request.scope.get('auth'): # 防止重复授权
return request.scope.get('user')
request.scope["auth"], request.scope["user"] = self, None
token = self.backend.get_user_token(request)
if not token:
return None
token_data = await self.backend.token_store.read_token(token)
if token_data is not None:
request.scope["user"]: _UserModelT = await self.db.async_get(self.user_model, token_data.id, session = session)
return request.user

return _get_current_user

def requires(
self,
roles: Union[str, Sequence[str]] = None,
Expand All @@ -103,21 +108,22 @@ def requires(
response: Union[bool, Response] = None,
) -> Callable: # sourcery no-metrics

async def has_requires(conn: HTTPConnection) -> bool:
# todo websocket support
await self.backend.authenticate(conn) # type:ignore
if not conn.user:
return False
return await self.db.async_run_sync(
conn.user.has_requires,
async def has_requires(user: _UserModelT) -> bool:
return user and await self.db.async_run_sync(
user.has_requires,
roles = roles,
groups = groups,
permissions = permissions,
is_session = True
)

async def depend(request: Request) -> Union[bool, Response]:
if not await has_requires(request):
async def depend(
request: Request,
user: _UserModelT = Depends(self.get_current_user),
) -> Union[bool, Response]:
if isinstance(user, params.Depends):
user = await self.get_current_user(request)
if not await has_requires(user):
if response is not None:
return response
code, headers = status_code, {}
Expand Down Expand Up @@ -150,7 +156,8 @@ async def websocket_wrapper(
) -> None:
websocket = kwargs.get("websocket", args[idx] if args else None)
assert isinstance(websocket, WebSocket)
if not await has_requires(websocket):
user = await self.get_current_user(websocket) # type: ignore
if not await has_requires(user):
await websocket.close()
else:
await func(*args, **kwargs)
Expand Down
76 changes: 38 additions & 38 deletions tests/test_auth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,41 @@
from fastapi_user_auth.auth.models import User
from tests.conftest import async_db, sync_db


@pytest.fixture(params=[async_db, sync_db])
@pytest.fixture(params = [async_db, sync_db])
async def db(request) -> Union[Database, AsyncDatabase]:
database = request.param
await database.async_run_sync(SQLModel.metadata.create_all, is_session=False)
await database.async_run_sync(SQLModel.metadata.create_all, is_session = False)
yield database
await database.async_run_sync(SQLModel.metadata.drop_all, is_session=False)

await database.async_run_sync(SQLModel.metadata.drop_all, is_session = False)

app = FastAPI()
# 创建auth实例
auth = Auth(db=async_db)
auth = Auth(db = async_db)
# 注册auth基础路由
auth_router = AuthRouter(auth=auth)
auth_router = AuthRouter(auth = auth)
app.include_router(auth_router.router)


class UserClient:

def __init__(self, client: TestClient = None, user: User = None) -> None:
self.client: TestClient = client or TestClient(app)
self.user: User = user


def get_login_client(username: str = None, password: str = None) -> UserClient:
client = TestClient(app)
if not username or not password:
return UserClient()
response = client.post('/auth/gettoken',
data={'username': username, 'password': password},
headers={"Content-Type": "application/x-www-form-urlencoded"})
response = client.post(
'/auth/gettoken',
data = {'username': username, 'password': password},
headers = {"Content-Type": "application/x-www-form-urlencoded"}
)
data = response.json()
assert data['data']['access_token']
user = User.parse_obj(data['data'])
assert user.is_active
assert user.username == username
return UserClient(client=client, user=user)

return UserClient(client = client, user = user)

@pytest.fixture
def logins(request) -> UserClient:
Expand All @@ -64,52 +62,54 @@ def logins(request) -> UserClient:
user = user_data.get(request.param) or {}
return get_login_client(**user)


@pytest.fixture(scope="session")
@pytest.fixture(scope = "session")
async def prepare_database() -> AsyncGenerator[None, None]:
await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session=False)
await auth.db.async_run_sync(SQLModel.metadata.create_all, is_session = False)
yield
await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session=False)

await auth.db.async_run_sync(SQLModel.metadata.drop_all, is_session = False)

@pytest.fixture(scope="session")
@pytest.fixture(scope = "session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest_asyncio.fixture(scope="session", autouse=True)
@pytest_asyncio.fixture(scope = "session", autouse = True)
async def fake_users(prepare_database):
await auth.db.async_run_sync(create_fake_users)


# noinspection PyTypeChecker
def create_fake_users(session: Session):
# init permission
admin_perm = Permission(key='admin', name='admin permission')
vip_perm = Permission(key='vip', name='vip permission')
test_perm = Permission(key='test', name='test permission')
admin_perm = Permission(key = 'admin', name = 'admin permission')
vip_perm = Permission(key = 'vip', name = 'vip permission')
test_perm = Permission(key = 'test', name = 'test permission')
session.add_all([admin_perm, vip_perm, test_perm])
session.flush([admin_perm, vip_perm, test_perm])
# init role
admin_role = Role(key='admin', name='admin role', permissions=[admin_perm])
vip_role = Role(key='vip', name='vip role', permissions=[vip_perm])
test_role = Role(key='test', name='test role', permissions=[test_perm])
admin_role = Role(key = 'admin', name = 'admin role', permissions = [admin_perm])
vip_role = Role(key = 'vip', name = 'vip role', permissions = [vip_perm])
test_role = Role(key = 'test', name = 'test role', permissions = [test_perm])
session.add_all([admin_role, vip_role, test_role])
session.flush([admin_role, vip_role, test_role])
# init group
admin_group = Group(key='admin', name='admin group', roles=[admin_role])
vip_group = Group(key='vip', name='vip group', roles=[vip_role])
test_group = Group(key='test', name='test group', roles=[test_role])
admin_group = Group(key = 'admin', name = 'admin group', roles = [admin_role])
vip_group = Group(key = 'vip', name = 'vip group', roles = [vip_role])
test_group = Group(key = 'test', name = 'test group', roles = [test_role])
session.add_all([admin_group, vip_group, test_group])
session.flush([admin_group, vip_group, test_group])
# init user
admin_user = User(username='admin', password=auth.pwd_context.hash('admin'), email='admin@amis.work',
roles=[admin_role], groups=[admin_group])
vip_user = User(username='vip', password=auth.pwd_context.hash('vip'), email='vip@amis.work', roles=[vip_role],
groups=[vip_group])
test_user = User(username='test', password=auth.pwd_context.hash('test'), email='test@amis.work', roles=[test_role],
groups=[test_group])
admin_user = User(
username = 'admin', password = auth.pwd_context.hash('admin'), email = 'admin@amis.work',
roles = [admin_role], groups = [admin_group]
)
vip_user = User(
username = 'vip', password = auth.pwd_context.hash('vip'), email = 'vip@amis.work', roles = [vip_role],
groups = [vip_group]
)
test_user = User(
username = 'test', password = auth.pwd_context.hash('test'), email = 'test@amis.work', roles = [test_role],
groups = [test_group]
)
session.add_all([admin_user, vip_user, test_user])
session.flush([admin_user, vip_user, test_user])
15 changes: 0 additions & 15 deletions tests/test_auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from fastapi_user_auth.auth.models import User
from tests.test_auth.conftest import auth


async def test_create_role_user():
user = await auth.create_role_user('admin2')
assert user.username == 'admin2'
Expand All @@ -14,7 +13,6 @@ async def test_create_role_user():
role = result.roles[0]
assert role.key == 'admin2'


async def test_authenticate_user():
# error
user = await auth.authenticate_user('admin', 'admin1')
Expand All @@ -23,16 +21,3 @@ async def test_authenticate_user():
# admin
user = await auth.authenticate_user('admin', 'admin')
assert user.username == 'admin'


async def test_get_user_by_whereclause():
user = await auth.get_user_by_whereclause(User.id == 1)
assert user.username == 'admin'

user = await auth.get_user_by_whereclause(User.username == 'admin')
assert user.username == 'admin'


async def test_get_user_by_username():
user = await auth.get_user_by_username('admin')
assert user.username == 'admin'
11 changes: 3 additions & 8 deletions tests/test_auth/test_auth_mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,34 @@
app.mount('/subapp2', subapp2)
auth.backend.attach_middleware(subapp2)

subapp3 = FastAPI(dependencies=[Depends(auth.requires('admin')())])
subapp3 = FastAPI(dependencies = [Depends(auth.requires('admin')())])
app.mount('/subapp3', subapp3)


# auth decorator
@subapp1.get("/auth/user")
@auth.requires()
def user(request: Request):
return request.user


@subapp2.get("/auth/user")
def user_2(request: Request):
if request.user:
return request.user
else:
raise HTTPException(status_code=403)

raise HTTPException(status_code = 403)

@subapp3.get("/auth/user")
@auth.requires()
def user_3(request: Request):
return request.user


path_admin_auth = {
"/subapp1/auth/user",
"/subapp2/auth/user",
"/subapp3/auth/user",
}


@pytest.mark.parametrize("logins", ['admin'], indirect=True)
@pytest.mark.parametrize("logins", ['admin'], indirect = True)
@pytest.mark.parametrize("path", list(path_admin_auth))
def test_admin_auth(logins: UserClient, path):
response = logins.client.get(path)
Expand Down
Loading

0 comments on commit 207e33b

Please sign in to comment.