11import asyncio
2- from abc import ABC , abstractmethod
2+ from abc import abstractmethod
33from collections import defaultdict
44from collections .abc import Callable
55from dataclasses import dataclass , field
99
1010import injection
1111
12- from cq ._core .middleware import Middleware , MiddlewareGroup
12+ from cq ._core .dispatcher . base import BaseDispatcher , Dispatcher
1313
1414type HandlerType [** P , T ] = type [Handler [P , T ]]
1515type HandlerFactory [** P , T ] = Callable [..., Handler [P , T ]]
@@ -27,30 +27,13 @@ async def handle(self, *args: P.args, **kwargs: P.kwargs) -> T:
2727
2828
2929@runtime_checkable
30- class Bus [I , O ](Protocol ):
30+ class Bus [I , O ](Dispatcher [ I , O ], Protocol ):
3131 __slots__ = ()
3232
33- @abstractmethod
34- async def dispatch (self , input_value : I , / ) -> O :
35- raise NotImplementedError
36-
37- def dispatch_no_wait (self , first_input_value : I , / , * input_values : I ) -> None :
38- asyncio .gather (
39- * (
40- self .dispatch (input_value )
41- for input_value in (first_input_value , * input_values )
42- ),
43- return_exceptions = True ,
44- )
45-
4633 @abstractmethod
4734 def subscribe (self , input_type : type [I ], factory : HandlerFactory [[I ], O ]) -> Self :
4835 raise NotImplementedError
4936
50- @abstractmethod
51- def add_middlewares (self , * middlewares : Middleware [[I ], O ]) -> Self :
52- raise NotImplementedError
53-
5437
5538@dataclass (eq = False , frozen = True , slots = True )
5639class SubscriberDecorator [I , O ]:
@@ -81,23 +64,7 @@ def __find_bus(self) -> Bus[I, O]:
8164 return self .injection_module .find_instance (self .bus_type )
8265
8366
84- class _BaseBus [I , O ](Bus [I , O ], ABC ):
85- __slots__ = ("__middleware_group" ,)
86-
87- __middleware_group : MiddlewareGroup [[I ], O ]
88-
89- def __init__ (self ) -> None :
90- self .__middleware_group = MiddlewareGroup ()
91-
92- def add_middlewares (self , * middlewares : Middleware [[I ], O ]) -> Self :
93- self .__middleware_group .add (* middlewares )
94- return self
95-
96- async def _invoke (self , handler : Handler [[I ], O ], input_value : I , / ) -> O :
97- return await self .__middleware_group .invoke (handler .handle , input_value )
98-
99-
100- class SimpleBus [I , O ](_BaseBus [I , O ]):
67+ class SimpleBus [I , O ](BaseDispatcher [I , O ], Bus [I , O ]):
10168 __slots__ = ("__handlers" ,)
10269
10370 __handlers : dict [type [I ], HandlerFactory [[I ], O ]]
@@ -114,7 +81,10 @@ async def dispatch(self, input_value: I, /) -> O:
11481 except KeyError :
11582 return NotImplemented
11683
117- return await self ._invoke (handler_factory (), input_value )
84+ return await self ._invoke_with_middlewares (
85+ handler_factory ().handle ,
86+ input_value ,
87+ )
11888
11989 def subscribe (self , input_type : type [I ], factory : HandlerFactory [[I ], O ]) -> Self :
12090 if input_type in self .__handlers :
@@ -126,7 +96,7 @@ def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Sel
12696 return self
12797
12898
129- class TaskBus [I ](_BaseBus [I , None ]):
99+ class TaskBus [I ](BaseDispatcher [ I , None ], Bus [I , None ]):
130100 __slots__ = ("__handlers" ,)
131101
132102 __handlers : dict [type [I ], list [HandlerFactory [[I ], None ]]]
@@ -143,7 +113,10 @@ async def dispatch(self, input_value: I, /) -> None:
143113
144114 await asyncio .gather (
145115 * (
146- self ._invoke (handler_factory (), input_value )
116+ self ._invoke_with_middlewares (
117+ handler_factory ().handle ,
118+ input_value ,
119+ )
147120 for handler_factory in handler_factories
148121 )
149122 )
0 commit comments