From f34b5f462e57c712bb4be99c92506d5b289ce947 Mon Sep 17 00:00:00 2001 From: Humberto Rocha Date: Fri, 12 Jul 2024 11:00:17 -0400 Subject: [PATCH] Refactor service setup --- docs/api.md | 2 +- pyproject.toml | 2 ++ src/wheke/__about__.py | 2 +- src/wheke/__init__.py | 6 ++-- src/wheke/_core.py | 53 +++++++++++++++++++++++++++++------ src/wheke/_pod.py | 8 +++--- src/wheke/_service.py | 28 ++++++++++++++++-- tests/conftest.py | 21 ++++++++++---- tests/example_app/__init__.py | 16 +++++++---- tests/test_app.py | 39 +++++++++++++------------- tests/test_cli.py | 7 ++--- 11 files changed, 130 insertions(+), 54 deletions(-) diff --git a/docs/api.md b/docs/api.md index 061ea3e..5764b8d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -8,7 +8,7 @@ - Wheke - WhekeSettings - Pod - - ServiceList + - ServiceConfig - aget_service - get_service - get_settings diff --git a/pyproject.toml b/pyproject.toml index e7a1b64..1c699fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ "Topic :: Internet", ] dependencies = [ + "anyio", "fastapi", "pydantic", "pydantic-settings", @@ -53,6 +54,7 @@ dependencies = [ "pytest", "pytest-cov", "ruff", + "trio", ] [tool.hatch.envs.default.scripts] diff --git a/src/wheke/__about__.py b/src/wheke/__about__.py index 3ced358..493f741 100644 --- a/src/wheke/__about__.py +++ b/src/wheke/__about__.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.3.0" diff --git a/src/wheke/__init__.py b/src/wheke/__init__.py index f85959e..1373cfa 100644 --- a/src/wheke/__init__.py +++ b/src/wheke/__init__.py @@ -1,12 +1,12 @@ from ._core import Wheke from ._demo import demo_pod -from ._pod import Pod, ServiceList -from ._service import aget_service, get_service +from ._pod import Pod +from ._service import ServiceConfig, aget_service, get_service from ._settings import WhekeSettings, get_settings __all__ = [ "Pod", - "ServiceList", + "ServiceConfig", "Wheke", "WhekeSettings", "aget_service", diff --git a/src/wheke/_core.py b/src/wheke/_core.py index 933ac4d..d2ae66f 100644 --- a/src/wheke/_core.py +++ b/src/wheke/_core.py @@ -1,4 +1,6 @@ +from collections.abc import Callable from importlib import import_module +from types import TracebackType from fastapi import FastAPI from fastapi.staticfiles import StaticFiles @@ -6,7 +8,12 @@ from ._cli import empty_callback, version from ._pod import Pod -from ._service import get_service_registry +from ._service import ( + ServiceConfig, + aclose_service_registry, + close_service_registry, + register_service, +) from ._settings import WhekeSettings, get_settings @@ -23,20 +30,22 @@ def __init__( ) -> None: self.pods = [] + settings_factory: Callable + if settings is None: settings_cls = WhekeSettings - settings_obj = WhekeSettings() + settings_factory = WhekeSettings elif isinstance(settings, WhekeSettings): settings_cls = type(settings) - settings_obj = settings + settings_factory = lambda: settings # NOQA: E731 else: settings_cls = settings - settings_obj = settings_cls() + settings_factory = settings_cls - get_service_registry().register_value(settings_cls, settings_obj) + register_service(ServiceConfig(settings_cls, settings_factory, True)) if settings_cls != WhekeSettings: - get_service_registry().register_value(WhekeSettings, settings_obj) + register_service(ServiceConfig(WhekeSettings, settings_factory, True)) for pod in get_settings(WhekeSettings).pods: self.add_pod(pod) @@ -53,8 +62,8 @@ def add_pod(self, pod_to_add: Pod | str) -> None: else: pod = pod_to_add - for service_type, service_factory in pod.services: - get_service_registry().register_factory(service_type, service_factory) + for service_config in pod.services: + register_service(service_config) self.pods.append(pod) @@ -90,3 +99,31 @@ def create_cli(self) -> Typer: cli.add_typer(pod.cli, name=pod.name) return cli + + def close(self) -> None: + close_service_registry() + + async def aclose(self) -> None: + await aclose_service_registry() + + def __enter__(self) -> "Wheke": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + async def __aenter__(self) -> "Wheke": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() diff --git a/src/wheke/_pod.py b/src/wheke/_pod.py index 2ab54c5..5d865af 100644 --- a/src/wheke/_pod.py +++ b/src/wheke/_pod.py @@ -1,10 +1,10 @@ -from collections.abc import Callable +from collections.abc import Iterable from pathlib import Path from fastapi import APIRouter from typer import Typer -ServiceList = list[tuple[type, Callable]] +from ._service import ServiceConfig class Pod: @@ -27,7 +27,7 @@ class Pod: static_path: Path | None "The path to the Pod static files." - services: ServiceList + services: Iterable[ServiceConfig] """ The list of services provided by the Pod. @@ -45,7 +45,7 @@ def __init__( router: APIRouter | None = None, static_url: str | None = None, static_path: str | Path | None = None, - services: ServiceList | None = None, + services: Iterable[ServiceConfig] | None = None, cli: Typer | None = None, ) -> None: self.name = name diff --git a/src/wheke/_service.py b/src/wheke/_service.py index 5eebd6c..c32d2e6 100644 --- a/src/wheke/_service.py +++ b/src/wheke/_service.py @@ -1,14 +1,36 @@ +from collections.abc import Callable +from dataclasses import dataclass from typing import TypeVar from svcs import Container, Registry T = TypeVar("T") -_registry = Registry() +_registry: Registry = Registry() -def get_service_registry() -> Registry: - return _registry +@dataclass +class ServiceConfig: + service_type: type + service_factory: Callable + as_value: bool = False + + +def close_service_registry() -> None: + _registry.close() + + +async def aclose_service_registry() -> None: + await _registry.aclose() + + +def register_service(config: ServiceConfig) -> bool: + if config.as_value: + _registry.register_value(config.service_type, config.service_factory()) + else: + _registry.register_factory(config.service_type, config.service_factory) + + return True def get_service(service_type: type[T]) -> T: diff --git a/tests/conftest.py b/tests/conftest.py index 48902d6..66113bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,26 @@ +from collections.abc import Generator +from typing import Any + import pytest from fastapi.testclient import TestClient from typer import Typer -from tests.example_app import wheke +from tests.example_app import make_wheke @pytest.fixture -def client() -> TestClient: - return TestClient(wheke.create_app()) +def client() -> Generator[TestClient, Any, Any]: + wheke = make_wheke() + + yield TestClient(wheke.create_app()) + + wheke.close() @pytest.fixture -def cli() -> Typer: - return wheke.create_cli() +def cli() -> Generator[Typer, Any, Any]: + wheke = make_wheke() + + yield wheke.create_cli() + + wheke.close() diff --git a/tests/example_app/__init__.py b/tests/example_app/__init__.py index daef3ed..f75c898 100644 --- a/tests/example_app/__init__.py +++ b/tests/example_app/__init__.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends from typer import Typer, echo -from wheke import Pod, Wheke, aget_service, demo_pod, get_service +from wheke import Pod, ServiceConfig, Wheke, aget_service, demo_pod, get_service STATIC_PATH = Path(__file__).parent / "static" @@ -64,12 +64,16 @@ def hello() -> None: static_url="/static", static_path=str(STATIC_PATH), services=[ - (PingService, ping_service_factory), - (APingService, aping_service_factory), + ServiceConfig(PingService, ping_service_factory), + ServiceConfig(APingService, aping_service_factory), ], cli=cli, ) -wheke = Wheke() -wheke.add_pod(demo_pod) -wheke.add_pod(test_pod) + +def make_wheke() -> Wheke: + wheke = Wheke() + wheke.add_pod(demo_pod) + wheke.add_pod(test_pod) + + return wheke diff --git a/tests/test_app.py b/tests/test_app.py index c96ac11..73cdbc5 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,51 +1,52 @@ +import pytest from fastapi import FastAPI, status from fastapi.testclient import TestClient from wheke import Pod, Wheke, WhekeSettings, demo_pod, get_settings from wheke._demo import DEMO_PAGE +pytestmark = pytest.mark.anyio -def test_create_app() -> None: - wheke = Wheke() - app = wheke.create_app() +async def test_create_app() -> None: + async with Wheke() as wheke: + app = wheke.create_app() - assert type(app) is FastAPI + assert type(app) is FastAPI def test_create_app_with_demo_pod_in_settings() -> None: settings = WhekeSettings() settings.pods = ["wheke.demo_pod"] - wheke = Wheke(settings) + with Wheke(settings) as wheke: + app = wheke.create_app() - app = wheke.create_app() - - assert type(app) is FastAPI - assert demo_pod in wheke.pods + assert type(app) is FastAPI + assert demo_pod in wheke.pods def test_create_app_with_empty_pod() -> None: empty_pod = Pod("empty") - wheke = Wheke() - wheke.add_pod(empty_pod) - app = wheke.create_app() + with Wheke() as wheke: + wheke.add_pod(empty_pod) + + app = wheke.create_app() - assert type(app) is FastAPI - assert wheke.pods == [empty_pod] + assert type(app) is FastAPI + assert wheke.pods == [empty_pod] def test_create_app_with_custom_settings_class() -> None: class CustomSettings(WhekeSettings): test_setting: str = "test" - wheke = Wheke(CustomSettings) - - app = wheke.create_app() + with Wheke(CustomSettings) as wheke: + app = wheke.create_app() - assert type(app) is FastAPI - assert get_settings(CustomSettings).test_setting == "test" + assert type(app) is FastAPI + assert get_settings(CustomSettings).test_setting == "test" def test_demo_pod(client: TestClient) -> None: diff --git a/tests/test_cli.py b/tests/test_cli.py index f8903c1..ef38602 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,11 +9,10 @@ def test_create_cli() -> None: - wheke = Wheke() + with Wheke() as wheke: + app = wheke.create_cli() - app = wheke.create_cli() - - assert type(app) is Typer + assert type(app) is Typer def test_version() -> None: