Skip to content

Commit ddd900a

Browse files
author
remimd
committed
refactor: Now analyze the types at the module level
1 parent 3218cc5 commit ddd900a

File tree

1 file changed

+51
-53
lines changed

1 file changed

+51
-53
lines changed

injection/_core/module.py

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
runtime_checkable,
4444
)
4545

46-
from type_analyzer import MatchingTypesConfig, iter_matching_types
46+
from type_analyzer import MatchingTypesConfig, iter_matching_types, matching_types
4747

4848
from injection._core.common.asynchronous import (
4949
AsyncCaller,
@@ -58,7 +58,6 @@
5858
from injection._core.common.threading import get_lock
5959
from injection._core.common.type import (
6060
InputType,
61-
TypeDef,
6261
TypeInfo,
6362
get_return_types,
6463
get_yield_hint,
@@ -247,20 +246,6 @@ class Updater[T]:
247246
def make_record(self) -> Record[T]:
248247
return Record(self.injectable, self.mode)
249248

250-
@classmethod
251-
def with_basics(
252-
cls,
253-
on: TypeInfo[T],
254-
/,
255-
injectable: Injectable[T],
256-
mode: Mode | ModeStr,
257-
) -> Self:
258-
return cls(
259-
classes=get_return_types(on),
260-
injectable=injectable,
261-
mode=Mode(mode),
262-
)
263-
264249

265250
@dataclass(repr=False, frozen=True, slots=True)
266251
class Locator(Broker):
@@ -274,20 +259,15 @@ class Locator(Broker):
274259
)
275260

276261
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
277-
for key_type in self.__iter_key_types((cls,)):
278-
try:
279-
record = self.__records[key_type]
280-
except KeyError:
281-
continue
282-
262+
try:
263+
record = self.__records[cls]
264+
except KeyError as exc:
265+
raise NoInjectable(cls) from exc
266+
else:
283267
return record.injectable
284268

285-
raise NoInjectable(cls)
286-
287269
def __contains__(self, cls: InputType[Any], /) -> bool:
288-
return any(
289-
key_type in self.__records for key_type in self.__iter_key_types((cls,))
290-
)
270+
return cls in self.__records
291271

292272
@property
293273
def is_locked(self) -> bool:
@@ -299,8 +279,7 @@ def __injectables(self) -> frozenset[Injectable[Any]]:
299279

300280
def update[T](self, updater: Updater[T]) -> Self:
301281
record = updater.make_record()
302-
key_types = self.__build_key_types(updater.classes)
303-
records = dict(self.__prepare_for_updating(key_types, record))
282+
records = dict(self.__prepare_for_updating(updater.classes, record))
304283

305284
if records:
306285
event = LocatorDependenciesUpdated(self, records.keys(), record.mode)
@@ -345,21 +324,6 @@ def __prepare_for_updating[T](
345324

346325
yield cls, record
347326

348-
@staticmethod
349-
def __build_key_types[T](classes: Iterable[InputType[T]]) -> frozenset[TypeDef[T]]:
350-
config = MatchingTypesConfig(ignore_none=True)
351-
return frozenset(
352-
itertools.chain.from_iterable(
353-
iter_matching_types(cls, config) for cls in classes
354-
)
355-
)
356-
357-
@staticmethod
358-
def __iter_key_types[T](classes: Iterable[InputType[T]]) -> Iterator[InputType[T]]:
359-
config = MatchingTypesConfig(with_origin=True, with_type_alias_value=True)
360-
for cls in classes:
361-
yield from iter_matching_types(cls, config)
362-
363327
@staticmethod
364328
def __keep_new_record[T](
365329
new: Record[T],
@@ -432,14 +396,22 @@ def __post_init__(self) -> None:
432396
self.__locator.add_listener(self)
433397

434398
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
399+
key_types = self.__matching_key_types(cls)
400+
435401
for broker in self._iter_brokers():
436-
with suppress(KeyError):
437-
return broker[cls]
402+
for key_type in key_types:
403+
with suppress(KeyError):
404+
return broker[key_type]
438405

439406
raise NoInjectable(cls)
440407

441408
def __contains__(self, cls: InputType[Any], /) -> bool:
442-
return any(cls in broker for broker in self._iter_brokers())
409+
key_types = self.__matching_key_types(cls)
410+
return any(
411+
key_type in broker
412+
for broker in self._iter_brokers()
413+
for key_type in key_types
414+
)
443415

444416
@property
445417
def is_locked(self) -> bool:
@@ -460,8 +432,7 @@ def decorator(wp: Recipe[P, T]) -> Recipe[P, T]:
460432
factory = extract_caller(self.make_injected_function(wp) if inject else wp)
461433
injectable = cls(factory) # type: ignore[arg-type]
462434
hints = on if ignore_type_hint else (wp, on)
463-
updater = Updater.with_basics(hints, injectable, mode)
464-
self.update(updater)
435+
self.update_from(hints, injectable, mode)
465436
return wp
466437

467438
return decorator(wrapped) if wrapped else decorator
@@ -512,8 +483,7 @@ def decorator(
512483
def should_be_injectable[T](self, wrapped: type[T] | None = None, /) -> Any:
513484
def decorator(wp: type[T]) -> type[T]:
514485
injectable = ShouldBeInjectable(wp)
515-
updater = Updater.with_basics(wp, injectable, Mode.FALLBACK)
516-
self.update(updater)
486+
self.update_from(wp, injectable, Mode.FALLBACK)
517487
return wp
518488

519489
return decorator(wrapped) if wrapped else decorator
@@ -571,8 +541,7 @@ def reserve_scoped_slot[T](
571541
mode: Mode | ModeStr = Mode.get_default(),
572542
) -> SlotKey[T]:
573543
injectable = ScopedSlotInjectable(cls, scope_name)
574-
updater = Updater.with_basics(cls, injectable, mode)
575-
self.update(updater)
544+
self.update_from(cls, injectable, mode)
576545
return injectable.key
577546

578547
def inject[**P, T](
@@ -795,6 +764,21 @@ def update[T](self, updater: Updater[T]) -> Self:
795764
self.__locator.update(updater)
796765
return self
797766

767+
def update_from[T](
768+
self,
769+
on: TypeInfo[T],
770+
/,
771+
injectable: Injectable[T],
772+
mode: Mode | ModeStr,
773+
) -> Self:
774+
updater = Updater(
775+
classes=self.__build_key_types(on),
776+
injectable=injectable,
777+
mode=Mode(mode),
778+
)
779+
self.update(updater)
780+
return self
781+
798782
def init_modules(self, *modules: Module) -> Self:
799783
for module in tuple(self.__modules):
800784
self.stop_using(module)
@@ -955,6 +939,20 @@ def from_name(cls, name: str) -> Module:
955939
def default(cls) -> Module:
956940
return cls.from_name("__default__")
957941

942+
@staticmethod
943+
def __build_key_types(on: Any) -> frozenset[Any]:
944+
config = MatchingTypesConfig(ignore_none=True)
945+
return frozenset(
946+
itertools.chain.from_iterable(
947+
iter_matching_types(cls, config) for cls in get_return_types(on)
948+
)
949+
)
950+
951+
@staticmethod
952+
def __matching_key_types(cls: Any) -> tuple[Any, ...]:
953+
config = MatchingTypesConfig(with_origin=True, with_type_alias_value=True)
954+
return matching_types(cls, config)
955+
958956

959957
def mod(name: str | None = None, /) -> Module:
960958
if name is None:

0 commit comments

Comments
 (0)