Skip to content

Commit 7af0fdd

Browse files
authored
feat: No longer instantiate a dependency if it has been explicitly passed
1 parent e6737cf commit 7af0fdd

File tree

3 files changed

+104
-83
lines changed

3 files changed

+104
-83
lines changed

injection/_core/module.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
AsyncIterator,
99
Awaitable,
1010
Callable,
11+
Container,
1112
Generator,
1213
Iterable,
1314
Iterator,
@@ -18,6 +19,7 @@
1819
from enum import StrEnum
1920
from functools import partial, partialmethod, singledispatchmethod, update_wrapper
2021
from inspect import (
22+
BoundArguments,
2123
Signature,
2224
isasyncgenfunction,
2325
isclass,
@@ -739,28 +741,32 @@ def mod(name: str | None = None, /) -> Module:
739741
class Dependencies:
740742
lazy_mapping: Lazy[Mapping[str, Injectable[Any]]]
741743

742-
def __iter__(self) -> Iterator[tuple[str, Any]]:
743-
for name, injectable in self.items():
744+
def iter(self, exclude: Container[str]) -> Iterator[tuple[str, Any]]:
745+
for name, injectable in self.items(exclude):
744746
with suppress(SkipInjectable):
745747
yield name, injectable.get_instance()
746748

747-
async def __aiter__(self) -> AsyncIterator[tuple[str, Any]]:
748-
for name, injectable in self.items():
749+
async def aiter(self, exclude: Container[str]) -> AsyncIterator[tuple[str, Any]]:
750+
for name, injectable in self.items(exclude):
749751
with suppress(SkipInjectable):
750752
yield name, await injectable.aget_instance()
751753

752754
@property
753755
def are_resolved(self) -> bool:
754756
return self.lazy_mapping.is_set
755757

756-
async def aget_arguments(self) -> dict[str, Any]:
757-
return {key: value async for key, value in self}
758+
async def aget_arguments(self, *, exclude: Container[str]) -> dict[str, Any]:
759+
return {key: value async for key, value in self.aiter(exclude)}
758760

759-
def get_arguments(self) -> dict[str, Any]:
760-
return dict(self)
761+
def get_arguments(self, *, exclude: Container[str]) -> dict[str, Any]:
762+
return dict(self.iter(exclude))
761763

762-
def items(self) -> Iterator[tuple[str, Injectable[Any]]]:
763-
return iter((~self.lazy_mapping).items())
764+
def items(self, exclude: Container[str]) -> Iterator[tuple[str, Injectable[Any]]]:
765+
return (
766+
(name, injectable)
767+
for name, injectable in (~self.lazy_mapping).items()
768+
if name not in exclude
769+
)
764770

765771
@classmethod
766772
def from_iterable(cls, iterable: Iterable[tuple[str, Injectable[Any]]]) -> Self:
@@ -858,21 +864,21 @@ def signature(self) -> Signature:
858864
def wrapped(self) -> Callable[P, T]:
859865
return self.__wrapped
860866

861-
async def abind(
862-
self,
863-
args: Iterable[Any] = (),
864-
kwargs: Mapping[str, Any] | None = None,
865-
) -> Arguments:
866-
additional_arguments = await self.__dependencies.aget_arguments()
867-
return self.__bind(args, kwargs, additional_arguments)
867+
async def abind(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> Arguments:
868+
arguments = self.__get_arguments(args, kwargs)
869+
dependencies = await self.__dependencies.aget_arguments(exclude=arguments)
870+
if dependencies:
871+
return self.__merge_arguments(arguments, dependencies)
868872

869-
def bind(
870-
self,
871-
args: Iterable[Any] = (),
872-
kwargs: Mapping[str, Any] | None = None,
873-
) -> Arguments:
874-
additional_arguments = self.__dependencies.get_arguments()
875-
return self.__bind(args, kwargs, additional_arguments)
873+
return Arguments(args, kwargs)
874+
875+
def bind(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> Arguments:
876+
arguments = self.__get_arguments(args, kwargs)
877+
dependencies = self.__dependencies.get_arguments(exclude=arguments)
878+
if dependencies:
879+
return self.__merge_arguments(arguments, dependencies)
880+
881+
return Arguments(args, kwargs)
876882

877883
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
878884
with self.__lock:
@@ -921,20 +927,20 @@ def _(self, event: ModuleEvent, /) -> Iterator[None]:
921927
yield
922928
self.update(event.module)
923929

924-
def __bind(
930+
def __get_arguments(
925931
self,
926932
args: Iterable[Any],
927-
kwargs: Mapping[str, Any] | None,
928-
additional_arguments: dict[str, Any] | None,
929-
) -> Arguments:
930-
if kwargs is None:
931-
kwargs = {}
932-
933-
if not additional_arguments:
934-
return Arguments(args, kwargs)
935-
933+
kwargs: Mapping[str, Any],
934+
) -> dict[str, Any]:
936935
bound = self.signature.bind_partial(*args, **kwargs)
937-
bound.arguments = bound.arguments | additional_arguments | bound.arguments
936+
return bound.arguments
937+
938+
def __merge_arguments(
939+
self,
940+
arguments: dict[str, Any],
941+
additional_arguments: dict[str, Any],
942+
) -> Arguments:
943+
bound = BoundArguments(self.signature, additional_arguments | arguments) # type: ignore[arg-type]
938944
return Arguments(bound.args, bound.kwargs)
939945

940946
def __run_tasks(self) -> None:

tests/test_inject.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,18 @@ def _method(this=..., _: SomeInjectable = ...):
279279
@injectable
280280
class A:
281281
method = _method
282+
283+
def test_inject_with_passing_argument_do_not_lock_module(self, module):
284+
assert not module.is_locked
285+
286+
@module.singleton
287+
class A: ...
288+
289+
@module.inject
290+
def function(a: A): ...
291+
292+
function(A())
293+
assert not module.is_locked
294+
295+
function()
296+
assert module.is_locked

0 commit comments

Comments
 (0)