1- from collections import deque
1+ from abc import abstractmethod
22from collections .abc import Awaitable , Callable
33from dataclasses import dataclass , field
4- from typing import TYPE_CHECKING , Any , Protocol , Self , overload
5-
6- from cq ._core .common .typing import Decorator
4+ from functools import partial
5+ from inspect import iscoroutinefunction
6+ from typing import (
7+ TYPE_CHECKING ,
8+ Any ,
9+ Concatenate ,
10+ Protocol ,
11+ Self ,
12+ overload ,
13+ runtime_checkable ,
14+ )
15+
16+ from cq ._core .common .typing import Decorator , Method
717from cq ._core .dispatcher .base import BaseDispatcher , Dispatcher
8- from cq ._core .middleware import Middleware
18+ from cq ._core .middleware import Middleware , MiddlewareGroup
919
10- type PipeConverter [I , O ] = Callable [[O ], Awaitable [I ]]
20+ type ConvertAsync [** P , I , O ] = Callable [Concatenate [O , P ], Awaitable [I ]]
21+ type ConvertSync [** P , I , O ] = Callable [Concatenate [O , P ], I ]
22+ type Convert [** P , I , O ] = ConvertAsync [P , I , O ] | ConvertSync [P , I , O ]
1123
24+ type ConvertMethodAsync [I , O ] = Method [[O ], Awaitable [I ]]
25+ type ConvertMethodSync [I , O ] = Method [[O ], I ]
26+ type ConvertMethod [I , O ] = ConvertMethodAsync [I , O ] | ConvertMethodSync [I , O ]
1227
13- class PipeConverterMethod [I , O ](Protocol ):
14- def __get__ (
15- self ,
16- instance : object ,
17- owner : type | None = ...,
18- ) -> PipeConverter [I , O ]: ...
28+
29+ @runtime_checkable
30+ class PipelineConverter [** P , I , O ](Protocol ):
31+ __slots__ = ()
32+
33+ @abstractmethod
34+ async def convert (self , output_value : O , / , * args : P .args , ** kwargs : P .kwargs ) -> I :
35+ raise NotImplementedError
1936
2037
2138@dataclass (repr = False , eq = False , frozen = True , slots = True )
22- class PipeStep [ I , O ]:
23- converter : PipeConverter [ I , O ]
39+ class PipelineStep [ ** P , I , O ]:
40+ converter : PipelineConverter [ P , I , O ]
2441 dispatcher : Dispatcher [I , Any ] | None = field (default = None )
2542
2643
44+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
45+ class PipelineSteps [** P , I , O ]:
46+ default_dispatcher : Dispatcher [Any , Any ]
47+ __steps : list [PipelineStep [P , Any , Any ]] = field (default_factory = list , init = False )
48+
49+ def add [T ](
50+ self ,
51+ converter : PipelineConverter [P , T , Any ],
52+ dispatcher : Dispatcher [T , Any ] | None ,
53+ ) -> Self :
54+ self .__steps .append (PipelineStep (converter , dispatcher ))
55+ return self
56+
57+ async def execute (self , input_value : I , / , * args : P .args , ** kwargs : P .kwargs ) -> O :
58+ dispatcher = self .default_dispatcher
59+
60+ for step in self .__steps :
61+ output_value = await dispatcher .dispatch (input_value )
62+ input_value = await step .converter .convert (output_value , * args , ** kwargs )
63+
64+ if input_value is None :
65+ return NotImplemented
66+
67+ dispatcher = step .dispatcher or self .default_dispatcher
68+
69+ return await dispatcher .dispatch (input_value )
70+
71+
2772class Pipe [I , O ](BaseDispatcher [I , O ]):
28- __slots__ = ("__dispatcher" , " __steps" )
73+ __slots__ = ("__steps" , )
2974
30- __dispatcher : Dispatcher [Any , Any ]
31- __steps : list [PipeStep [Any , Any ]]
75+ __steps : PipelineSteps [[], I , O ]
3276
3377 def __init__ (self , dispatcher : Dispatcher [Any , Any ]) -> None :
3478 super ().__init__ ()
35- self .__dispatcher = dispatcher
36- self .__steps = []
79+ self .__steps = PipelineSteps (dispatcher )
3780
3881 if TYPE_CHECKING : # pragma: no cover
3982
4083 @overload
4184 def step [T ](
4285 self ,
43- wrapped : PipeConverter [T , Any ],
86+ wrapped : ConvertAsync [[], T , Any ],
87+ / ,
88+ * ,
89+ dispatcher : Dispatcher [T , Any ] | None = ...,
90+ ) -> ConvertAsync [[], T , Any ]: ...
91+
92+ @overload
93+ def step [T ](
94+ self ,
95+ wrapped : ConvertSync [[], T , Any ],
4496 / ,
4597 * ,
4698 dispatcher : Dispatcher [T , Any ] | None = ...,
47- ) -> PipeConverter [ T , Any ]: ...
99+ ) -> ConvertSync [[], T , Any ]: ...
48100
49101 @overload
50102 def step (
@@ -57,14 +109,18 @@ def step(
57109
58110 def step [T ](
59111 self ,
60- wrapped : PipeConverter [ T , Any ] | None = None ,
112+ wrapped : Convert [[], T , Any ] | None = None ,
61113 / ,
62114 * ,
63115 dispatcher : Dispatcher [T , Any ] | None = None ,
64116 ) -> Any :
65- def decorator (wp : PipeConverter [T , Any ]) -> PipeConverter [T , Any ]:
66- step = PipeStep (wp , dispatcher )
67- self .__steps .append (step )
117+ def decorator (wp : Convert [[], T , Any ]) -> Convert [[], T , Any ]:
118+ converter = (
119+ _AsyncPipelineConverter (wp )
120+ if iscoroutinefunction (wp )
121+ else _SyncPipelineConverter (wp )
122+ )
123+ self .__steps .add (converter , dispatcher )
68124 return wp
69125
70126 return decorator (wrapped ) if wrapped else decorator
@@ -75,47 +131,23 @@ def add_static_step[T](
75131 * ,
76132 dispatcher : Dispatcher [T , Any ] | None = None ,
77133 ) -> Self :
78- @self .step (dispatcher = dispatcher )
79- async def converter (_ : Any ) -> T :
80- return input_value
81-
134+ converter = _StaticPipelineConverter (input_value )
135+ self .__steps .add (converter , dispatcher )
82136 return self
83137
84138 async def dispatch (self , input_value : I , / ) -> O :
85- return await self ._invoke_with_middlewares (self .__execute , input_value )
86-
87- async def __execute (self , input_value : I ) -> O :
88- dispatcher = self .__dispatcher
89-
90- for step in self .__steps :
91- output_value = await dispatcher .dispatch (input_value )
92- input_value = await step .converter (output_value )
93-
94- if input_value is None :
95- return NotImplemented
96-
97- dispatcher = step .dispatcher or self .__dispatcher
98-
99- return await dispatcher .dispatch (input_value )
100-
101-
102- @dataclass (repr = False , eq = False , frozen = True , slots = True )
103- class ContextPipelineStep [I , O ]:
104- converter : PipeConverterMethod [I , O ]
105- dispatcher : Dispatcher [I , Any ] | None = field (default = None )
139+ return await self ._invoke_with_middlewares (self .__steps .execute , input_value )
106140
107141
108142class ContextPipeline [I ]:
109- __slots__ = ("__dispatcher" , "__middlewares " , "__steps" )
143+ __slots__ = ("__middleware_group " , "__steps" )
110144
111- __dispatcher : Dispatcher [Any , Any ]
112- __middlewares : deque [Middleware [Any , Any ]]
113- __steps : list [ContextPipelineStep [Any , Any ]]
145+ __middleware_group : MiddlewareGroup [[I ], Any ]
146+ __steps : PipelineSteps [[object , type | None ], I , Any ]
114147
115148 def __init__ (self , dispatcher : Dispatcher [Any , Any ]) -> None :
116- self .__dispatcher = dispatcher
117- self .__middlewares = deque ()
118- self .__steps = []
149+ self .__middleware_group = MiddlewareGroup ()
150+ self .__steps = PipelineSteps (dispatcher )
119151
120152 if TYPE_CHECKING : # pragma: no cover
121153
@@ -145,23 +177,32 @@ def __get__[O](
145177
146178 instance = owner ()
147179
148- pipeline = self .__new_pipeline ( instance , owner )
149- return BoundContextPipeline (instance , pipeline )
180+ dispatch_method = partial ( self .__execute , context = instance , context_type = owner )
181+ return BoundContextPipeline (dispatch_method )
150182
151183 def add_middlewares (self , * middlewares : Middleware [[I ], Any ]) -> Self :
152- self .__middlewares . extendleft ( reversed ( middlewares ) )
184+ self .__middleware_group . add ( * middlewares )
153185 return self
154186
155187 if TYPE_CHECKING : # pragma: no cover
156188
157189 @overload
158190 def step [T ](
159191 self ,
160- wrapped : PipeConverterMethod [T , Any ],
192+ wrapped : ConvertMethodAsync [T , Any ],
193+ / ,
194+ * ,
195+ dispatcher : Dispatcher [T , Any ] | None = ...,
196+ ) -> ConvertMethodAsync [T , Any ]: ...
197+
198+ @overload
199+ def step [T ](
200+ self ,
201+ wrapped : ConvertMethodSync [T , Any ],
161202 / ,
162203 * ,
163204 dispatcher : Dispatcher [T , Any ] | None = ...,
164- ) -> PipeConverterMethod [T , Any ]: ...
205+ ) -> ConvertMethodSync [T , Any ]: ...
165206
166207 @overload
167208 def step (
@@ -174,38 +215,98 @@ def step(
174215
175216 def step [T ](
176217 self ,
177- wrapped : PipeConverterMethod [T , Any ] | None = None ,
218+ wrapped : ConvertMethod [T , Any ] | None = None ,
178219 / ,
179220 * ,
180221 dispatcher : Dispatcher [T , Any ] | None = None ,
181222 ) -> Any :
182- def decorator (wp : PipeConverterMethod [T , Any ]) -> PipeConverterMethod [T , Any ]:
183- step = ContextPipelineStep (wp , dispatcher )
184- self .__steps .append (step )
223+ def decorator (wp : ConvertMethod [T , Any ]) -> ConvertMethod [T , Any ]:
224+ converter = (
225+ _AsyncContextPipelineConverter (wp )
226+ if iscoroutinefunction (wp )
227+ else _SyncContextPipelineConverter (wp )
228+ )
229+ self .__steps .add (converter , dispatcher )
185230 return wp
186231
187232 return decorator (wrapped ) if wrapped else decorator
188233
189- def __new_pipeline [ T ](
234+ async def __execute [ O ](
190235 self ,
191- context : T ,
192- context_type : type [ T ] | None ,
193- ) -> Pipe [ I , Any ]:
194- pipeline : Pipe [ I , Any ] = Pipe ( self . __dispatcher )
195- pipeline . add_middlewares ( * self . __middlewares )
196-
197- for step in self .__steps :
198- converter = step . converter . __get__ ( context , context_type )
199- pipeline . step ( converter , dispatcher = step . dispatcher )
200-
201- return pipeline
236+ input_value : I ,
237+ / ,
238+ * ,
239+ context : O ,
240+ context_type : type [ O ] | None ,
241+ ) -> O :
242+ await self .__middleware_group . invoke (
243+ lambda i : self . __steps . execute ( i , context , context_type ),
244+ input_value ,
245+ )
246+ return context
202247
203248
204249@dataclass (repr = False , eq = False , frozen = True , slots = True )
205250class BoundContextPipeline [I , O ](Dispatcher [I , O ]):
206- context : O
207- pipeline : Pipe [I , Any ]
251+ dispatch_method : Callable [[I ], Awaitable [O ]]
208252
209253 async def dispatch (self , input_value : I , / ) -> O :
210- await self .pipeline .dispatch (input_value )
211- return self .context
254+ return await self .dispatch_method (input_value )
255+
256+
257+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
258+ class _AsyncPipelineConverter [** P , I , O ](PipelineConverter [P , I , O ]):
259+ converter : ConvertAsync [P , I , O ]
260+
261+ async def convert (self , output_value : O , / , * args : P .args , ** kwargs : P .kwargs ) -> I :
262+ return await self .converter (output_value , * args , ** kwargs )
263+
264+
265+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
266+ class _SyncPipelineConverter [** P , I , O ](PipelineConverter [P , I , O ]):
267+ converter : ConvertSync [P , I , O ]
268+
269+ async def convert (self , output_value : O , / , * args : P .args , ** kwargs : P .kwargs ) -> I :
270+ return self .converter (output_value , * args , ** kwargs )
271+
272+
273+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
274+ class _StaticPipelineConverter [I ](PipelineConverter [..., I , Any ]):
275+ input_value : I
276+
277+ async def convert (self , output_value : Any , / , * args : Any , ** kwargs : Any ) -> I :
278+ return self .input_value
279+
280+
281+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
282+ class _AsyncContextPipelineConverter [I , O ](
283+ PipelineConverter [[object , type | None ], I , O ],
284+ ):
285+ converter : ConvertMethodAsync [I , O ]
286+
287+ async def convert (
288+ self ,
289+ output_value : O ,
290+ / ,
291+ context : object ,
292+ context_type : type | None ,
293+ ) -> I :
294+ method = self .converter .__get__ (context , context_type )
295+ return await method (output_value )
296+
297+
298+ @dataclass (repr = False , eq = False , frozen = True , slots = True )
299+ class _SyncContextPipelineConverter [I , O ](
300+ PipelineConverter [[object , type | None ], I , O ],
301+ ):
302+ converter : ConvertMethodSync [I , O ]
303+
304+ async def convert (
305+ self ,
306+ output_value : O ,
307+ / ,
308+ context : object ,
309+ context_type : type | None ,
310+ ) -> I :
311+ method = self .converter .__get__ (context , context_type )
312+ return method (output_value )
0 commit comments