Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: various DI improvements #449

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fragments/449.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- Improved container hierarchical dependency resolution process.
- Add significantly more `DEBUG` logging throughout the DI pipeline.
- `DependencyInjectionManager.enter_context` now needs to be called for every context you wish to enter; it will no longer enter the default context automatically.
- `DependencyInjectionManager.enter_context` now searches the existing container hierarchy to find an existing container for the passed context before trying to create a new one.
- `DependencyInjectionManager.enter_context` now returns a no-op container implementation instead of `None` if DI is disabled globally.
18 changes: 11 additions & 7 deletions lightbulb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,11 @@
async def _execute_autocomplete_context(
self, context: context_.AutocompleteContext[t.Any], autocomplete_provider: options_.AutocompleteProvider[t.Any]
) -> None:
async with self.di.enter_context(di_.Contexts.AUTOCOMPLETE) as container:
if container is not None: # type: ignore[reportUnnecessaryComparison]
container.add_value(context_.AutocompleteContext, context)
async with (

Check warning on line 949 in lightbulb/client.py

View check run for this annotation

Codecov / codecov/patch

lightbulb/client.py#L949

Added line #L949 was not covered by tests
self.di.enter_context(di_.Contexts.DEFAULT),
self.di.enter_context(di_.Contexts.AUTOCOMPLETE) as container,
):
container.add_value(context_.AutocompleteContext, context)

Check warning on line 953 in lightbulb/client.py

View check run for this annotation

Codecov / codecov/patch

lightbulb/client.py#L953

Added line #L953 was not covered by tests

try:
await autocomplete_provider(context)
Expand Down Expand Up @@ -1015,10 +1017,12 @@
async def _execute_command_context(self, context: context_.Context) -> None:
pipeline = execution.ExecutionPipeline(context, self.execution_step_order)

async with self.di.enter_context(di_.Contexts.COMMAND) as container:
if container is not None: # type: ignore[reportUnnecessaryComparison]
container.add_value(context_.Context, context)
container.add_value(execution.ExecutionPipeline, pipeline)
async with (

Check warning on line 1020 in lightbulb/client.py

View check run for this annotation

Codecov / codecov/patch

lightbulb/client.py#L1020

Added line #L1020 was not covered by tests
self.di.enter_context(di_.Contexts.DEFAULT),
self.di.enter_context(di_.Contexts.COMMAND) as container,
):
container.add_value(context_.Context, context)
container.add_value(execution.ExecutionPipeline, pipeline)

Check warning on line 1025 in lightbulb/client.py

View check run for this annotation

Codecov / codecov/patch

lightbulb/client.py#L1024-L1025

Added lines #L1024 - L1025 were not covered by tests

try:
await pipeline._run()
Expand Down
71 changes: 28 additions & 43 deletions lightbulb/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = ["Container"]

import logging
import typing as t

import networkx as nx
Expand All @@ -35,9 +36,11 @@
import types
from collections.abc import Callable

from lightbulb.di import solver
from lightbulb.internal import types as lb_types

T = t.TypeVar("T")
LOGGER = logging.getLogger(__name__)


