1- import asyncio
21from abc import ABC , abstractmethod
32from collections import defaultdict
43from collections .abc import Awaitable , Callable
54from dataclasses import dataclass , field
6- from inspect import isclass
5+ from inspect import getmro , isclass
76from types import GenericAlias
87from typing import Any , Protocol , Self , TypeAliasType , runtime_checkable
98
9+ import anyio
1010import injection
1111
1212from cq ._core .dispatcher .base import BaseDispatcher , Dispatcher
@@ -52,18 +52,12 @@ def decorator(wrapped: type[Handler[[I], O]]) -> type[Handler[[I], O]]:
5252 raise TypeError (f"`{ wrapped } ` isn't a valid handler." )
5353
5454 bus = self .injection_module .find_instance (self .bus_type )
55- lazy_instance = self .injection_module .aget_lazy_instance (
56- wrapped ,
57- default = NotImplemented ,
58- )
59-
60- async def getter () -> Handler [[I ], O ]:
61- return await lazy_instance
55+ factory = self .injection_module .make_async_factory (wrapped )
6256
6357 for input_type in (first_input_type , * input_types ):
64- bus .subscribe (input_type , getter )
58+ bus .subscribe (input_type , factory )
6559
66- return self . injection_module . injectable ( wrapped )
60+ return wrapped
6761
6862 return decorator
6963
@@ -82,7 +76,24 @@ def add_listeners(self, *listeners: Listener[I]) -> Self:
8276 return self
8377
8478 async def _trigger_listeners (self , input_value : I , / ) -> None :
85- await asyncio .gather (* (listener (input_value ) for listener in self .__listeners ))
79+ listeners = self .__listeners
80+
81+ if not listeners :
82+ return
83+
84+ async with anyio .create_task_group () as task_group :
85+ for listener in listeners :
86+ task_group .start_soon (listener , input_value )
87+
88+ @staticmethod
89+ def _make_handle_function (
90+ handler_factory : HandlerFactory [[I ], O ],
91+ ) -> Callable [[I ], Awaitable [O ]]:
92+ async def handle (input_value : I ) -> O :
93+ handler = await handler_factory ()
94+ return await handler .handle (input_value )
95+
96+ return handle
8697
8798
8899class SimpleBus [I , O ](BaseBus [I , O ]):
@@ -96,15 +107,16 @@ def __init__(self) -> None:
96107
97108 async def dispatch (self , input_value : I , / ) -> O :
98109 await self ._trigger_listeners (input_value )
99- input_type = type (input_value )
100110
101- try :
102- handler_factory = self .__handlers [input_type ]
103- except KeyError :
111+ for input_type in getmro (type (input_value )):
112+ if handler_factory := self .__handlers .get (input_type ):
113+ break
114+
115+ else :
104116 return NotImplemented
105117
106- handler = await handler_factory ( )
107- return await self ._invoke_with_middlewares (handler . handle , input_value )
118+ handler = self . _make_handle_function ( handler_factory )
119+ return await self ._invoke_with_middlewares (handler , input_value )
108120
109121 def subscribe (self , input_type : type [I ], factory : HandlerFactory [[I ], O ]) -> Self :
110122 if input_type in self .__handlers :
@@ -127,20 +139,22 @@ def __init__(self) -> None:
127139
128140 async def dispatch (self , input_value : I , / ) -> None :
129141 await self ._trigger_listeners (input_value )
130- handler_factories = self .__handlers .get (type (input_value ))
131142
132- if not handler_factories :
143+ for input_type in getmro (type (input_value )):
144+ if handler_factories := self .__handlers .get (input_type ):
145+ break
146+
147+ else :
133148 return
134149
135- await asyncio .gather (
136- * [
137- self ._invoke_with_middlewares (
138- (await handler_factory ()).handle ,
150+ async with anyio .create_task_group () as task_group :
151+ for handler_factory in handler_factories :
152+ handler = self ._make_handle_function (handler_factory )
153+ task_group .start_soon (
154+ self ._invoke_with_middlewares ,
155+ handler ,
139156 input_value ,
140157 )
141- for handler_factory in handler_factories
142- ]
143- )
144158
145159 def subscribe (
146160 self ,
0 commit comments