Skip to content

Commit a7b1653

Browse files
authored
feat: ✨ Static pipe step
1 parent f39d43d commit a7b1653

8 files changed

Lines changed: 58 additions & 18 deletions

File tree

cq/_core/dispatcher/base.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ async def dispatch(self, input_value: I, /) -> O:
1515
raise NotImplementedError
1616

1717
@abstractmethod
18-
def dispatch_no_wait(self, first_input_value: I, /, *input_values: I) -> None:
18+
def dispatch_no_wait(self, *input_values: I) -> None:
1919
raise NotImplementedError
2020

2121
@abstractmethod
@@ -31,12 +31,9 @@ class BaseDispatcher[I, O](Dispatcher[I, O], ABC):
3131
def __init__(self) -> None:
3232
self.__middleware_group = MiddlewareGroup()
3333

34-
def dispatch_no_wait(self, first_input_value: I, /, *input_values: I) -> None:
34+
def dispatch_no_wait(self, *input_values: I) -> None:
3535
asyncio.gather(
36-
*(
37-
self.dispatch(input_value)
38-
for input_value in (first_input_value, *input_values)
39-
),
36+
*(self.dispatch(input_value) for input_value in input_values),
4037
return_exceptions=True,
4138
)
4239

cq/_core/dispatcher/bus.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,8 @@ class SubscriberDecorator[I, O]:
4040
bus_type: BusType[I, O] | TypeAliasType | GenericAlias
4141
injection_module: injection.Module = field(default_factory=injection.mod)
4242

43-
def __call__[T](
44-
self,
45-
first_input_type: type[I],
46-
/,
47-
*input_types: type[I],
48-
) -> Callable[[T], T]:
49-
def decorator(wrapped: T) -> T:
43+
def __call__(self, first_input_type: type[I], /, *input_types: type[I]): # type: ignore[no-untyped-def]
44+
def decorator(wrapped): # type: ignore[no-untyped-def]
5045
if not isclass(wrapped) or not issubclass(wrapped, Handler):
5146
raise TypeError(f"`{wrapped}` isn't a valid handler.")
5247

@@ -56,7 +51,7 @@ def decorator(wrapped: T) -> T:
5651
for input_type in (first_input_type, *input_types):
5752
bus.subscribe(input_type, factory)
5853

59-
return wrapped # type: ignore[return-value]
54+
return wrapped
6055

6156
return decorator
6257

cq/_core/dispatcher/pipe.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass, field
3-
from typing import Any, Awaitable
3+
from typing import Any, Awaitable, Self
44

55
from cq._core.dispatcher.base import BaseDispatcher, Dispatcher
66

@@ -38,6 +38,18 @@ def decorator(wp): # type: ignore[no-untyped-def]
3838

3939
return decorator(wrapped) if wrapped else decorator
4040

41+
def add_static_step[T](
42+
self,
43+
input_value: T,
44+
*,
45+
dispatcher: Dispatcher[T, Any] | None = None,
46+
) -> Self:
47+
@self.step(dispatcher=dispatcher)
48+
async def converter(_: Any) -> T:
49+
return input_value
50+
51+
return self
52+
4153
async def dispatch(self, input_value: I, /) -> O:
4254
return await self._invoke_with_middlewares(self.__execute, input_value)
4355

cq/_core/middleware.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from dataclasses import dataclass, field
33
from typing import Self
44

5+
from cq.exceptions import MiddlewareError
6+
57
type MiddlewareResult[T] = AsyncGenerator[None, T]
68
type Middleware[**P, T] = Callable[P, MiddlewareResult[T]]
79

@@ -47,7 +49,9 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
4749
await generator.athrow(exc)
4850
else:
4951
await generator.asend(value)
50-
break
52+
raise MiddlewareError(
53+
f"Too many `yield` keywords in `{middleware}`."
54+
)
5155

5256
except StopAsyncIteration:
5357
...

cq/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
__all__ = ("CQError", "MiddlewareError")
2+
3+
4+
class CQError(Exception): ...
5+
6+
7+
class MiddlewareError(CQError): ...

cq/middlewares/retry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from cq import MiddlewareResult
55

6+
__all__ = ("RetryMiddleware",)
7+
68

79
class RetryMiddleware:
810
__slots__ = ("__delay", "__exceptions", "__retry")

tests/core/dispatcher/test_pipe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ async def step_converter_1(length: int) -> int:
3232

3333
assert await pipe.dispatch("hello") == "*****"
3434

35+
# Custom dispatcher
3536
other_bus: Bus[Any, Any] = SimpleBus()
3637

3738
other_bus.subscribe(str, ToTupleHandler)
@@ -41,3 +42,8 @@ async def step_converter_2(hidden_string: str) -> str:
4142
return hidden_string
4243

4344
assert await pipe.dispatch("hello") == ("*", "*", "*", "*", "*")
45+
46+
# Add static step
47+
pipe.add_static_step(2)
48+
49+
assert await pipe.dispatch("hello") == "**"

tests/core/test_middleware.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from __future__ import annotations
2-
31
from typing import Any
42

53
import pytest
64

75
from cq._core.middleware import MiddlewareGroup, MiddlewareResult
6+
from cq.exceptions import MiddlewareError
87
from tests.helpers.history import HistoryMiddleware
98

109

@@ -62,6 +61,24 @@ async def handler() -> str:
6261
assert isinstance(record.result, ValueError)
6362
assert record.is_failed
6463

64+
async def test_invoke_with_too_many_yield_raise_middleware_error(
65+
self,
66+
group: MiddlewareGroup[..., Any],
67+
) -> None:
68+
async def handler() -> None: ...
69+
70+
async def too_many_yield_middleware(
71+
*args: Any,
72+
**kwargs: Any,
73+
) -> MiddlewareResult[Any]:
74+
yield
75+
yield
76+
77+
group.add(too_many_yield_middleware)
78+
79+
with pytest.raises(MiddlewareError):
80+
await group.invoke(handler)
81+
6582
async def test_invoke_with_multiple_yield_return_any(
6683
self,
6784
group: MiddlewareGroup[..., Any],

0 commit comments

Comments
 (0)