class Container:
Expand All @@ -49,33 +52,22 @@ class Container:
parent: The parent container. Defaults to None.
"""

__slots__ = ("_closed", "_graph", "_instances", "_parent", "_registry")
__slots__ = ("_closed", "_graph", "_instances", "_parent", "_registry", "_tag")

def __init__(self, registry: registry_.Registry, *, parent: Container | None = None) -> None:
def __init__(
self, registry: registry_.Registry, *, parent: Container | None = None, tag: solver.Context | None = None
) -> None:
self._registry = registry
self._registry._freeze(self)

self._parent = parent
self._tag = tag

self._closed = False

self._graph: nx.DiGraph[str] = nx.DiGraph(self._parent._graph) if self._parent is not None else nx.DiGraph()
self._graph: nx.DiGraph[str] = nx.DiGraph(self._registry._graph)
self._instances: dict[str, t.Any] = {}

# Add our registry entries to the graphs
for node, node_data in self._registry._graph.nodes.items():
new_node_data = dict(node_data)

# Set the origin container if this is a concrete dependency instead of a transient one
if node_data.get("factory") is not None:
new_node_data["container"] = self

# If we are overriding a previously defined dependency with our own
if node in self._graph and node_data.get("factory") is not None:
self._graph.remove_edges_from(list(self._graph.out_edges(node)))

self._graph.add_node(node, **new_node_data)
self._graph.add_edges_from(self._registry._graph.edges)

self.add_value(Container, self)

async def __aenter__(self) -> Container:
Expand Down Expand Up @@ -152,23 +144,25 @@ def add_value(

if dependency_id in self._graph:
self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id)))
self._graph.add_node(dependency_id, container=self, teardown=teardown)
self._graph.add_node(dependency_id, factory=lambda: None, teardown=teardown)

async def _get(self, dependency_id: str) -> t.Any:
if self._closed:
raise exceptions.ContainerClosedException

# TODO - look into whether locking is necessary - how likely are we to have race conditions

data = self._graph.nodes.get(dependency_id)
if data is None or data.get("container") is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dependency_id!r} - not provided by this or a parent container"
)
if (existing := self._instances.get(dependency_id)) is not None:
return existing

existing_dependency = data["container"]._instances.get(dependency_id)
if existing_dependency is not None:
return existing_dependency
if (data := self._graph.nodes.get(dependency_id)) is None or data.get("factory") is None:
if self._parent is None:
raise exceptions.DependencyNotSatisfiableException(
f"cannot create dependency {dependency_id!r} - not provided by this or a parent container"
)

LOGGER.debug("dependency %r not provided by this container - checking parent", dependency_id)
return await self._parent._get(dependency_id)

# TODO - look into caching individual dependency creation order globally
# - may speed up using subsequent containers (i.e. for each command)
Expand All @@ -177,47 +171,38 @@ async def _get(self, dependency_id: str) -> t.Any:
assert isinstance(subgraph, nx.DiGraph)

try:
creation_order = reversed(list(nx.topological_sort(subgraph)))
creation_order = list(reversed(list(nx.topological_sort(subgraph))))
except nx.NetworkXUnfeasible:
raise exceptions.CircularDependencyException(
f"cannot provide {dependency_id!r} - circular dependency found during creation"
)

LOGGER.debug("dependency %r depends on %s", dependency_id, creation_order[:-1])
for dep_id in creation_order:
if (container := self._graph.nodes[dep_id].get("container")) is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - not provided by this or a parent container"
)

# We already have the dependency we need
if dep_id in container._instances:
if dep_id in self._instances:
continue

node_data = self._graph.nodes[dep_id]
# Check that we actually know how to create the dependency - this should have been caught earlier
# by checking that node["container"] was present - but just in case, we check for the factory
if node_data.get("factory") is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - do not know how to instantiate"
)

# Get the dependencies for this dependency from the container this dependency was defined in.
# This prevents 'scope promotion' - a dependency from the parent container requiring one from the
# child container, and hence the lifecycle of the child dependency being extended to
# that of the parent.
sub_dependencies: dict[str, t.Any] = {}
try:
LOGGER.debug("checking sub-dependencies for %r", dep_id)
for sub_dependency_id, param_name in node_data["factory_params"].items():
sub_dependencies[param_name] = await node_data["container"]._get(sub_dependency_id)
sub_dependencies[param_name] = await self._get(sub_dependency_id)
except exceptions.DependencyNotSatisfiableException as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - failed creating sub-dependency"
) from e

# Cache the created dependency in the correct container to ensure the correct lifecycle
container._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies))
self._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies))

return self._graph.nodes[dependency_id]["container"]._instances[dependency_id]
return self._instances[dependency_id]

async def get(self, typ: type[T]) -> T:
"""
Expand Down
131 changes: 93 additions & 38 deletions lightbulb/di/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@
import os
import sys
import typing as t
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Coroutine

from lightbulb import utils
from lightbulb.di import container
Expand All @@ -55,11 +51,16 @@
from lightbulb.internal import marker

if t.TYPE_CHECKING:
from lightbulb.internal import types
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Coroutine

from lightbulb.internal import types as lb_types

P = t.ParamSpec("P")
R = t.TypeVar("R")
AsyncFnT = t.TypeVar("AsyncFnT", bound=Callable[..., Coroutine[t.Any, t.Any, t.Any]])
T = t.TypeVar("T")

DI_ENABLED: t.Final[bool] = os.environ.get("LIGHTBULB_DI_DISABLED", "false").lower() != "true"
DI_CONTAINER: contextvars.ContextVar[container.Container | None] = contextvars.ContextVar(
Expand Down Expand Up @@ -132,6 +133,32 @@ class Contexts:
}


class _NoOpContainer(container.Container):
__slots__ = ()

def add_factory(
self,
typ: type[T],
factory: Callable[..., lb_types.MaybeAwaitable[T]],
*,
teardown: Callable[[T], lb_types.MaybeAwaitable[None]] | None = None,
) -> None: ...

def add_value(
self,
typ: type[T],
value: T,
*,
teardown: Callable[[T], lb_types.MaybeAwaitable[None]] | None = None,
) -> None: ...

def _get(self, dependency_id: str) -> t.Any:
raise exceptions.DependencyNotSatisfiableException("dependency injection is globally disabled")


_NOOP_CONTAINER = _NoOpContainer(registry.Registry(), tag=Contexts.DEFAULT)


class DependencyInjectionManager:
"""Class which contains dependency injection functionality."""

Expand Down Expand Up @@ -167,9 +194,8 @@ async def enter_context(self, context: Context = Contexts.DEFAULT, /) -> AsyncIt
Context manager that ensures a dependency injection context is available for the nested operations.

