Skip to content

Commit f272fe6

Browse files
committed
chore: conditional branching
1 parent ba2be73 commit f272fe6

1 file changed

Lines changed: 156 additions & 58 deletions

File tree

laygo/pipeline.py

Lines changed: 156 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -303,98 +303,196 @@ def consume(self) -> tuple[None, dict[str, Any]]:
303303

304304
return None, self.context_manager.to_dict()
305305

306+
# Overload 1: Unconditional fan-out
307+
@overload
306308
def branch(
307309
self,
308310
branches: dict[str, Transformer[T, Any]],
311+
*,
312+
batch_size: int = 1000,
313+
max_batch_buffer: int = 1,
314+
) -> tuple[dict[str, list[Any]], dict[str, Any]]: ...
315+
316+
# Overload 2: Conditional routing
317+
@overload
318+
def branch(
319+
self,
320+
branches: dict[str, tuple[Transformer[T, Any], Callable[[T], bool]]],
321+
*,
322+
first_match: bool = True,
323+
batch_size: int = 1000,
324+
max_batch_buffer: int = 1,
325+
) -> tuple[dict[str, list[Any]], dict[str, Any]]: ...
326+
327+
def branch(
328+
self,
329+
branches: dict[str, Transformer[T, Any]] | dict[str, tuple[Transformer[T, Any], Callable[[T], bool]]],
330+
*,
331+
first_match: bool = True,
309332
batch_size: int = 1000,
310333
max_batch_buffer: int = 1,
311334
) -> tuple[dict[str, list[Any]], dict[str, Any]]:
312-
"""Forks the pipeline into multiple branches for concurrent, parallel processing.
335+
"""
336+
Forks the pipeline for parallel processing with optional conditional routing.
337+
338+
This is a **terminal operation** that consumes the pipeline.
313339
314-
This is a **terminal operation** that implements a fan-out pattern where
315-
the entire dataset is copied to each branch for independent processing.
316-
Each branch gets its own Pipeline instance with isolated context management,
317-
and results are collected and returned in a dictionary.
340+
**1. Unconditional Fan-Out:**
341+
If `branches` is a `Dict[str, Transformer]`, every item is sent to every branch.
342+
343+
**2. Conditional Routing:**
344+
If `branches` is a `Dict[str, Tuple[Transformer, condition]]`, the `first_match`
345+
argument determines the routing logic:
346+
- `first_match=True` (default): Routes each item to the **first** branch
347+
whose condition is met. This acts as a router.
348+
- `first_match=False`: Routes each item to **all** branches whose
349+
conditions are met. This acts as a conditional broadcast.
318350
319351
Args:
320-
branches: A dictionary where keys are branch names (str) and values
321-
are `Transformer` instances of any subtype.
322-
batch_size: The number of items to batch together when sending data
323-
to branches. Larger batches can improve throughput but
324-
use more memory. Defaults to 1000.
325-
max_batch_buffer: The maximum number of batches to buffer for each
326-
branch queue. Controls memory usage and creates
327-
backpressure. Defaults to 1.
352+
branches: A dictionary defining the branches.
353+
first_match (bool): Determines the routing logic for conditional branches.
354+
batch_size (int): The number of items to batch for processing.
355+
max_batch_buffer (int): The max number of batches to buffer per branch.
328356
329357
Returns:
330-
A tuple containing:
331-
- A dictionary where keys are the branch names and values are lists
332-
of all items processed by that branch's transformer.
333-
- A merged dictionary of all context values from all branches.
334-
335-
Note:
336-
This operation consumes the pipeline's iterator, making subsequent
337-
operations on the same pipeline return empty results.
358+
A tuple containing a dictionary of results and the final context.
338359
"""
339360
if not branches:
340361
self.consume()
341362
return {}, {}
342363

364+
first_value = next(iter(branches.values()))
365+
is_conditional = isinstance(first_value, tuple)
366+
367+
parsed_branches: list[tuple[str, Transformer[T, Any], Callable[[T], bool]]]
368+
if is_conditional:
369+
parsed_branches = [(name, trans, cond) for name, (trans, cond) in branches.items()] # type: ignore
370+
else:
371+
parsed_branches = [(name, trans, lambda _: True) for name, trans in branches.items()] # type: ignore
372+
373+
producer_fn: Callable
374+
if not is_conditional:
375+
producer_fn = self._producer_fanout
376+
elif first_match:
377+
producer_fn = self._producer_router
378+
else:
379+
producer_fn = self._producer_broadcast
380+
381+
return self._execute_branching(
382+
producer_fn=producer_fn,
383+
parsed_branches=parsed_branches,
384+
batch_size=batch_size,
385+
max_batch_buffer=max_batch_buffer,
386+
)
387+
388+
def _producer_fanout(
389+
self,
390+
source_iterator: Iterator[T],
391+
queues: dict[str, Queue],
392+
batch_size: int,
393+
) -> None:
394+
"""Producer for fan-out: sends every item to every branch."""
395+
for batch_tuple in itertools.batched(source_iterator, batch_size):
396+
batch_list = list(batch_tuple)
397+
for q in queues.values():
398+
q.put(batch_list)
399+
for q in queues.values():
400+
q.put(None)
401+
402+
def _producer_router(
403+
self,
404+
source_iterator: Iterator[T],
405+
queues: dict[str, Queue],
406+
parsed_branches: list[tuple[str, Transformer, Callable]],
407+
batch_size: int,
408+
) -> None:
409+
"""Producer for router (`first_match=True`): sends item to the first matching branch."""
410+
buffers = {name: [] for name, _, _ in parsed_branches}
411+
for item in source_iterator:
412+
for name, _, condition in parsed_branches:
413+
if condition(item):
414+
branch_buffer = buffers[name]
415+
branch_buffer.append(item)
416+
if len(branch_buffer) >= batch_size:
417+
queues[name].put(branch_buffer)
418+
buffers[name] = []
419+
break
420+
for name, buffer_list in buffers.items():
421+
if buffer_list:
422+
queues[name].put(buffer_list)
423+
for q in queues.values():
424+
q.put(None)
425+
426+
def _producer_broadcast(
427+
self,
428+
source_iterator: Iterator[T],
429+
queues: dict[str, Queue],
430+
parsed_branches: list[tuple[str, Transformer, Callable]],
431+
batch_size: int,
432+
) -> None:
433+
"""Producer for broadcast (`first_match=False`): sends item to all matching branches."""
434+
buffers = {name: [] for name, _, _ in parsed_branches}
435+
for item in source_iterator:
436+
item_matches = [name for name, _, condition in parsed_branches if condition(item)]
437+
438+
for name in item_matches:
439+
buffers[name].append(item)
440+
branch_buffer = buffers[name]
441+
if len(branch_buffer) >= batch_size:
442+
queues[name].put(branch_buffer)
443+
buffers[name] = []
444+
445+
for name, buffer_list in buffers.items():
446+
if buffer_list:
447+
queues[name].put(buffer_list)
448+
for q in queues.values():
449+
q.put(None)
450+
451+
def _execute_branching(
452+
self,
453+
*,
454+
producer_fn: Callable,
455+
parsed_branches: list[tuple[str, Transformer, Callable]],
456+
batch_size: int,
457+
max_batch_buffer: int,
458+
) -> tuple[dict[str, list[Any]], dict[str, Any]]:
459+
"""Shared execution logic for all branching modes."""
343460
source_iterator = self.processed_data
344-
branch_items = list(branches.items())
345-
num_branches = len(branch_items)
346-
final_results: dict[str, list[Any]] = {}
347-
348-
queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)]
349-
350-
def producer() -> None:
351-
"""Reads from the source and distributes batches to ALL branch queues."""
352-
# Use itertools.batched for clean and efficient batch creation.
353-
for batch_tuple in itertools.batched(source_iterator, batch_size):
354-
# The batch is a tuple; convert to a list for consumers.
355-
batch_list = list(batch_tuple)
356-
for q in queues:
357-
q.put(batch_list)
358-
359-
# Signal to all consumers that the stream is finished.
360-
for q in queues:
361-
q.put(None)
362-
363-
def consumer(
364-
transformer: Transformer, queue: Queue, context_handle: IContextHandle
365-
) -> tuple[list[Any], dict[str, Any]]:
366-
"""Consumes batches from a queue and processes them through a dedicated pipeline."""
461+
num_branches = len(parsed_branches)
462+
final_results: dict[str, list[Any]] = {name: [] for name, _, _ in parsed_branches}
463+
queues = {name: Queue(maxsize=max_batch_buffer) for name, _, _ in parsed_branches}
464+
465+
def consumer(transformer: Transformer, queue: Queue, context_handle: IContextHandle) -> list[Any]:
466+
"""Consumes batches from a queue and processes them."""
367467

368468
def stream_from_queue() -> Iterator[T]:
369469
while (batch := queue.get()) is not None:
370470
yield from batch
371471

372-
# Create a new pipeline for this branch but share the parent's context manager
373-
# This ensures all branches share the same context
374472
branch_pipeline = Pipeline(stream_from_queue(), context_manager=context_handle.create_proxy()) # type: ignore
375-
376-
# Apply the transformer to the branch pipeline and get results
377-
result_list, branch_context = branch_pipeline.apply(transformer).to_list()
378-
379-
return result_list, branch_context
473+
result_list, _ = branch_pipeline.apply(transformer).to_list()
474+
return result_list
380475

381476
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:
382-
executor.submit(producer)
477+
# The producer needs different arguments depending on the type
478+
producer_args: tuple
479+
if producer_fn == self._producer_fanout:
480+
producer_args = (source_iterator, queues, batch_size)
481+
else:
482+
producer_args = (source_iterator, queues, parsed_branches, batch_size)
483+
executor.submit(producer_fn, *producer_args)
383484

384485
future_to_name = {
385-
executor.submit(consumer, transformer, queues[i], self.context_manager.get_handle()): name
386-
for i, (name, transformer) in enumerate(branch_items)
486+
executor.submit(consumer, transformer, queues[name], self.context_manager.get_handle()): name
487+
for name, transformer, _ in parsed_branches
387488
}
388489

389-
# Collect results - context is shared through the same context manager
390490
for future in as_completed(future_to_name):
391491
name = future_to_name[future]
392492
try:
393-
result_list, branch_context = future.result()
394-
final_results[name] = result_list
493+
final_results[name] = future.result()
395494
except Exception:
396495
final_results[name] = []
397496

398-
# After all threads complete, get the final context state
399497
final_context = self.context_manager.to_dict()
400498
return final_results, final_context

0 commit comments

Comments
 (0)