|
1 | 1 | """Parallel transformer implementation using multiple threads.""" |
2 | 2 |
|
3 | 3 | from collections import deque |
| 4 | +from collections.abc import Callable |
4 | 5 | from collections.abc import Iterable |
5 | 6 | from collections.abc import Iterator |
6 | 7 | from concurrent.futures import FIRST_COMPLETED |
|
11 | 12 | from functools import partial |
12 | 13 | import itertools |
13 | 14 | import threading |
| 15 | +from typing import Any |
| 16 | +from typing import Union |
| 17 | +from typing import overload |
14 | 18 |
|
| 19 | +from laygo.errors import ErrorHandler |
15 | 20 | from laygo.helpers import PipelineContext |
16 | 21 | from laygo.transformers.transformer import DEFAULT_CHUNK_SIZE |
| 22 | +from laygo.transformers.transformer import ChunkErrorHandler |
17 | 23 | from laygo.transformers.transformer import InternalTransformer |
| 24 | +from laygo.transformers.transformer import PipelineFunction |
18 | 25 | from laygo.transformers.transformer import Transformer |
19 | 26 |
|
20 | 27 |
|
@@ -142,3 +149,53 @@ def result_iterator_manager() -> Iterator[Out]: |
142 | 149 | yield from result_chunk |
143 | 150 |
|
144 | 151 | return result_iterator_manager() |
| 152 | + |
| 153 | + # --- Overridden Chaining Methods to Preserve Type --- |
| 154 | + |
| 155 | + def on_error(self, handler: ChunkErrorHandler[In, Out] | ErrorHandler) -> "ParallelTransformer[In, Out]": |
| 156 | + super().on_error(handler) |
| 157 | + return self |
| 158 | + |
| 159 | + def map[U](self, function: PipelineFunction[Out, U]) -> "ParallelTransformer[In, U]": |
| 160 | + super().map(function) |
| 161 | + return self # type: ignore |
| 162 | + |
| 163 | + def filter(self, predicate: PipelineFunction[Out, bool]) -> "ParallelTransformer[In, Out]": |
| 164 | + super().filter(predicate) |
| 165 | + return self |
| 166 | + |
| 167 | + @overload |
| 168 | + def flatten[T](self: "ParallelTransformer[In, list[T]]") -> "ParallelTransformer[In, T]": ... |
| 169 | + @overload |
| 170 | + def flatten[T](self: "ParallelTransformer[In, tuple[T, ...]]") -> "ParallelTransformer[In, T]": ... |
| 171 | + @overload |
| 172 | + def flatten[T](self: "ParallelTransformer[In, set[T]]") -> "ParallelTransformer[In, T]": ... |
| 173 | + def flatten[T]( # type: ignore |
| 174 | + self: Union[ |
| 175 | + "ParallelTransformer[In, list[T]]", "ParallelTransformer[In, tuple[T, ...]]", "ParallelTransformer[In, set[T]]" |
| 176 | + ], |
| 177 | + ) -> "ParallelTransformer[In, T]": |
| 178 | + super().flatten() # type: ignore |
| 179 | + return self # type: ignore |
| 180 | + |
| 181 | + def tap(self, function: PipelineFunction[Out, Any]) -> "ParallelTransformer[In, Out]": |
| 182 | + super().tap(function) |
| 183 | + return self |
| 184 | + |
| 185 | + def apply[T]( |
| 186 | + self, t: Callable[["ParallelTransformer[In, Out]"], "Transformer[In, T]"] |
| 187 | + ) -> "ParallelTransformer[In, T]": |
| 188 | + super().apply(t) # type: ignore |
| 189 | + return self # type: ignore |
| 190 | + |
| 191 | + def catch[U]( |
| 192 | + self, |
| 193 | + sub_pipeline_builder: Callable[[Transformer[Out, Out]], Transformer[Out, U]], |
| 194 | + on_error: ChunkErrorHandler[Out, U] | None = None, |
| 195 | + ) -> "ParallelTransformer[In, U]": |
| 196 | + super().catch(sub_pipeline_builder, on_error) |
| 197 | + return self # type: ignore |
| 198 | + |
| 199 | + def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "ParallelTransformer[In, Out]": |
| 200 | + super().short_circuit(function) |
| 201 | + return self |
0 commit comments