@@ -303,98 +303,196 @@ def consume(self) -> tuple[None, dict[str, Any]]:
303303
304304 return None , self .context_manager .to_dict ()
305305
306+ # Overload 1: Unconditional fan-out
307+ @overload
306308 def branch (
307309 self ,
308310 branches : dict [str , Transformer [T , Any ]],
311+ * ,
312+ batch_size : int = 1000 ,
313+ max_batch_buffer : int = 1 ,
314+ ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]: ...
315+
316+ # Overload 2: Conditional routing
317+ @overload
318+ def branch (
319+ self ,
320+ branches : dict [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
321+ * ,
322+ first_match : bool = True ,
323+ batch_size : int = 1000 ,
324+ max_batch_buffer : int = 1 ,
325+ ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]: ...
326+
327+ def branch (
328+ self ,
329+ branches : dict [str , Transformer [T , Any ]] | dict [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
330+ * ,
331+ first_match : bool = True ,
309332 batch_size : int = 1000 ,
310333 max_batch_buffer : int = 1 ,
311334 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
312- """Forks the pipeline into multiple branches for concurrent, parallel processing.
335+ """
336+ Forks the pipeline for parallel processing with optional conditional routing.
337+
338+ This is a **terminal operation** that consumes the pipeline.
313339
314- This is a **terminal operation** that implements a fan-out pattern where
315- the entire dataset is copied to each branch for independent processing.
316- Each branch gets its own Pipeline instance with isolated context management,
317- and results are collected and returned in a dictionary.
340+ **1. Unconditional Fan-Out:**
341+ If `branches` is a `Dict[str, Transformer]`, every item is sent to every branch.
342+
343+ **2. Conditional Routing:**
344+ If `branches` is a `Dict[str, Tuple[Transformer, condition]]`, the `first_match`
345+ argument determines the routing logic:
346+ - `first_match=True` (default): Routes each item to the **first** branch
347+ whose condition is met. This acts as a router.
348+ - `first_match=False`: Routes each item to **all** branches whose
349+ conditions are met. This acts as a conditional broadcast.
318350
319351 Args:
320- branches: A dictionary where keys are branch names (str) and values
321- are `Transformer` instances of any subtype.
322- batch_size: The number of items to batch together when sending data
323- to branches. Larger batches can improve throughput but
324- use more memory. Defaults to 1000.
325- max_batch_buffer: The maximum number of batches to buffer for each
326- branch queue. Controls memory usage and creates
327- backpressure. Defaults to 1.
352+ branches: A dictionary defining the branches.
353+ first_match (bool): Determines the routing logic for conditional branches.
354+ batch_size (int): The number of items to batch for processing.
355+ max_batch_buffer (int): The max number of batches to buffer per branch.
328356
329357 Returns:
330- A tuple containing:
331- - A dictionary where keys are the branch names and values are lists
332- of all items processed by that branch's transformer.
333- - A merged dictionary of all context values from all branches.
334-
335- Note:
336- This operation consumes the pipeline's iterator, making subsequent
337- operations on the same pipeline return empty results.
358+ A tuple containing a dictionary of results and the final context.
338359 """
339360 if not branches :
340361 self .consume ()
341362 return {}, {}
342363
364+ first_value = next (iter (branches .values ()))
365+ is_conditional = isinstance (first_value , tuple )
366+
367+ parsed_branches : list [tuple [str , Transformer [T , Any ], Callable [[T ], bool ]]]
368+ if is_conditional :
369+ parsed_branches = [(name , trans , cond ) for name , (trans , cond ) in branches .items ()] # type: ignore
370+ else :
371+ parsed_branches = [(name , trans , lambda _ : True ) for name , trans in branches .items ()] # type: ignore
372+
373+ producer_fn : Callable
374+ if not is_conditional :
375+ producer_fn = self ._producer_fanout
376+ elif first_match :
377+ producer_fn = self ._producer_router
378+ else :
379+ producer_fn = self ._producer_broadcast
380+
381+ return self ._execute_branching (
382+ producer_fn = producer_fn ,
383+ parsed_branches = parsed_branches ,
384+ batch_size = batch_size ,
385+ max_batch_buffer = max_batch_buffer ,
386+ )
387+
388+ def _producer_fanout (
389+ self ,
390+ source_iterator : Iterator [T ],
391+ queues : dict [str , Queue ],
392+ batch_size : int ,
393+ ) -> None :
394+ """Producer for fan-out: sends every item to every branch."""
395+ for batch_tuple in itertools .batched (source_iterator , batch_size ):
396+ batch_list = list (batch_tuple )
397+ for q in queues .values ():
398+ q .put (batch_list )
399+ for q in queues .values ():
400+ q .put (None )
401+
402+ def _producer_router (
403+ self ,
404+ source_iterator : Iterator [T ],
405+ queues : dict [str , Queue ],
406+ parsed_branches : list [tuple [str , Transformer , Callable ]],
407+ batch_size : int ,
408+ ) -> None :
409+ """Producer for router (`first_match=True`): sends item to the first matching branch."""
410+ buffers = {name : [] for name , _ , _ in parsed_branches }
411+ for item in source_iterator :
412+ for name , _ , condition in parsed_branches :
413+ if condition (item ):
414+ branch_buffer = buffers [name ]
415+ branch_buffer .append (item )
416+ if len (branch_buffer ) >= batch_size :
417+ queues [name ].put (branch_buffer )
418+ buffers [name ] = []
419+ break
420+ for name , buffer_list in buffers .items ():
421+ if buffer_list :
422+ queues [name ].put (buffer_list )
423+ for q in queues .values ():
424+ q .put (None )
425+
426+ def _producer_broadcast (
427+ self ,
428+ source_iterator : Iterator [T ],
429+ queues : dict [str , Queue ],
430+ parsed_branches : list [tuple [str , Transformer , Callable ]],
431+ batch_size : int ,
432+ ) -> None :
433+ """Producer for broadcast (`first_match=False`): sends item to all matching branches."""
434+ buffers = {name : [] for name , _ , _ in parsed_branches }
435+ for item in source_iterator :
436+ item_matches = [name for name , _ , condition in parsed_branches if condition (item )]
437+
438+ for name in item_matches :
439+ buffers [name ].append (item )
440+ branch_buffer = buffers [name ]
441+ if len (branch_buffer ) >= batch_size :
442+ queues [name ].put (branch_buffer )
443+ buffers [name ] = []
444+
445+ for name , buffer_list in buffers .items ():
446+ if buffer_list :
447+ queues [name ].put (buffer_list )
448+ for q in queues .values ():
449+ q .put (None )
450+
451+ def _execute_branching (
452+ self ,
453+ * ,
454+ producer_fn : Callable ,
455+ parsed_branches : list [tuple [str , Transformer , Callable ]],
456+ batch_size : int ,
457+ max_batch_buffer : int ,
458+ ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
459+ """Shared execution logic for all branching modes."""
343460 source_iterator = self .processed_data
344- branch_items = list (branches .items ())
345- num_branches = len (branch_items )
346- final_results : dict [str , list [Any ]] = {}
347-
348- queues = [Queue (maxsize = max_batch_buffer ) for _ in range (num_branches )]
349-
350- def producer () -> None :
351- """Reads from the source and distributes batches to ALL branch queues."""
352- # Use itertools.batched for clean and efficient batch creation.
353- for batch_tuple in itertools .batched (source_iterator , batch_size ):
354- # The batch is a tuple; convert to a list for consumers.
355- batch_list = list (batch_tuple )
356- for q in queues :
357- q .put (batch_list )
358-
359- # Signal to all consumers that the stream is finished.
360- for q in queues :
361- q .put (None )
362-
363- def consumer (
364- transformer : Transformer , queue : Queue , context_handle : IContextHandle
365- ) -> tuple [list [Any ], dict [str , Any ]]:
366- """Consumes batches from a queue and processes them through a dedicated pipeline."""
461+ num_branches = len (parsed_branches )
462+ final_results : dict [str , list [Any ]] = {name : [] for name , _ , _ in parsed_branches }
463+ queues = {name : Queue (maxsize = max_batch_buffer ) for name , _ , _ in parsed_branches }
464+
465+ def consumer (transformer : Transformer , queue : Queue , context_handle : IContextHandle ) -> list [Any ]:
466+ """Consumes batches from a queue and processes them."""
367467
368468 def stream_from_queue () -> Iterator [T ]:
369469 while (batch := queue .get ()) is not None :
370470 yield from batch
371471
372- # Create a new pipeline for this branch but share the parent's context manager
373- # This ensures all branches share the same context
374472 branch_pipeline = Pipeline (stream_from_queue (), context_manager = context_handle .create_proxy ()) # type: ignore
375-
376- # Apply the transformer to the branch pipeline and get results
377- result_list , branch_context = branch_pipeline .apply (transformer ).to_list ()
378-
379- return result_list , branch_context
473+ result_list , _ = branch_pipeline .apply (transformer ).to_list ()
474+ return result_list
380475
381476 with ThreadPoolExecutor (max_workers = num_branches + 1 ) as executor :
382- executor .submit (producer )
477+ # The producer needs different arguments depending on the type
478+ producer_args : tuple
479+ if producer_fn == self ._producer_fanout :
480+ producer_args = (source_iterator , queues , batch_size )
481+ else :
482+ producer_args = (source_iterator , queues , parsed_branches , batch_size )
483+ executor .submit (producer_fn , * producer_args )
383484
384485 future_to_name = {
385- executor .submit (consumer , transformer , queues [i ], self .context_manager .get_handle ()): name
386- for i , ( name , transformer ) in enumerate ( branch_items )
486+ executor .submit (consumer , transformer , queues [name ], self .context_manager .get_handle ()): name
487+ for name , transformer , _ in parsed_branches
387488 }
388489
389- # Collect results - context is shared through the same context manager
390490 for future in as_completed (future_to_name ):
391491 name = future_to_name [future ]
392492 try :
393- result_list , branch_context = future .result ()
394- final_results [name ] = result_list
493+ final_results [name ] = future .result ()
395494 except Exception :
396495 final_results [name ] = []
397496
398- # After all threads complete, get the final context state
399497 final_context = self .context_manager .to_dict ()
400498 return final_results , final_context
0 commit comments