Skip to content

Commit 68a368e

Browse files
authored
feat: Bidirectional link possible between modules
1 parent 8ad8760 commit 68a368e

4 files changed

Lines changed: 77 additions & 20 deletions

File tree

injection/_core/module.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,25 @@ def __str__(self) -> str:
120120
return f"`{self.module}` has propagated an event: {self.origin}"
121121

122122
@property
123-
def history(self) -> Iterator[Event]:
124-
if isinstance(self.event, ModuleEventProxy):
125-
yield from self.event.history
126-
127-
yield self.event
123+
def is_duplicate(self) -> bool:
124+
module, origin = self.module, self.origin
125+
return any(
126+
module is event.module and origin is event.origin
127+
for event in self.proxy_history
128+
)
128129

129130
@property
130131
def origin(self) -> Event:
131-
return next(self.history)
132+
reversed_proxy_history = reversed(tuple(self.proxy_history))
133+
return next(reversed_proxy_history, self).event
134+
135+
@property
136+
def proxy_history(self) -> Iterator[ModuleEventProxy]:
137+
event = self.event
138+
139+
if isinstance(event, ModuleEventProxy):
140+
yield event
141+
yield from event.proxy_history
132142

133143

134144
@dataclass(frozen=True, slots=True)
@@ -162,8 +172,10 @@ def __str__(self) -> str:
162172

163173
@dataclass(frozen=True, slots=True)
164174
class UnlockCalled(Event):
175+
module: Module
176+
165177
def __str__(self) -> str:
166-
return "An `unlock` method has been called."
178+
return f"`{self.module}.unlock` has been called."
167179

168180

169181
"""
@@ -420,23 +432,18 @@ def __post_init__(self) -> None:
420432
self.__locator.add_listener(self)
421433

422434
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
423-
for broker in self.__brokers:
435+
for broker in self._iter_brokers():
424436
with suppress(KeyError):
425437
return broker[cls]
426438

427439
raise NoInjectable(cls)
428440

429441
def __contains__(self, cls: InputType[Any], /) -> bool:
430-
return any(cls in broker for broker in self.__brokers)
442+
return any(cls in broker for broker in self._iter_brokers())
431443

432444
@property
433445
def is_locked(self) -> bool:
434-
return any(broker.is_locked for broker in self.__brokers)
435-
436-
@property
437-
def __brokers(self) -> Iterator[Broker]:
438-
yield from self.__modules
439-
yield self.__locator
446+
return any(broker.is_locked for broker in self._iter_brokers())
440447

441448
def injectable[**P, T](
442449
self,
@@ -857,19 +864,19 @@ def change_priority(self, module: Module, priority: Priority | PriorityStr) -> S
857864
return self
858865

859866
def unlock(self) -> Self:
860-
event = UnlockCalled()
867+
event = UnlockCalled(self)
861868

862869
with self.dispatch(event, lock_bypass=True):
863870
self.unsafe_unlocking()
864871

865872
return self
866873

867874
def unsafe_unlocking(self) -> None:
868-
for broker in self.__brokers:
875+
for broker in self._iter_brokers():
869876
broker.unsafe_unlocking()
870877

871878
async def all_ready(self) -> None:
872-
for broker in self.__brokers:
879+
for broker in self._iter_brokers():
873880
await broker.all_ready()
874881

875882
def add_logger(self, logger: Logger) -> Self:
@@ -884,8 +891,12 @@ def remove_listener(self, listener: EventListener) -> Self:
884891
self.__channel.remove_listener(listener)
885892
return self
886893

887-
def on_event(self, event: Event, /) -> ContextManager[None]:
894+
def on_event(self, event: Event, /) -> ContextManager[None] | None:
888895
self_event = ModuleEventProxy(self, event)
896+
897+
if self_event.is_duplicate:
898+
return None
899+
889900
return self.dispatch(self_event)
890901

891902
@contextmanager
@@ -899,6 +910,20 @@ def dispatch(self, event: Event, *, lock_bypass: bool = False) -> Iterator[None]
899910
finally:
900911
self.__debug(event)
901912

913+
def _iter_brokers(self, visited: set[Module] | None = None, /) -> Iterator[Broker]:
914+
if visited is None:
915+
visited = set()
916+
917+
if self in visited:
918+
return
919+
920+
visited.add(self)
921+
922+
for module in self.__modules:
923+
yield from module._iter_brokers(visited)
924+
925+
yield self.__locator
926+
902927
def __debug(self, message: object) -> None:
903928
for logger in self.__loggers:
904929
logger.debug(message)

injection/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@ def _unload(self, name: str, /) -> None:
176176

177177
def __init_subsets_for(self, module: Module) -> Module:
178178
if not self.__is_empty and not self.__is_initialized(module):
179+
self.__mark_initialized(module)
179180
target_modules = tuple(
180181
self.__init_subsets_for(mod(name))
181182
for name in self.module_subsets.get(module.name, ())
182183
)
183184
module.init_modules(*target_modules)
184-
self.__mark_initialized(module)
185185

186186
return module
187187

tests/core/test_module.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ def test_use_with_module_already_in_use_raise_module_error(
292292

293293
event_history.assert_length(1)
294294

295+
def test_use_with_bidirectional_use(self, module, event_history):
296+
second_module = Module()
297+
third_module = Module()
298+
299+
module.use(second_module)
300+
second_module.use(module)
301+
module.use(third_module)
302+
303+
event_history.assert_length(4)
304+
295305
"""
296306
stop_using
297307
"""

tests/loaders/test_profile_loader.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,28 @@ class B(A): ...
108108

109109
assert type(find_instance(A)) is A
110110

111+
def test_load_with_bidirectional_link(self):
112+
profile_name_1 = uuid4().hex
113+
profile_name_2 = uuid4().hex
114+
115+
@mod(profile_name_1).injectable
116+
class BaseConfig: ...
117+
118+
@mod(profile_name_2).injectable
119+
@dataclass
120+
class Dependency:
121+
config: BaseConfig
122+
123+
loader = ProfileLoader(
124+
{
125+
profile_name_1: [profile_name_2],
126+
profile_name_2: [profile_name_1],
127+
}
128+
)
129+
130+
with loader.load(profile_name_1):
131+
assert find_instance(Dependency)
132+
111133
def test_load_with_default_profile_do_nothing(self):
112134
default_profile_name = mod().name
113135
global_profile_name = uuid4().hex

0 commit comments

Comments
 (0)