Skip to content

Commit ade4f5d

Browse files
authored
refactor: ContextPipeline
1 parent 05987df commit ade4f5d

7 files changed

Lines changed: 224 additions & 95 deletions

File tree

cq/_core/common/typing.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1-
from typing import Protocol
1+
from collections.abc import Callable
2+
from typing import Any, Protocol, overload
23

34

45
class Decorator(Protocol):
56
def __call__[T](self, wrapped: T, /) -> T: ...
7+
8+
9+
class Method[**P, T](Protocol):
10+
@overload
11+
def __call__(self, instance: Any, /, *args: P.args, **kwargs: P.kwargs) -> T: ...
12+
13+
@overload
14+
def __call__(self, /, *args: Any, **kwargs: Any) -> T: ...
15+
16+
def __get__(
17+
self,
18+
instance: object,
19+
owner: type | None = ...,
20+
/,
21+
) -> Callable[P, T]: ...

cq/_core/dispatcher/pipe.py

Lines changed: 183 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,102 @@
1-
from collections import deque
1+
from abc import abstractmethod
22
from collections.abc import Awaitable, Callable
33
from dataclasses import dataclass, field
4-
from typing import TYPE_CHECKING, Any, Protocol, Self, overload
5-
6-
from cq._core.common.typing import Decorator
4+
from functools import partial
5+
from inspect import iscoroutinefunction
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Concatenate,
10+
Protocol,
11+
Self,
12+
overload,
13+
runtime_checkable,
14+
)
15+
16+
from cq._core.common.typing import Decorator, Method
717
from cq._core.dispatcher.base import BaseDispatcher, Dispatcher
8-
from cq._core.middleware import Middleware
18+
from cq._core.middleware import Middleware, MiddlewareGroup
919

10-
type PipeConverter[I, O] = Callable[[O], Awaitable[I]]
20+
type ConvertAsync[**P, I, O] = Callable[Concatenate[O, P], Awaitable[I]]
21+
type ConvertSync[**P, I, O] = Callable[Concatenate[O, P], I]
22+
type Convert[**P, I, O] = ConvertAsync[P, I, O] | ConvertSync[P, I, O]
1123

24+
type ConvertMethodAsync[I, O] = Method[[O], Awaitable[I]]
25+
type ConvertMethodSync[I, O] = Method[[O], I]
26+
type ConvertMethod[I, O] = ConvertMethodAsync[I, O] | ConvertMethodSync[I, O]
1227

13-
class PipeConverterMethod[I, O](Protocol):
14-
def __get__(
15-
self,
16-
instance: object,
17-
owner: type | None = ...,
18-
) -> PipeConverter[I, O]: ...
28+
29+
@runtime_checkable
30+
class PipelineConverter[**P, I, O](Protocol):
31+
__slots__ = ()
32+
33+
@abstractmethod
34+
async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I:
35+
raise NotImplementedError
1936

2037

2138
@dataclass(repr=False, eq=False, frozen=True, slots=True)
22-
class PipeStep[I, O]:
23-
converter: PipeConverter[I, O]
39+
class PipelineStep[**P, I, O]:
40+
converter: PipelineConverter[P, I, O]
2441
dispatcher: Dispatcher[I, Any] | None = field(default=None)
2542

2643

44+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
45+
class PipelineSteps[**P, I, O]:
46+
default_dispatcher: Dispatcher[Any, Any]
47+
__steps: list[PipelineStep[P, Any, Any]] = field(default_factory=list, init=False)
48+
49+
def add[T](
50+
self,
51+
converter: PipelineConverter[P, T, Any],
52+
dispatcher: Dispatcher[T, Any] | None,
53+
) -> Self:
54+
self.__steps.append(PipelineStep(converter, dispatcher))
55+
return self
56+
57+
async def execute(self, input_value: I, /, *args: P.args, **kwargs: P.kwargs) -> O:
58+
dispatcher = self.default_dispatcher
59+
60+
for step in self.__steps:
61+
output_value = await dispatcher.dispatch(input_value)
62+
input_value = await step.converter.convert(output_value, *args, **kwargs)
63+
64+
if input_value is None:
65+
return NotImplemented
66+
67+
dispatcher = step.dispatcher or self.default_dispatcher
68+
69+
return await dispatcher.dispatch(input_value)
70+
71+
2772
class Pipe[I, O](BaseDispatcher[I, O]):
28-
__slots__ = ("__dispatcher", "__steps")
73+
__slots__ = ("__steps",)
2974

30-
__dispatcher: Dispatcher[Any, Any]
31-
__steps: list[PipeStep[Any, Any]]
75+
__steps: PipelineSteps[[], I, O]
3276

3377
def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None:
3478
super().__init__()
35-
self.__dispatcher = dispatcher
36-
self.__steps = []
79+
self.__steps = PipelineSteps(dispatcher)
3780

3881
if TYPE_CHECKING: # pragma: no cover
3982

