4343 runtime_checkable ,
4444)
4545
46- from type_analyzer import MatchingTypesConfig , iter_matching_types
46+ from type_analyzer import MatchingTypesConfig , iter_matching_types , matching_types
4747
4848from injection ._core .common .asynchronous import (
4949 AsyncCaller ,
5858from injection ._core .common .threading import get_lock
5959from injection ._core .common .type import (
6060 InputType ,
61- TypeDef ,
6261 TypeInfo ,
6362 get_return_types ,
6463 get_yield_hint ,
@@ -247,20 +246,6 @@ class Updater[T]:
247246 def make_record (self ) -> Record [T ]:
248247 return Record (self .injectable , self .mode )
249248
250- @classmethod
251- def with_basics (
252- cls ,
253- on : TypeInfo [T ],
254- / ,
255- injectable : Injectable [T ],
256- mode : Mode | ModeStr ,
257- ) -> Self :
258- return cls (
259- classes = get_return_types (on ),
260- injectable = injectable ,
261- mode = Mode (mode ),
262- )
263-
264249
265250@dataclass (repr = False , frozen = True , slots = True )
266251class Locator (Broker ):
@@ -274,20 +259,15 @@ class Locator(Broker):
274259 )
275260
276261 def __getitem__ [T ](self , cls : InputType [T ], / ) -> Injectable [T ]:
277- for key_type in self .__iter_key_types ((cls ,)):
278- try :
279- record = self .__records [key_type ]
280- except KeyError :
281- continue
282-
262+ try :
263+ record = self .__records [cls ]
264+ except KeyError as exc :
265+ raise NoInjectable (cls ) from exc
266+ else :
283267 return record .injectable
284268
285- raise NoInjectable (cls )
286-
287269 def __contains__ (self , cls : InputType [Any ], / ) -> bool :
288- return any (
289- key_type in self .__records for key_type in self .__iter_key_types ((cls ,))
290- )
270+ return cls in self .__records
291271
292272 @property
293273 def is_locked (self ) -> bool :
@@ -299,8 +279,7 @@ def __injectables(self) -> frozenset[Injectable[Any]]:
299279
300280 def update [T ](self , updater : Updater [T ]) -> Self :
301281 record = updater .make_record ()
302- key_types = self .__build_key_types (updater .classes )
303- records = dict (self .__prepare_for_updating (key_types , record ))
282+ records = dict (self .__prepare_for_updating (updater .classes , record ))
304283
305284 if records :
306285 event = LocatorDependenciesUpdated (self , records .keys (), record .mode )
@@ -345,21 +324,6 @@ def __prepare_for_updating[T](
345324
346325 yield cls , record
347326
348- @staticmethod
349- def __build_key_types [T ](classes : Iterable [InputType [T ]]) -> frozenset [TypeDef [T ]]:
350- config = MatchingTypesConfig (ignore_none = True )
351- return frozenset (
352- itertools .chain .from_iterable (
353- iter_matching_types (cls , config ) for cls in classes
354- )
355- )
356-
357- @staticmethod
358- def __iter_key_types [T ](classes : Iterable [InputType [T ]]) -> Iterator [InputType [T ]]:
359- config = MatchingTypesConfig (with_origin = True , with_type_alias_value = True )
360- for cls in classes :
361- yield from iter_matching_types (cls , config )
362-
363327 @staticmethod
364328 def __keep_new_record [T ](
365329 new : Record [T ],
@@ -432,14 +396,22 @@ def __post_init__(self) -> None:
432396 self .__locator .add_listener (self )
433397
434398 def __getitem__ [T ](self , cls : InputType [T ], / ) -> Injectable [T ]:
399+ key_types = self .__matching_key_types (cls )
400+
435401 for broker in self ._iter_brokers ():
436- with suppress (KeyError ):
437- return broker [cls ]
402+ for key_type in key_types :
403+ with suppress (KeyError ):
404+ return broker [key_type ]
438405
439406 raise NoInjectable (cls )
440407
441408 def __contains__ (self , cls : InputType [Any ], / ) -> bool :
442- return any (cls in broker for broker in self ._iter_brokers ())
409+ key_types = self .__matching_key_types (cls )
410+ return any (
411+ key_type in broker
412+ for broker in self ._iter_brokers ()
413+ for key_type in key_types
414+ )
443415
444416 @property
445417 def is_locked (self ) -> bool :
@@ -460,8 +432,7 @@ def decorator(wp: Recipe[P, T]) -> Recipe[P, T]:
460432 factory = extract_caller (self .make_injected_function (wp ) if inject else wp )
461433 injectable = cls (factory ) # type: ignore[arg-type]
462434 hints = on if ignore_type_hint else (wp , on )
463- updater = Updater .with_basics (hints , injectable , mode )
464- self .update (updater )
435+ self .update_from (hints , injectable , mode )
465436 return wp
466437
467438 return decorator (wrapped ) if wrapped else decorator
@@ -512,8 +483,7 @@ def decorator(
512483 def should_be_injectable [T ](self , wrapped : type [T ] | None = None , / ) -> Any :
513484 def decorator (wp : type [T ]) -> type [T ]:
514485 injectable = ShouldBeInjectable (wp )
515- updater = Updater .with_basics (wp , injectable , Mode .FALLBACK )
516- self .update (updater )
486+ self .update_from (wp , injectable , Mode .FALLBACK )
517487 return wp
518488
519489 return decorator (wrapped ) if wrapped else decorator
@@ -571,8 +541,7 @@ def reserve_scoped_slot[T](
571541 mode : Mode | ModeStr = Mode .get_default (),
572542 ) -> SlotKey [T ]:
573543 injectable = ScopedSlotInjectable (cls , scope_name )
574- updater = Updater .with_basics (cls , injectable , mode )
575- self .update (updater )
544+ self .update_from (cls , injectable , mode )
576545 return injectable .key
577546
578547 def inject [** P , T ](
@@ -795,6 +764,21 @@ def update[T](self, updater: Updater[T]) -> Self:
795764 self .__locator .update (updater )
796765 return self
797766
767+ def update_from [T ](
768+ self ,
769+ on : TypeInfo [T ],
770+ / ,
771+ injectable : Injectable [T ],
772+ mode : Mode | ModeStr ,
773+ ) -> Self :
774+ updater = Updater (
775+ classes = self .__build_key_types (on ),
776+ injectable = injectable ,
777+ mode = Mode (mode ),
778+ )
779+ self .update (updater )
780+ return self
781+
798782 def init_modules (self , * modules : Module ) -> Self :
799783 for module in tuple (self .__modules ):
800784 self .stop_using (module )
@@ -955,6 +939,20 @@ def from_name(cls, name: str) -> Module:
955939 def default (cls ) -> Module :
956940 return cls .from_name ("__default__" )
957941
942+ @staticmethod
943+ def __build_key_types (on : Any ) -> frozenset [Any ]:
944+ config = MatchingTypesConfig (ignore_none = True )
945+ return frozenset (
946+ itertools .chain .from_iterable (
947+ iter_matching_types (cls , config ) for cls in get_return_types (on )
948+ )
949+ )
950+
951+ @staticmethod
952+ def __matching_key_types (cls : Any ) -> tuple [Any , ...]:
953+ config = MatchingTypesConfig (with_origin = True , with_type_alias_value = True )
954+ return matching_types (cls , config )
955+
958956
959957def mod (name : str | None = None , / ) -> Module :
960958 if name is None :
0 commit comments