Args:
context: The context to enter. If you are trying to enter a non-default (:obj:`~DiContext.DEFAULT`) context,
the default context will be entered first to ensure its dependencies are available. Defaults to
:obj:`~DiContext.DEFAULT`.
context: The context to enter. If a container for the given context already exists, it will be returned
and a new container will not be created.

Yields:
:obj:`~lightbulb.di.container.Container`: The container that has been entered.
Expand All @@ -179,50 +205,74 @@ async def enter_context(self, context: Context = Contexts.DEFAULT, /) -> AsyncIt
.. code-block:: python

# Enter a specific context ('client' is your lightbulb.Client instance)
with client.di.enter_context(lightbulb.di.DiContext.COMMAND):
async with client.di.enter_context(lightbulb.di.Contexts.COMMAND):
await some_function_that_needs_dependencies()

Note:
If you want to enter multiple contexts - i.e. a command context that requires the default context to
be available first - you should call this once for each context that is needed.

.. code-block:: python

async with (
client.di.enter_context(lightbulb.di.Contexts.DEFAULT),
client.di.enter_context(lightbulb.di.Contexts.COMMAND)
):
...

Warning:
If you have disabled dependency injection using the ``LIGHTBULB_DI_DISABLED`` environment variable,
this method will do nothing and the context manager will return :obj:`None`. Most users will never
have to worry about this, but it is something to consider. The type-hint does not reflect this
to prevent your type-checker complaining about not checking for :obj:`None`.
"""
if not DI_ENABLED:
# I'm not sure how to deal with this - but I definitely don't want to hint the return type
# as optional because that adds annoying assertions further down the line for users
#
# I think I should just account for this internally within the library and document the
# behaviour - chances are almost all users will never come across this
yield None # type: ignore
# Return a container that will never register dependencies and cannot have dependencies
# retrieved from it - it will always raise an error if someone tries to use DI while it is
# globally disabled.
yield _NOOP_CONTAINER
return

initial_token, initial = None, DI_CONTAINER.get(None)
if initial is None:
if self._default_container is None:
self._default_container = container.Container(self._registries[Contexts.DEFAULT])
self._default_container.add_value(DefaultContainer, self._default_container)
LOGGER.debug("attempting to enter context %r", context)

new_container: container.Container | None = None
created: bool = False

token, value = None, DI_CONTAINER.get(None)
if value is not None:
LOGGER.debug("searching for existing container for context %r", context)
this = value
while this:
if this._tag == context:
new_container = this
LOGGER.debug("existing container found for context %r", context)
break

initial_token = DI_CONTAINER.set(self._default_container)
this = this._parent

ctx_token: contextvars.Token[container.Container | None] | None = None
if context != Contexts.DEFAULT:
new_container = container.Container(self._registries[context], parent=DI_CONTAINER.get())
if new_container is None:
LOGGER.debug("creating new container for context %r", context)

new_container = container.Container(self._registries[context], parent=value, tag=context)
new_container.add_value(_CONTAINER_TYPE_BY_CONTEXT[context], new_container)
ctx_token = DI_CONTAINER.set(new_container)

if context == Contexts.DEFAULT:
self._default_container = new_container

created = True

token = DI_CONTAINER.set(new_container)
LOGGER.debug("entered context %r", context)

try:
if (ct := DI_CONTAINER.get(None)) is not None:
if ct is self.default_container:
yield ct
else:
async with ct:
yield ct
if new_container is self._default_container or not created:
yield new_container
else:
async with new_container:
yield new_container
finally:
if ctx_token is not None:
DI_CONTAINER.reset(ctx_token)
if initial_token is not None:
DI_CONTAINER.reset(initial_token)
DI_CONTAINER.reset(token)
LOGGER.debug("cleared context %r", context)

async def close(self) -> None:
"""
Expand All @@ -237,7 +287,7 @@ async def close(self) -> None:
self._default_container = None


CANNOT_INJECT = object()
CANNOT_INJECT: t.Final[t.Any] = marker.Marker("CANNOT_INJECT")


def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str, t.Any]], dict[str, t.Any]]:
Expand Down Expand Up @@ -336,14 +386,19 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if di_container is None:
raise exceptions.DependencyNotSatisfiableException("no DI context is available")

LOGGER.debug("requesting dependency for type %r", type)
new_kwargs[name] = await di_container.get(type)

if len(new_kwargs) > len(kwargs):
func_name = ((self._self.__class__.__name__ + ".") if self._self else "") + self._func.__name__
LOGGER.debug("calling function %r with resolved dependencies", func_name)

if self._self is not None:
return await utils.maybe_await(self._func(self._self, *args, **new_kwargs))
return await utils.maybe_await(self._func(*args, **new_kwargs))


def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Coroutine[t.Any, t.Any, R]]:
def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Coroutine[t.Any, t.Any, R]]:
"""
Decorator that enables dependency injection on the decorated function. If dependency injection
has been disabled globally then this function does nothing and simply returns the object that was passed in.
Expand Down
Loading