4083
@overload
4184
def step[T](
4285
self,
43-
wrapped: PipeConverter[T, Any],
86+
wrapped: ConvertAsync[[], T, Any],
87+
/,
88+
*,
89+
dispatcher: Dispatcher[T, Any] | None = ...,
90+
) -> ConvertAsync[[], T, Any]: ...
91+
92+
@overload
93+
def step[T](
94+
self,
95+
wrapped: ConvertSync[[], T, Any],
4496
/,
4597
*,
4698
dispatcher: Dispatcher[T, Any] | None = ...,
47-
) -> PipeConverter[T, Any]: ...
99+
) -> ConvertSync[[], T, Any]: ...
48100

49101
@overload
50102
def step(
@@ -57,14 +109,18 @@ def step(
57109

58110
def step[T](
59111
self,
60-
wrapped: PipeConverter[T, Any] | None = None,
112+
wrapped: Convert[[], T, Any] | None = None,
61113
/,
62114
*,
63115
dispatcher: Dispatcher[T, Any] | None = None,
64116
) -> Any:
65-
def decorator(wp: PipeConverter[T, Any]) -> PipeConverter[T, Any]:
66-
step = PipeStep(wp, dispatcher)
67-
self.__steps.append(step)
117+
def decorator(wp: Convert[[], T, Any]) -> Convert[[], T, Any]:
118+
converter = (
119+
_AsyncPipelineConverter(wp)
120+
if iscoroutinefunction(wp)
121+
else _SyncPipelineConverter(wp)
122+
)
123+
self.__steps.add(converter, dispatcher)
68124
return wp
69125

70126
return decorator(wrapped) if wrapped else decorator
@@ -75,47 +131,23 @@ def add_static_step[T](
75131
*,
76132
dispatcher: Dispatcher[T, Any] | None = None,
77133
) -> Self:
78-
@self.step(dispatcher=dispatcher)
79-
async def converter(_: Any) -> T:
80-
return input_value
81-
134+
converter = _StaticPipelineConverter(input_value)
135+
self.__steps.add(converter, dispatcher)
82136
return self
83137

84138
async def dispatch(self, input_value: I, /) -> O:
85-
return await self._invoke_with_middlewares(self.__execute, input_value)
86-
87-
async def __execute(self, input_value: I) -> O:
88-
dispatcher = self.__dispatcher
89-
90-
for step in self.__steps:
91-
output_value = await dispatcher.dispatch(input_value)
92-
input_value = await step.converter(output_value)
93-
94-
if input_value is None:
95-
return NotImplemented
96-
97-
dispatcher = step.dispatcher or self.__dispatcher
98-
99-
return await dispatcher.dispatch(input_value)
100-
101-
102-
@dataclass(repr=False, eq=False, frozen=True, slots=True)
103-
class ContextPipelineStep[I, O]:
104-
converter: PipeConverterMethod[I, O]
105-
dispatcher: Dispatcher[I, Any] | None = field(default=None)
139+
return await self._invoke_with_middlewares(self.__steps.execute, input_value)
106140

107141

108142
class ContextPipeline[I]:
109-
__slots__ = ("__dispatcher", "__middlewares", "__steps")
143+
__slots__ = ("__middleware_group", "__steps")
110144

111-
__dispatcher: Dispatcher[Any, Any]
112-
__middlewares: deque[Middleware[Any, Any]]
113-
__steps: list[ContextPipelineStep[Any, Any]]
145+
__middleware_group: MiddlewareGroup[[I], Any]
146+
__steps: PipelineSteps[[object, type | None], I, Any]
114147

115148
def __init__(self, dispatcher: Dispatcher[Any, Any]) -> None:
116-
self.__dispatcher = dispatcher
117-
self.__middlewares = deque()
118-
self.__steps = []
149+
self.__middleware_group = MiddlewareGroup()
150+
self.__steps = PipelineSteps(dispatcher)
119151

120152
if TYPE_CHECKING: # pragma: no cover
121153

@@ -145,23 +177,32 @@ def __get__[O](
145177

146178
instance = owner()
147179

148-
pipeline = self.__new_pipeline(instance, owner)
149-
return BoundContextPipeline(instance, pipeline)
180+
dispatch_method = partial(self.__execute, context=instance, context_type=owner)
181+
return BoundContextPipeline(dispatch_method)
150182

151183
def add_middlewares(self, *middlewares: Middleware[[I], Any]) -> Self:
152-
self.__middlewares.extendleft(reversed(middlewares))
184+
self.__middleware_group.add(*middlewares)
153185
return self
154186

155187
if TYPE_CHECKING: # pragma: no cover
156188

157189
@overload
158190
def step[T](
159191
self,
160-
wrapped: PipeConverterMethod[T, Any],
192+
wrapped: ConvertMethodAsync[T, Any],
193+
/,
194+
*,
195+
dispatcher: Dispatcher[T, Any] | None = ...,
196+
) -> ConvertMethodAsync[T, Any]: ...
197+
198+
@overload
199+
def step[T](
200+
self,
201+
wrapped: ConvertMethodSync[T, Any],
161202
/,
162203
*,
163204
dispatcher: Dispatcher[T, Any] | None = ...,
164-
) -> PipeConverterMethod[T, Any]: ...
205+
) -> ConvertMethodSync[T, Any]: ...
165206

166207
@overload
167208
def step(
@@ -174,38 +215,98 @@ def step(
174215

175216
def step[T](
176217
self,
177-
wrapped: PipeConverterMethod[T, Any] | None = None,
218+
wrapped: ConvertMethod[T, Any] | None = None,
178219
/,
179220
*,
180221
dispatcher: Dispatcher[T, Any] | None = None,
181222
) -> Any:
182-
def decorator(wp: PipeConverterMethod[T, Any]) -> PipeConverterMethod[T, Any]:
183-
step = ContextPipelineStep(wp, dispatcher)
184-
self.__steps.append(step)
223+
def decorator(wp: ConvertMethod[T, Any]) -> ConvertMethod[T, Any]:
224+
converter = (
225+
_AsyncContextPipelineConverter(wp)
226+
if iscoroutinefunction(wp)
227+
else _SyncContextPipelineConverter(wp)
228+
)
229+
self.__steps.add(converter, dispatcher)
185230
return wp
186231

187232
return decorator(wrapped) if wrapped else decorator
188233

189-
def __new_pipeline[T](
234+
async def __execute[O](
190235
self,
191-
context: T,
192-
context_type: type[T] | None,
193-
) -> Pipe[I, Any]:
194-
pipeline: Pipe[I, Any] = Pipe(self.__dispatcher)
195-
pipeline.add_middlewares(*self.__middlewares)
196-
197-
for step in self.__steps:
198-
converter = step.converter.__get__(context, context_type)
199-
pipeline.step(converter, dispatcher=step.dispatcher)
200-
201-
return pipeline
236+
input_value: I,
237+
/,
238+
*,
239+
context: O,
240+
context_type: type[O] | None,
241+
) -> O:
242+
await self.__middleware_group.invoke(
243+
lambda i: self.__steps.execute(i, context, context_type),
244+
input_value,
245+
)
246+
return context
202247

203248

204249
@dataclass(repr=False, eq=False, frozen=True, slots=True)
205250
class BoundContextPipeline[I, O](Dispatcher[I, O]):
206-
context: O
207-
pipeline: Pipe[I, Any]
251+
dispatch_method: Callable[[I], Awaitable[O]]
208252

209253
async def dispatch(self, input_value: I, /) -> O:
210-
await self.pipeline.dispatch(input_value)
211-
return self.context
254+
return await self.dispatch_method(input_value)
255+
256+
257+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
258+
class _AsyncPipelineConverter[**P, I, O](PipelineConverter[P, I, O]):
259+
converter: ConvertAsync[P, I, O]
260+
261+
async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I:
262+
return await self.converter(output_value, *args, **kwargs)
263+
264+
265+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
266+
class _SyncPipelineConverter[**P, I, O](PipelineConverter[P, I, O]):
267+
converter: ConvertSync[P, I, O]
268+
269+
async def convert(self, output_value: O, /, *args: P.args, **kwargs: P.kwargs) -> I:
270+
return self.converter(output_value, *args, **kwargs)
271+
272+
273+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
274+
class _StaticPipelineConverter[I](PipelineConverter[..., I, Any]):
275+
input_value: I
276+
277+
async def convert(self, output_value: Any, /, *args: Any, **kwargs: Any) -> I:
278+
return self.input_value
279+
280+
281+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
282+
class _AsyncContextPipelineConverter[I, O](
283+
PipelineConverter[[object, type | None], I, O],
284+
):
285+
converter: ConvertMethodAsync[I, O]
286+
287+
async def convert(
288+
self,
289+
output_value: O,
290+
/,
291+
context: object,
292+
context_type: type | None,
293+
) -> I:
294+
method = self.converter.__get__(context, context_type)
295+
return await method(output_value)
296+
297+
298+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
299+
class _SyncContextPipelineConverter[I, O](
300+
PipelineConverter[[object, type | None], I, O],
301+
):
302+
converter: ConvertMethodSync[I, O]
303+
304+
async def convert(
305+
self,
306+
output_value: O,
307+
/,
308+
context: object,
309+
context_type: type | None,
310+
) -> I:
311+
method = self.converter.__get__(context, context_type)
312+
return method(output_value)

cq/_core/middleware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class _BoundMiddleware[**P, T]:
5959
call_next: Callable[P, Awaitable[T]]
6060
middleware: ClassicMiddleware[P, T]
6161

62-
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
62+
async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
6363
return await self.middleware(self.call_next, *args, **kwargs)
6464

6565

0 commit comments

Comments
 (0)