Skip to content

Commit 92ec3d7

Browse files
committed
fix: branch is now truly concurrent and it can be handled with different transformers
1 parent 496e8a1 commit 92ec3d7

2 files changed

Lines changed: 198 additions & 162 deletions

File tree

laygo/pipeline.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# pipeline.py
2-
32
from collections.abc import Callable
43
from collections.abc import Iterable
54
from collections.abc import Iterator
5+
from concurrent.futures import ThreadPoolExecutor
6+
from concurrent.futures import as_completed
67
import itertools
78
import multiprocessing as mp
9+
from queue import Queue
810
from typing import Any
911
from typing import TypeVar
1012
from typing import overload
@@ -15,6 +17,7 @@
1517
from laygo.transformers.transformer import Transformer
1618

1719
T = TypeVar("T")
20+
U = TypeVar("U")
1821
PipelineFunction = Callable[[T], Any]
1922

2023

@@ -147,53 +150,61 @@ def apply[U](
147150

148151
return self # type: ignore
149152

150-
def branch(self, branches: dict[str, Transformer[T, Any]]) -> dict[str, list[Any]]:
151-
"""Forks the pipeline, sending all data to multiple branches and returning the last chunk.
152-
153-
This is a **terminal operation** that implements a fan-out pattern.
154-
It consumes the pipeline's data, sends the **entire dataset** to each
155-
branch transformer, and continuously **overwrites** a shared context value
156-
with the latest processed chunk. The final result is a dictionary
157-
containing only the **last processed chunk** for each branch.
158-
159-
Args:
160-
branches: A dictionary where keys are branch names (str) and values
161-
are `Transformer` instances.
162-
163-
Returns:
164-
A dictionary where keys are the branch names and values are lists
165-
of items from the last processed chunk for that branch.
166-
"""
153+
def branch(
154+
self,
155+
branches: dict[str, Transformer[T, Any]],
156+
batch_size: int = 1000,
157+
max_batch_buffer: int = 1,
158+
) -> dict[str, list[Any]]:
159+
"""Forks the pipeline into multiple branches for concurrent, parallel processing."""
167160
if not branches:
168161
self.consume()
169162
return {}
170163

171-
# 1. Build a single "fan-out" transformer by chaining taps.
172-
fan_out_transformer = Transformer[T, T]()
173-
174-
for name, branch_transformer in branches.items():
175-
# Create a "collector" that runs the user's logic and then
176-
# overwrites the context with its latest chunk.
177-
collector = Transformer.from_transformer(branch_transformer)
178-
179-
# This is the side-effect operation that overwrites the context.
180-
def overwrite_context_with_chunk(chunk: list[Any], ctx: PipelineContext, name=name) -> list[Any]:
181-
# This is an atomic assignment for manager dicts; no lock needed.
182-
ctx[name] = chunk
183-
# Return the chunk unmodified to satisfy the _pipe interface.
184-
return chunk
185-
186-
# Add this as the final step in the collector's pipeline.
187-
collector._pipe(overwrite_context_with_chunk)
188-
189-
# Tap the main transformer. The collector will run as a side-effect.
190-
fan_out_transformer.tap(collector)
191-
192-
# 2. Apply the fan-out transformer and consume the entire pipeline.
193-
self.apply(fan_out_transformer).consume()
194-
195-
# 3. Collect the final state from the context.
196-
final_results = {name: self.ctx.get(name, []) for name in branches}
164+
source_iterator = self.processed_data
165+
branch_items = list(branches.items())
166+
num_branches = len(branch_items)
167+
final_results: dict[str, list[Any]] = {}
168+
169+
queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)]
170+
171+
def producer() -> None:
172+
"""Reads from the source and distributes batches to ALL branch queues."""
173+
# Use itertools.batched for clean and efficient batch creation.
174+
for batch_tuple in itertools.batched(source_iterator, batch_size):
175+
# The batch is a tuple; convert to a list for consumers.
176+
batch_list = list(batch_tuple)
177+
for q in queues:
178+
q.put(batch_list)
179+
180+
# Signal to all consumers that the stream is finished.
181+
for q in queues:
182+
q.put(None)
183+
184+
def consumer(transformer: Transformer, queue: Queue) -> list[Any]:
185+
"""Consumes batches from a queue and runs them through a transformer."""
186+
187+
def stream_from_queue() -> Iterator[T]:
188+
while (batch := queue.get()) is not None:
189+
yield from batch
190+
191+
result_iterator = transformer(stream_from_queue(), self.ctx) # type: ignore
192+
return list(result_iterator)
193+
194+
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:
195+
executor.submit(producer)
196+
197+
future_to_name = {
198+
executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items)
199+
}
200+
201+
for future in as_completed(future_to_name):
202+
name = future_to_name[future]
203+
try:
204+
final_results[name] = future.result()
205+
except Exception as e:
206+
print(f"Branch '{name}' raised an exception: {e}")
207+
final_results[name] = []
197208

198209
return final_results
199210

0 commit comments

Comments
 (0)