|
1 | 1 | # pipeline.py |
2 | | - |
3 | 2 | from collections.abc import Callable |
4 | 3 | from collections.abc import Iterable |
5 | 4 | from collections.abc import Iterator |
| 5 | +from concurrent.futures import ThreadPoolExecutor |
| 6 | +from concurrent.futures import as_completed |
6 | 7 | import itertools |
7 | 8 | import multiprocessing as mp |
| 9 | +from queue import Queue |
8 | 10 | from typing import Any |
9 | 11 | from typing import TypeVar |
10 | 12 | from typing import overload |
|
15 | 17 | from laygo.transformers.transformer import Transformer |
16 | 18 |
|
17 | 19 | T = TypeVar("T") |
| 20 | +U = TypeVar("U") |
18 | 21 | PipelineFunction = Callable[[T], Any] |
19 | 22 |
|
20 | 23 |
|
@@ -147,53 +150,61 @@ def apply[U]( |
147 | 150 |
|
148 | 151 | return self # type: ignore |
149 | 152 |
|
150 | | - def branch(self, branches: dict[str, Transformer[T, Any]]) -> dict[str, list[Any]]: |
151 | | - """Forks the pipeline, sending all data to multiple branches and returning the last chunk. |
152 | | -
|
153 | | - This is a **terminal operation** that implements a fan-out pattern. |
154 | | - It consumes the pipeline's data, sends the **entire dataset** to each |
155 | | - branch transformer, and continuously **overwrites** a shared context value |
156 | | - with the latest processed chunk. The final result is a dictionary |
157 | | - containing only the **last processed chunk** for each branch. |
158 | | -
|
159 | | - Args: |
160 | | - branches: A dictionary where keys are branch names (str) and values |
161 | | - are `Transformer` instances. |
162 | | -
|
163 | | - Returns: |
164 | | - A dictionary where keys are the branch names and values are lists |
165 | | - of items from the last processed chunk for that branch. |
166 | | - """ |
| 153 | + def branch( |
| 154 | + self, |
| 155 | + branches: dict[str, Transformer[T, Any]], |
| 156 | + batch_size: int = 1000, |
| 157 | + max_batch_buffer: int = 1, |
| 158 | + ) -> dict[str, list[Any]]: |
| 159 | + """Forks the pipeline into multiple branches for concurrent, parallel processing.""" |
167 | 160 | if not branches: |
168 | 161 | self.consume() |
169 | 162 | return {} |
170 | 163 |
|
171 | | - # 1. Build a single "fan-out" transformer by chaining taps. |
172 | | - fan_out_transformer = Transformer[T, T]() |
173 | | - |
174 | | - for name, branch_transformer in branches.items(): |
175 | | - # Create a "collector" that runs the user's logic and then |
176 | | - # overwrites the context with its latest chunk. |
177 | | - collector = Transformer.from_transformer(branch_transformer) |
178 | | - |
179 | | - # This is the side-effect operation that overwrites the context. |
180 | | - def overwrite_context_with_chunk(chunk: list[Any], ctx: PipelineContext, name=name) -> list[Any]: |
181 | | - # This is an atomic assignment for manager dicts; no lock needed. |
182 | | - ctx[name] = chunk |
183 | | - # Return the chunk unmodified to satisfy the _pipe interface. |
184 | | - return chunk |
185 | | - |
186 | | - # Add this as the final step in the collector's pipeline. |
187 | | - collector._pipe(overwrite_context_with_chunk) |
188 | | - |
189 | | - # Tap the main transformer. The collector will run as a side-effect. |
190 | | - fan_out_transformer.tap(collector) |
191 | | - |
192 | | - # 2. Apply the fan-out transformer and consume the entire pipeline. |
193 | | - self.apply(fan_out_transformer).consume() |
194 | | - |
195 | | - # 3. Collect the final state from the context. |
196 | | - final_results = {name: self.ctx.get(name, []) for name in branches} |
| 164 | + source_iterator = self.processed_data |
| 165 | + branch_items = list(branches.items()) |
| 166 | + num_branches = len(branch_items) |
| 167 | + final_results: dict[str, list[Any]] = {} |
| 168 | + |
| 169 | + queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)] |
| 170 | + |
| 171 | + def producer() -> None: |
| 172 | + """Reads from the source and distributes batches to ALL branch queues.""" |
| 173 | + # Use itertools.batched for clean and efficient batch creation. |
| 174 | + for batch_tuple in itertools.batched(source_iterator, batch_size): |
| 175 | + # The batch is a tuple; convert to a list for consumers. |
| 176 | + batch_list = list(batch_tuple) |
| 177 | + for q in queues: |
| 178 | + q.put(batch_list) |
| 179 | + |
| 180 | + # Signal to all consumers that the stream is finished. |
| 181 | + for q in queues: |
| 182 | + q.put(None) |
| 183 | + |
| 184 | + def consumer(transformer: Transformer, queue: Queue) -> list[Any]: |
| 185 | + """Consumes batches from a queue and runs them through a transformer.""" |
| 186 | + |
| 187 | + def stream_from_queue() -> Iterator[T]: |
| 188 | + while (batch := queue.get()) is not None: |
| 189 | + yield from batch |
| 190 | + |
| 191 | + result_iterator = transformer(stream_from_queue(), self.ctx) # type: ignore |
| 192 | + return list(result_iterator) |
| 193 | + |
| 194 | + with ThreadPoolExecutor(max_workers=num_branches + 1) as executor: |
| 195 | + executor.submit(producer) |
| 196 | + |
| 197 | + future_to_name = { |
| 198 | + executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items) |
| 199 | + } |
| 200 | + |
| 201 | + for future in as_completed(future_to_name): |
| 202 | + name = future_to_name[future] |
| 203 | + try: |
| 204 | + final_results[name] = future.result() |
| 205 | + except Exception as e: |
| 206 | + print(f"Branch '{name}' raised an exception: {e}") |
| 207 | + final_results[name] = [] |
197 | 208 |
|
198 | 209 | return final_results |
199 | 210 |
|
|
0 commit comments