diff --git a/docs/source/by-examples/050_dependencies.md b/docs/source/by-examples/050_dependencies.md index 12404f7a..c09d9d84 100644 --- a/docs/source/by-examples/050_dependencies.md +++ b/docs/source/by-examples/050_dependencies.md @@ -383,10 +383,8 @@ Lightbulb will enable dependency injection on a specific subset of your methods These are listed below: - {meth}`@lightbulb.invoke ` -- {meth}`@Client.register ` - {meth}`@Client.error_handler ` - {meth}`@Client.task ` -- {meth}`@Loader.command ` (due to it calling `Client.register` internally) - {meth}`@Loader.listener ` - {meth}`@Loader.task ` diff --git a/fragments/443.feature.md b/fragments/443.feature.md new file mode 100644 index 00000000..b63a43f5 --- /dev/null +++ b/fragments/443.feature.md @@ -0,0 +1,3 @@ +- Add `__contains__` method to `di.Container` to allow checking if a dependency is registered. + +- Allow parameter-injected dependencies to be optional, and have fallbacks if one is not available. diff --git a/lightbulb/di/container.py b/lightbulb/di/container.py index 88e9b3cf..f6514608 100644 --- a/lightbulb/di/container.py +++ b/lightbulb/di/container.py @@ -22,6 +22,7 @@ __all__ = ["Container"] +import types import typing as t import networkx as nx @@ -30,14 +31,17 @@ from lightbulb.di import exceptions from lightbulb.di import registry as registry_ from lightbulb.di import utils as di_utils +from lightbulb.internal import marker if t.TYPE_CHECKING: - import types from collections.abc import Callable from lightbulb.internal import types as lb_types T = t.TypeVar("T") +D = t.TypeVar("D") + +_MISSING = marker.Marker("MISSING") class Container: @@ -78,6 +82,17 @@ def __init__(self, registry: registry_.Registry, *, parent: Container | None = N self.add_value(Container, self) + def __contains__(self, item: type[t.Any]) -> bool: + dep_id = di_utils.get_dependency_id(item) + if dep_id not in self._graph: + return False + + container = self._graph.nodes[dep_id]["container"] + if dep_id in container._instances: + return True + + return container._graph.nodes[dep_id].get("factory") is not None + async def __aenter__(self) -> Container: return self @@ -116,9 +131,15 @@ def add_factory( Returns: :obj:`None` + Raises: + :obj:`ValueError`: When attempting to add a dependency for ``NoneType``. + See Also: :meth:`lightbulb.di.registry.Registry.add_factory` for factory and teardown function spec. """ + if typ is types.NoneType: + raise ValueError("cannot register type 'NoneType' - 'None' is used for optional dependencies") + dependency_id = di_utils.get_dependency_id(typ) if dependency_id in self._graph: @@ -144,9 +165,15 @@ def add_value( Returns: :obj:`None` + Raises: + :obj:`ValueError`: When attempting to add a dependency for ``NoneType``. + See Also: :meth:`lightbulb.di.registry.Registry.add_value` for teardown function spec. """ + if typ is types.NoneType: + raise ValueError("cannot register type 'NoneType' - 'None' is used for optional dependencies") + dependency_id = di_utils.get_dependency_id(typ) self._instances[dependency_id] = value @@ -154,7 +181,7 @@ def add_value( self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id))) self._graph.add_node(dependency_id, container=self, teardown=teardown) - async def _get(self, dependency_id: str) -> t.Any: + async def _get(self, dependency_id: str, *, allow_missing: bool = False) -> t.Any: if self._closed: raise exceptions.ContainerClosedException @@ -162,6 +189,8 @@ async def _get(self, dependency_id: str) -> t.Any: data = self._graph.nodes.get(dependency_id) if data is None or data.get("container") is None: + if allow_missing: + return _MISSING raise exceptions.DependencyNotSatisfiableException existing_dependency = data["container"]._instances.get(dependency_id) @@ -183,6 +212,9 @@ async def _get(self, dependency_id: str) -> t.Any: for dep_id in creation_order: if (container := self._graph.nodes[dep_id].get("container")) is None: + if allow_missing: + return _MISSING + raise exceptions.DependencyNotSatisfiableException( f"could not create dependency {dep_id!r} - not provided by this or a parent container" ) @@ -206,7 +238,11 @@ async def _get(self, dependency_id: str) -> t.Any: sub_dependencies: dict[str, t.Any] = {} try: for sub_dependency_id, param_name in node_data["factory_params"].items(): - sub_dependencies[param_name] = await node_data["container"]._get(sub_dependency_id) + sub_dependency = await node_data["container"]._get(sub_dependency_id, allow_missing=allow_missing) + if sub_dependency is _MISSING: + return _MISSING + + sub_dependencies[param_name] = sub_dependency except exceptions.DependencyNotSatisfiableException as e: raise exceptions.DependencyNotSatisfiableException( f"could not create dependency {dep_id!r} - failed creating sub-dependency" @@ -217,12 +253,19 @@ async def _get(self, dependency_id: str) -> t.Any: return self._graph.nodes[dependency_id]["container"]._instances[dependency_id] - async def get(self, typ: type[T]) -> T: + @t.overload + async def get(self, typ: type[T], /) -> T: ... + @t.overload + async def get(self, typ: type[T], /, *, default: D) -> T | D: ... + + async def get(self, typ: type[T], /, *, default: D = _MISSING) -> T | D: """ Get a dependency from this container, instantiating it and sub-dependencies if necessary. Args: typ: The type used when registering the dependency. + default: The default value to return if the dependency is not satisfiable. If not provided, this will + raise a :obj:`~lightbulb.di.exceptions.DependencyNotSatisfiableException`. Returns: The dependency for the given type. @@ -235,4 +278,9 @@ async def get(self, typ: type[T]) -> T: for any other reason. """ dependency_id = di_utils.get_dependency_id(typ) - return t.cast(T, await self._get(dependency_id)) + + dependency = await self._get(dependency_id, allow_missing=default is not _MISSING) + if dependency is _MISSING: + return default + + return t.cast(T, dependency) diff --git a/lightbulb/di/registry.py b/lightbulb/di/registry.py index ac1d4a27..34f73007 100644 --- a/lightbulb/di/registry.py +++ b/lightbulb/di/registry.py @@ -22,6 +22,7 @@ __all__ = ["Registry"] +import types import typing as t import networkx as nx @@ -33,7 +34,7 @@ from collections.abc import Callable from lightbulb.di import container - from lightbulb.internal import types + from lightbulb.internal import types as lb_types T = t.TypeVar("T") @@ -84,7 +85,7 @@ def register_value( typ: type[T], value: T, *, - teardown: Callable[[T], types.MaybeAwaitable[None]] | None = None, + teardown: Callable[[T], lb_types.MaybeAwaitable[None]] | None = None, ) -> None: """ Registers a pre-existing value as a dependency. @@ -100,15 +101,16 @@ def register_value( Raises: :obj:`lightbulb.di.exceptions.RegistryFrozenException`: If the registry is frozen. + :obj:`ValueError`: When attempting to register a dependency for ``NoneType``. """ self.register_factory(typ, lambda: value, teardown=teardown) def register_factory( self, typ: type[T], - factory: Callable[..., types.MaybeAwaitable[T]], + factory: Callable[..., lb_types.MaybeAwaitable[T]], *, - teardown: Callable[[T], types.MaybeAwaitable[None]] | None = None, + teardown: Callable[[T], lb_types.MaybeAwaitable[None]] | None = None, ) -> None: """ Registers a factory for creating a dependency. @@ -127,10 +129,14 @@ def register_factory( Raises: :obj:`lightbulb.di.exceptions.RegistryFrozenException`: If the registry is frozen. :obj:`lightbulb.di.exceptions.CircularDependencyException`: If the factory requires itself as a dependency. + :obj:`ValueError`: When attempting to register a dependency for ``NoneType``. """ if self._active_containers: raise exceptions.RegistryFrozenException + if typ is types.NoneType: + raise ValueError("cannot register type 'NoneType' - 'None' is used for optional dependencies") + dependency_id = di_utils.get_dependency_id(typ) # We are overriding a previously defined dependency and want to strip the edges, so we don't have diff --git a/lightbulb/di/solver.py b/lightbulb/di/solver.py index 352987fe..8db2bc93 100644 --- a/lightbulb/di/solver.py +++ b/lightbulb/di/solver.py @@ -42,10 +42,12 @@ import logging import os import sys +import types import typing as t from collections.abc import AsyncIterator from collections.abc import Awaitable from collections.abc import Callable +from collections.abc import Sequence from lightbulb import utils from lightbulb.di import container @@ -54,7 +56,7 @@ from lightbulb.internal import marker if t.TYPE_CHECKING: - from lightbulb.internal import types + from lightbulb.internal import types as lb_types P = t.ParamSpec("P") R = t.TypeVar("R") @@ -236,12 +238,24 @@ async def close(self) -> None: self._default_container = None -CANNOT_INJECT = object() +class ParamInfo(t.NamedTuple): + name: str + types: Sequence[t.Any] + optional: bool + injectable: bool -def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str, t.Any]], dict[str, t.Any]]: - positional_or_keyword_params: list[tuple[str, t.Any]] = [] - keyword_only_params: dict[str, t.Any] = {} +def _get_requested_types(annotation: t.Any) -> tuple[Sequence[t.Any], bool]: + if t.get_origin(annotation) in (t.Union, types.UnionType): + args = t.get_args(annotation) + + return tuple(a for a in args if a is not types.NoneType), types.NoneType in args + return (annotation,), False + + +def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[ParamInfo], list[ParamInfo]]: + positional_or_keyword_params: list[ParamInfo] = [] + keyword_only_params: list[ParamInfo] = [] parameters = inspect.signature(func, locals={"lightbulb": sys.modules["lightbulb"]}, eval_str=True).parameters for parameter in parameters.values(): @@ -254,15 +268,19 @@ def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str # If it has a default that isn't INJECTED or ((default := parameter.default) is not inspect.Parameter.empty and default is not INJECTED) ): + # We need to know about ALL pos-or-kw arguments so that we can exclude ones passed in + # when the injection-enabled function is called - this isn't the same for kw-only args if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): - positional_or_keyword_params.append((parameter.name, CANNOT_INJECT)) + positional_or_keyword_params.append(ParamInfo(parameter.name, (), False, False)) continue + requested_types, optional = _get_requested_types(parameter.annotation) + param_info = ParamInfo(parameter.name, requested_types, optional, True) if parameter.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: - positional_or_keyword_params.append((parameter.name, parameter.annotation)) + positional_or_keyword_params.append(param_info) else: # It has to be a keyword-only parameter - keyword_only_params[parameter.name] = parameter.annotation + keyword_only_params.append(param_info) return positional_or_keyword_params, keyword_only_params @@ -288,8 +306,8 @@ def __init__( self, func: Callable[..., Awaitable[t.Any]], self_: t.Any = None, - _cached_pos_or_kw_params: list[tuple[str, t.Any]] | None = None, - _cached_kw_only_params: dict[str, t.Any] | None = None, + _cached_pos_or_kw_params: list[ParamInfo] | None = None, + _cached_kw_only_params: list[ParamInfo] | None = None, ) -> None: self._func = func self._self: t.Any = self_ @@ -320,22 +338,41 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: di_container: container.Container | None = DI_CONTAINER.get(None) - injectables = { - name: type - for name, type in self._pos_or_kw_params[len(args) + (self._self is not None) :] - if name not in new_kwargs - } - injectables.update({name: type for name, type in self._kw_only_params.items() if name not in new_kwargs}) - - for name, type in injectables.items(): - # Skip any arguments that we can't inject - if type is CANNOT_INJECT: + maybe_injectables = [*self._pos_or_kw_params[len(args) + (self._self is not None) :], *self._kw_only_params] + for param in maybe_injectables: + # Skip any parameters we already have a value for, or is not valid to be injected + if param.name in new_kwargs or not param.injectable: continue if di_container is None: raise exceptions.DependencyNotSatisfiableException("no DI context is available") - new_kwargs[name] = await di_container.get(type) + # Resolve the dependency, or None if the dependency is unsatisfied and is optional + if len(param.types) == 1: + default_kwarg = {"default": None} if param.optional else {} + new_kwargs[param.name] = await di_container.get(param.types[0], **default_kwarg) + continue + + for i, type in enumerate(param.types): + resolved = await di_container.get(type, default=None) + + # Check if this is the last type to check, and we couldn't resolve a dependency for it + if resolved is None and i == (len(param.types) - 1): + # If this dependency is optional then set value to 'None' + if param.optional: + new_kwargs[param.name] = None + break + # We can't supply this dependency, so raise an exception + raise exceptions.DependencyNotSatisfiableException( + f"could not satisfy any dependencies for types {param.types}" + ) + + # We couldn't supply this type, so continue and check the next one + if resolved is None: + continue + # We could supply this type, set the parameter to the dependency and skip to the next parameter + new_kwargs[param.name] = resolved + break if self._self is not None: return await utils.maybe_await(self._func(self._self, *args, **new_kwargs)) @@ -347,10 +384,10 @@ def with_di(func: AsyncFnT) -> AsyncFnT: ... @t.overload -def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: ... +def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: ... -def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: +def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: """ Decorator that enables dependency injection on the decorated function. If dependency injection has been disabled globally then this function does nothing and simply returns the object that was passed in. diff --git a/tests/di/test_container.py b/tests/di/test_container.py index 4b5e3e4e..6b2e6ecb 100644 --- a/tests/di/test_container.py +++ b/tests/di/test_container.py @@ -19,6 +19,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import inspect +import types import typing as t from unittest import mock @@ -247,6 +248,98 @@ def f(_: A) -> object: async with di.Container(registry) as container: await container.get(B) + @pytest.mark.asyncio + async def test_non_direct_circular_dependency_raises_exception(self) -> None: + # fmt: off + def f_a(_: B) -> object: return object() + + def f_b(_: A) -> object: return object() + + # fmt: on + + registry = di.Registry() + registry.register_factory(A, f_a) + registry.register_factory(B, f_b) + + with pytest.raises(di.CircularDependencyException): + async with di.Container(registry) as c: + await c.get(A) + + @pytest.mark.asyncio + async def test_get_transient_dependency_raises_exception(self) -> None: + def f_a(_: B) -> object: + return object() + + registry = di.Registry() + registry.register_factory(A, f_a) + + with pytest.raises(di.DependencyNotSatisfiableException): + async with di.Container(registry) as c: + await c.get(B) + + @pytest.mark.asyncio + async def test_get_from_closed_container_raises_exception(self) -> None: + registry = di.Registry() + registry.register_factory(object, lambda: object()) + + with pytest.raises(di.ContainerClosedException): + async with di.Container(registry) as c: + pass + await c.get(object) + + @pytest.mark.asyncio + async def test_get_with_default_when_dependency_not_available_returns_default(self) -> None: + registry = di.Registry() + + async with di.Container(registry) as c: + assert await c.get(object, default=None) is None + + @pytest.mark.asyncio + async def test_get_with_default_when_sub_dependency_not_available_returns_default(self) -> None: + registry = di.Registry() + + def f1(_: str) -> object: + return object() + + registry.register_factory(object, f1) + + async with di.Container(registry) as c: + assert await c.get(object, default=None) is None + + @pytest.mark.asyncio + async def test__contains__returns_true_when_dependency_known_by_value(self) -> None: + registry = di.Registry() + async with di.Container(registry) as container: + container.add_value(object, object()) + assert object in container + + @pytest.mark.asyncio + async def test__contains__returns_true_when_dependency_known_by_factory(self) -> None: + registry = di.Registry() + async with di.Container(registry) as container: + container.add_factory(object, lambda: object()) + assert object in container + + @pytest.mark.asyncio + async def test__contains__returns_false_when_dependency_not_known(self) -> None: + registry = di.Registry() + async with di.Container(registry) as container: + assert object not in container + + @pytest.mark.asyncio + async def test_cannot_register_dependency_by_value_for_NoneType(self) -> None: + registry = di.Registry() + with pytest.raises(ValueError): + async with di.Container(registry) as container: + container.add_value(types.NoneType, None) + + @pytest.mark.asyncio + async def test_cannot_register_dependency_by_factory_for_NoneType(self) -> None: + registry = di.Registry() + with pytest.raises(ValueError): + async with di.Container(registry) as container: + container.add_factory(types.NoneType, lambda: None) + class TestContainerWithParent: @pytest.mark.asyncio @@ -313,38 +406,28 @@ def f(_: A) -> object: await cc.get(B) @pytest.mark.asyncio - async def test_non_direct_circular_dependency_raises_exception(self) -> None: - # fmt: off - def f_a(_: B) -> object: return object() - def f_b(_: A) -> object: return object() - # fmt: on + async def test__contains__returns_true_when_dependency_known_by_parent(self) -> None: + r1 = di.Registry() + r1.register_value(object, object()) - registry = di.Registry() - registry.register_factory(A, f_a) - registry.register_factory(B, f_b) + r2 = di.Registry() - with pytest.raises(di.CircularDependencyException): - async with di.Container(registry) as c: - await c.get(A) + async with di.Container(r1) as r1, di.Container(r2, parent=r1) as r2: + assert object in r2 @pytest.mark.asyncio - async def test_get_transient_dependency_raises_exception(self) -> None: - def f_a(_: B) -> object: + async def test_parent_dependency_returns_default_if_given_and_depends_on_child_dependency(self) -> None: + def f(_: A) -> object: return object() - registry = di.Registry() - registry.register_factory(A, f_a) - - with pytest.raises(di.DependencyNotSatisfiableException): - async with di.Container(registry) as c: - await c.get(B) + parent_registry = di.Registry() + parent_registry.register_factory(B, f) - @pytest.mark.asyncio - async def test_get_from_closed_container_raises_exception(self) -> None: - registry = di.Registry() - registry.register_factory(object, lambda: object()) + child_registry = di.Registry() + child_registry.register_factory(A, lambda: object()) - with pytest.raises(di.ContainerClosedException): - async with di.Container(registry) as c: - pass - await c.get(object) + async with ( + di.Container(parent_registry) as pc, + di.Container(child_registry, parent=pc) as cc, + ): + assert await cc.get(B, default=None) is None diff --git a/tests/di/test_registry.py b/tests/di/test_registry.py index 826a143f..e1e5eed0 100644 --- a/tests/di/test_registry.py +++ b/tests/di/test_registry.py @@ -18,6 +18,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import types import typing as t import pytest @@ -70,3 +71,13 @@ def test__contains__returns_false_when_NewType_dependency_not_registered(self) - registry = di.Registry() T = t.NewType("T", object) assert T not in registry + + def test_cannot_register_dependency_by_value_for_NoneType(self) -> None: + registry = di.Registry() + with pytest.raises(ValueError): + registry.register_value(types.NoneType, None) + + def test_cannot_register_dependency_by_factory_for_NoneType(self) -> None: + registry = di.Registry() + with pytest.raises(ValueError): + registry.register_factory(types.NoneType, lambda: None) diff --git a/tests/di/test_solver.py b/tests/di/test_solver.py index f694a403..5e7b2ae5 100644 --- a/tests/di/test_solver.py +++ b/tests/di/test_solver.py @@ -23,7 +23,8 @@ import pytest import lightbulb -from lightbulb.di.solver import CANNOT_INJECT +from lightbulb.di import solver +from lightbulb.di.solver import ParamInfo from lightbulb.di.solver import _parse_injectable_params @@ -32,7 +33,7 @@ def test_parses_positional_only_arg_correctly(self) -> None: def m(foo: object, /) -> None: ... pos, kw = _parse_injectable_params(m) - assert pos[0][1] is CANNOT_INJECT and len(kw) == 0 + assert not pos[0].injectable and len(kw) == 0 def test_parses_var_positional_arg_correctly(self) -> None: def m(*foo: object) -> None: ... @@ -50,24 +51,33 @@ def test_parses_args_with_non_INJECTED_default_correctly(self) -> None: def m(foo: object = object()) -> None: ... pos, kw = _parse_injectable_params(m) - assert pos[0][1] is CANNOT_INJECT and len(kw) == 0 + assert not pos[0].injectable and len(kw) == 0 def test_parses_args_with_no_annotation_correctly(self) -> None: def m(foo) -> None: # type: ignore[unknownParameterType] ... pos, kw = _parse_injectable_params(m) # type: ignore[unknownArgumentType] - assert pos[0][1] is CANNOT_INJECT and len(kw) == 0 + assert not pos[0].injectable and len(kw) == 0 def test_parses_args_correctly(self) -> None: def m( - foo: str, bar: int = lightbulb.di.INJECTED, *, baz: float, bork: bool = lightbulb.di.INJECTED + foo: str, + bar: int | float = lightbulb.di.INJECTED, + *, + baz: float | None, + bork: t.Union[bool, object] = lightbulb.di.INJECTED, + qux: t.Optional[object] = lightbulb.di.INJECTED, ) -> None: ... pos, kw = _parse_injectable_params(m) - assert pos == [("foo", str), ("bar", int)] - assert kw == {"baz": float, "bork": bool} + assert pos == [ParamInfo("foo", (str,), False, True), ParamInfo("bar", (int, float), False, True)] + assert kw == [ + ParamInfo("baz", (float,), True, True), + ParamInfo("bork", (bool, object), False, True), + ParamInfo("qux", (object,), True, True), + ] class TestMethodInjection: @@ -237,6 +247,90 @@ def m(foo, obj: object = lightbulb.di.INJECTED) -> None: # type: ignore[reportU async with manager.enter_context(lightbulb.di.Contexts.DEFAULT): await m("bar") + @pytest.mark.asyncio + async def test_None_provided_if_dependency_not_available_for_optional_parameter(self) -> None: + manager = lightbulb.di.DependencyInjectionManager() + + @lightbulb.di.with_di + async def m(foo: object | None = lightbulb.di.INJECTED) -> None: + assert foo is None + + async with manager.enter_context(lightbulb.di.Contexts.DEFAULT): + await m() + + @pytest.mark.asyncio + async def test_second_dependency_provided_if_first_not_available(self) -> None: + manager = lightbulb.di.DependencyInjectionManager() + + value = object() + manager.registry_for(lightbulb.di.Contexts.DEFAULT).register_value(object, value) + + @lightbulb.di.with_di + async def m(foo: str | object = lightbulb.di.INJECTED) -> None: + assert foo is value + + async with manager.enter_context(lightbulb.di.Contexts.DEFAULT): + await m() + + @pytest.mark.asyncio + async def test_first_dependency_provided_if_both_are_available(self) -> None: + manager = lightbulb.di.DependencyInjectionManager() + + manager.registry_for(lightbulb.di.Contexts.DEFAULT).register_value(str, "bar") + manager.registry_for(lightbulb.di.Contexts.DEFAULT).register_value(object, object()) + + @lightbulb.di.with_di + async def m(foo: str | object = lightbulb.di.INJECTED) -> None: + assert foo == "bar" + + async with manager.enter_context(lightbulb.di.Contexts.DEFAULT): + await m() + + @pytest.mark.asyncio + async def test_None_provided_if_no_dependencies_available(self) -> None: + manager = lightbulb.di.DependencyInjectionManager() + + @lightbulb.di.with_di + async def m(foo: str | object | None = lightbulb.di.INJECTED) -> None: + assert foo is None + + async with manager.enter_context(lightbulb.di.Contexts.DEFAULT): + await m() + + @pytest.mark.asyncio + async def test_exception_raised_when_no_dependencies_available(self) -> None: + manager = lightbulb.di.DependencyInjectionManager() + + @lightbulb.di.with_di + async def m(foo: str | object = lightbulb.di.INJECTED) -> None: ... + + with pytest.raises(lightbulb.di.DependencyNotSatisfiableException): + async with manager.enter_context(lightbulb.di.Contexts.DEFAULT): + await m() + + +class TestLazyInjecting: + def test_getattr_passes_through_to_function(self) -> None: + @lightbulb.di.with_di + def m() -> None: ... + + assert m.__name__ == "m" + + def test_setattr_passes_through_to_function(self) -> None: + def m() -> None: ... + + fn = lightbulb.di.with_di(m) + fn.__lb_foo__ = "bar" # type: ignore[reportFunctionMemberAccess] + + assert m.__lb_foo__ == "bar" # type: ignore[reportFunctionMemberAccess] + + def test__get__within_class_does_not_assign_self(self) -> None: + class Foo: + @lightbulb.di.with_di + def m(self) -> None: ... + + assert Foo.m._self is None # type: ignore[reportFunctionMemberAccess] + class TestDependencyInjectionManager: @pytest.mark.asyncio @@ -272,3 +366,15 @@ async def test_default_container_closed_once_manager_closed(self) -> None: await manager.close() assert default_container._closed assert manager.default_container is None + + +class TestWithDiDecorator: + def test_does_not_enable_injection_when_injection_already_enabled(self) -> None: + method = lightbulb.di.with_di(lambda: None) + assert lightbulb.di.with_di(method) is method + + def test_does_not_enable_injection_when_injection_globally_disabled(self) -> None: + solver.DI_ENABLED = False + method = lambda: None # noqa: E731 + assert lightbulb.di.with_di(method) is method + solver.DI_ENABLED = True