22from collections .abc import Callable
33from collections .abc import Iterable
44from collections .abc import Iterator
5+ from collections .abc import Mapping
56from concurrent .futures import ThreadPoolExecutor
67from concurrent .futures import as_completed
78import itertools
9+ from multiprocessing import Manager
810from queue import Queue
911from typing import Any
12+ from typing import Literal
1013from typing import TypeVar
1114from typing import overload
1215
16+ from loky import get_reusable_executor
17+
1318from laygo .context import IContextManager
1419from laygo .context .parallel import ParallelContextManager
1520from laygo .context .types import IContextHandle
2126PipelineFunction = Callable [[T ], Any ]
2227
2328
29+ # This function must be defined at the top level of the module (e.g., after imports)
30+ def _branch_consumer_process [T ](transformer : Transformer , queue : "Queue" , context_handle : IContextHandle ) -> list [Any ]:
31+ """
32+ The entry point for a consumer process. It reconstructs the necessary
33+ objects and runs a dedicated pipeline instance on the data from its queue.
34+ """
35+ # Re-create the context proxy within the new process
36+ context_proxy = context_handle .create_proxy ()
37+
38+ def stream_from_queue () -> Iterator [T ]:
39+ """A generator that yields items from the process-safe queue."""
40+ while (batch := queue .get ()) is not None :
41+ yield from batch
42+
43+ try :
44+ # Each consumer process runs its own mini-pipeline
45+ branch_pipeline = Pipeline (stream_from_queue (), context_manager = context_proxy )
46+ result_list , _ = branch_pipeline .apply (transformer ).to_list ()
47+ return result_list
48+ finally :
49+ context_proxy .shutdown ()
50+
51+
2452class Pipeline [T ]:
2553 """Manages a data source and applies transformers to it.
2654
@@ -303,12 +331,78 @@ def consume(self) -> tuple[None, dict[str, Any]]:
303331
304332 return None , self .context_manager .to_dict ()
305333
334+ def _producer_fanout (
335+ self ,
336+ source_iterator : Iterator [T ],
337+ queues : dict [str , Queue ],
338+ batch_size : int ,
339+ ) -> None :
340+ """Producer for fan-out: sends every item to every branch."""
341+ for batch_tuple in itertools .batched (source_iterator , batch_size ):
342+ batch_list = list (batch_tuple )
343+ for q in queues .values ():
344+ q .put (batch_list )
345+ for q in queues .values ():
346+ q .put (None )
347+
348+ def _producer_router (
349+ self ,
350+ source_iterator : Iterator [T ],
351+ queues : dict [str , Queue ],
352+ parsed_branches : list [tuple [str , Transformer , Callable ]],
353+ batch_size : int ,
354+ ) -> None :
355+ """Producer for router (`first_match=True`): sends item to the first matching branch."""
356+ buffers = {name : [] for name , _ , _ in parsed_branches }
357+ for item in source_iterator :
358+ for name , _ , condition in parsed_branches :
359+ if condition (item ):
360+ branch_buffer = buffers [name ]
361+ branch_buffer .append (item )
362+ if len (branch_buffer ) >= batch_size :
363+ queues [name ].put (branch_buffer )
364+ buffers [name ] = []
365+ break
366+ for name , buffer_list in buffers .items ():
367+ if buffer_list :
368+ queues [name ].put (buffer_list )
369+ for q in queues .values ():
370+ q .put (None )
371+
372+ def _producer_broadcast (
373+ self ,
374+ source_iterator : Iterator [T ],
375+ queues : dict [str , Queue ],
376+ parsed_branches : list [tuple [str , Transformer , Callable ]],
377+ batch_size : int ,
378+ ) -> None :
379+ """Producer for broadcast (`first_match=False`): sends item to all matching branches."""
380+ buffers = {name : [] for name , _ , _ in parsed_branches }
381+ for item in source_iterator :
382+ item_matches = [name for name , _ , condition in parsed_branches if condition (item )]
383+
384+ for name in item_matches :
385+ buffers [name ].append (item )
386+ branch_buffer = buffers [name ]
387+ if len (branch_buffer ) >= batch_size :
388+ queues [name ].put (branch_buffer )
389+ buffers [name ] = []
390+
391+ for name , buffer_list in buffers .items ():
392+ if buffer_list :
393+ queues [name ].put (buffer_list )
394+ for q in queues .values ():
395+ q .put (None )
396+
397+ # In your Pipeline class
398+
306399 # Overload 1: Unconditional fan-out
307400 @overload
308401 def branch (
309402 self ,
310- branches : dict [str , Transformer [T , Any ]],
403+ branches : Mapping [str , Transformer [T , Any ]],
311404 * ,
405+ executor_type : Literal ["thread" , "process" ] = "thread" ,
312406 batch_size : int = 1000 ,
313407 max_batch_buffer : int = 1 ,
314408 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]: ...
@@ -317,17 +411,19 @@ def branch(
317411 @overload
318412 def branch (
319413 self ,
320- branches : dict [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
414+ branches : Mapping [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
321415 * ,
416+ executor_type : Literal ["thread" , "process" ] = "thread" ,
322417 first_match : bool = True ,
323418 batch_size : int = 1000 ,
324419 max_batch_buffer : int = 1 ,
325420 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]: ...
326421
327422 def branch (
328423 self ,
329- branches : dict [str , Transformer [T , Any ]] | dict [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
424+ branches : Mapping [str , Transformer [T , Any ]] | Mapping [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
330425 * ,
426+ executor_type : Literal ["thread" , "process" ] = "thread" ,
331427 first_match : bool = True ,
332428 batch_size : int = 1000 ,
333429 max_batch_buffer : int = 1 ,
@@ -350,9 +446,11 @@ def branch(
350446
351447 Args:
352448 branches: A dictionary defining the branches.
353- first_match (bool): Determines the routing logic for conditional branches.
354- batch_size (int): The number of items to batch for processing.
355- max_batch_buffer (int): The max number of batches to buffer per branch.
449+ executor_type: The parallelism model. 'thread' for I/O-bound tasks,
450+ 'process' for CPU-bound tasks. Defaults to 'thread'.
451+ first_match: Determines the routing logic for conditional branches.
452+ batch_size: The number of items to batch for processing.
453+ max_batch_buffer: The max number of batches to buffer per branch.
356454
357455 Returns:
358456 A tuple containing a dictionary of results and the final context.
@@ -378,85 +476,93 @@ def branch(
378476 else :
379477 producer_fn = self ._producer_broadcast
380478
381- return self ._execute_branching (
382- producer_fn = producer_fn ,
383- parsed_branches = parsed_branches ,
384- batch_size = batch_size ,
385- max_batch_buffer = max_batch_buffer ,
386- )
387-
388- def _producer_fanout (
389- self ,
390- source_iterator : Iterator [T ],
391- queues : dict [str , Queue ],
392- batch_size : int ,
393- ) -> None :
394- """Producer for fan-out: sends every item to every branch."""
395- for batch_tuple in itertools .batched (source_iterator , batch_size ):
396- batch_list = list (batch_tuple )
397- for q in queues .values ():
398- q .put (batch_list )
399- for q in queues .values ():
400- q .put (None )
479+ # Dispatch to the correct executor based on the chosen type
480+ if executor_type == "thread" :
481+ return self ._execute_branching_thread (
482+ producer_fn = producer_fn ,
483+ parsed_branches = parsed_branches ,
484+ batch_size = batch_size ,
485+ max_batch_buffer = max_batch_buffer ,
486+ )
487+ elif executor_type == "process" :
488+ return self ._execute_branching_process (
489+ producer_fn = producer_fn ,
490+ parsed_branches = parsed_branches ,
491+ batch_size = batch_size ,
492+ max_batch_buffer = max_batch_buffer ,
493+ )
494+ else :
495+ raise ValueError (f"Unsupported executor_type: '{ executor_type } '. Must be 'thread' or 'process'." )
401496
402- def _producer_router (
497+ def _execute_branching_process (
403498 self ,
404- source_iterator : Iterator [ T ] ,
405- queues : dict [ str , Queue ] ,
499+ * ,
500+ producer_fn : Callable ,
406501 parsed_branches : list [tuple [str , Transformer , Callable ]],
407502 batch_size : int ,
408- ) -> None :
409- """Producer for router (`first_match=True`): sends item to the first matching branch."""
410- buffers = {name : [] for name , _ , _ in parsed_branches }
411- for item in source_iterator :
412- for name , _ , condition in parsed_branches :
413- if condition (item ):
414- branch_buffer = buffers [name ]
415- branch_buffer .append (item )
416- if len (branch_buffer ) >= batch_size :
417- queues [name ].put (branch_buffer )
418- buffers [name ] = []
419- break
420- for name , buffer_list in buffers .items ():
421- if buffer_list :
422- queues [name ].put (buffer_list )
423- for q in queues .values ():
424- q .put (None )
503+ max_batch_buffer : int ,
504+ ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
505+ """Branching execution using a process pool for consumers."""
506+ source_iterator = self .processed_data
507+ num_branches = len (parsed_branches )
508+ final_results : dict [str , list [Any ]] = {name : [] for name , _ , _ in parsed_branches }
509+ context_handle = self .context_manager .get_handle ()
425510
426- def _producer_broadcast (
427- self ,
428- source_iterator : Iterator [T ],
429- queues : dict [str , Queue ],
430- parsed_branches : list [tuple [str , Transformer , Callable ]],
431- batch_size : int ,
432- ) -> None :
433- """Producer for broadcast (`first_match=False`): sends item to all matching branches."""
434- buffers = {name : [] for name , _ , _ in parsed_branches }
435- for item in source_iterator :
436- item_matches = [name for name , _ , condition in parsed_branches if condition (item )]
511+ # A Manager creates queues that can be shared between processes
512+ manager = Manager ()
513+ queues = {name : manager .Queue (maxsize = max_batch_buffer ) for name , _ , _ in parsed_branches }
437514
438- for name in item_matches :
439- buffers [name ].append (item )
440- branch_buffer = buffers [name ]
441- if len (branch_buffer ) >= batch_size :
442- queues [name ].put (branch_buffer )
443- buffers [name ] = []
515+ # The producer must run in a thread to access the pipeline's iterator,
516+ # while consumers run in processes for true CPU parallelism.
517+ producer_executor = ThreadPoolExecutor (max_workers = 1 )
518+ consumer_executor = get_reusable_executor (max_workers = num_branches )
444519
445- for name , buffer_list in buffers .items ():
446- if buffer_list :
447- queues [name ].put (buffer_list )
448- for q in queues .values ():
449- q .put (None )
520+ try :
521+ # Determine arguments for the producer function
522+ producer_args : tuple
523+ if producer_fn == self ._producer_fanout :
524+ producer_args = (source_iterator , queues , batch_size )
525+ else :
526+ producer_args = (source_iterator , queues , parsed_branches , batch_size )
527+
528+ # Submit the producer to the thread pool
529+ producer_future = producer_executor .submit (producer_fn , * producer_args )
530+
531+ # Submit consumers to the process pool
532+ future_to_name = {
533+ consumer_executor .submit (_branch_consumer_process , transformer , queues [name ], context_handle ): name
534+ for name , transformer , _ in parsed_branches
535+ }
536+
537+ # Collect results as they complete
538+ for future in as_completed (future_to_name ):
539+ name = future_to_name [future ]
540+ try :
541+ final_results [name ] = future .result ()
542+ except Exception :
543+ final_results [name ] = []
544+
545+ # Check for producer errors after consumers are done
546+ producer_future .result ()
547+
548+ finally :
549+ producer_executor .shutdown ()
550+ # The reusable executor from loky is managed globally
551+
552+ final_context = self .context_manager .to_dict ()
553+ return final_results , final_context
450554
451- def _execute_branching (
555+ # Rename original _execute_branching to be specific
556+ def _execute_branching_thread (
452557 self ,
453558 * ,
454559 producer_fn : Callable ,
455560 parsed_branches : list [tuple [str , Transformer , Callable ]],
456561 batch_size : int ,
457562 max_batch_buffer : int ,
458563 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
459- """Shared execution logic for all branching modes."""
564+ """Shared execution logic for thread-based branching modes."""
565+ # ... (The original implementation of _execute_branching goes here)
460566 source_iterator = self .processed_data
461567 num_branches = len (parsed_branches )
462568 final_results : dict [str , list [Any ]] = {name : [] for name , _ , _ in parsed_branches }
@@ -474,7 +580,6 @@ def stream_from_queue() -> Iterator[T]:
474580 return result_list
475581
476582 with ThreadPoolExecutor (max_workers = num_branches + 1 ) as executor :
477- # The producer needs different arguments depending on the type
478583 producer_args : tuple
479584 if producer_fn == self ._producer_fanout :
480585 producer_args = (source_iterator , queues , batch_size )
0 commit comments