11import asyncio
2- from abc import abstractmethod
2+ from abc import ABC , abstractmethod
33from collections import defaultdict
4- from collections .abc import Callable
4+ from collections .abc import Awaitable , Callable
55from dataclasses import dataclass , field
66from inspect import isclass
77from types import GenericAlias
8- from typing import Protocol , Self , TypeAliasType , runtime_checkable
8+ from typing import Any , Protocol , Self , TypeAliasType , runtime_checkable
99
1010import injection
1111
1414type HandlerType [** P , T ] = type [Handler [P , T ]]
1515type HandlerFactory [** P , T ] = Callable [..., Handler [P , T ]]
1616
17+ type Listener [T ] = Callable [[T ], Awaitable [Any ]]
18+
1719type BusType [I , O ] = type [Bus [I , O ]]
1820
1921
@@ -34,6 +36,10 @@ class Bus[I, O](Dispatcher[I, O], Protocol):
3436 def subscribe (self , input_type : type [I ], factory : HandlerFactory [[I ], O ]) -> Self :
3537 raise NotImplementedError
3638
39+ @abstractmethod
40+ def add_listeners (self , * listeners : Listener [I ]) -> Self :
41+ raise NotImplementedError
42+
3743
3844@dataclass (eq = False , frozen = True , slots = True )
3945class SubscriberDecorator [I , O ]:
@@ -59,7 +65,24 @@ def __find_bus(self) -> Bus[I, O]:
5965 return self .injection_module .find_instance (self .bus_type )
6066
6167
62- class SimpleBus [I , O ](BaseDispatcher [I , O ], Bus [I , O ]):
68+ class BaseBus [I , O ](BaseDispatcher [I , O ], Bus [I , O ], ABC ):
69+ __slots__ = ("__listeners" ,)
70+
71+ __listeners : list [Listener [I ]]
72+
73+ def __init__ (self ) -> None :
74+ super ().__init__ ()
75+ self .__listeners = []
76+
77+ def add_listeners (self , * listeners : Listener [I ]) -> Self :
78+ self .__listeners .extend (listeners )
79+ return self
80+
81+ async def _trigger_listeners (self , input_value : I , / ) -> None :
82+ await asyncio .gather (* (listener (input_value ) for listener in self .__listeners ))
83+
84+
85+ class SimpleBus [I , O ](BaseBus [I , O ]):
6386 __slots__ = ("__handlers" ,)
6487
6588 __handlers : dict [type [I ], HandlerFactory [[I ], O ]]
@@ -69,6 +92,7 @@ def __init__(self) -> None:
6992 self .__handlers = {}
7093
7194 async def dispatch (self , input_value : I , / ) -> O :
95+ await self ._trigger_listeners (input_value )
7296 input_type = type (input_value )
7397
7498 try :
@@ -91,7 +115,7 @@ def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Sel
91115 return self
92116
93117
94- class TaskBus [I ](BaseDispatcher [ I , None ], Bus [I , None ]):
118+ class TaskBus [I ](BaseBus [I , None ]):
95119 __slots__ = ("__handlers" ,)
96120
97121 __handlers : dict [type [I ], list [HandlerFactory [[I ], None ]]]
@@ -101,6 +125,7 @@ def __init__(self) -> None:
101125 self .__handlers = defaultdict (list )
102126
103127 async def dispatch (self , input_value : I , / ) -> None :
128+ await self ._trigger_listeners (input_value )
104129 handler_factories = self .__handlers .get (type (input_value ))
105130
106131 if not handler_factories :
0 commit comments