Skip to content

Commit cbc5ff1

Browse files
committed
chore: implemented branch execution using processes
1 parent 5e2036b commit cbc5ff1

2 files changed

Lines changed: 228 additions & 71 deletions

File tree

laygo/pipeline.py

Lines changed: 176 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
from collections.abc import Callable
33
from collections.abc import Iterable
44
from collections.abc import Iterator
5+
from collections.abc import Mapping
56
from concurrent.futures import ThreadPoolExecutor
67
from concurrent.futures import as_completed
78
import itertools
9+
from multiprocessing import Manager
810
from queue import Queue
911
from typing import Any
12+
from typing import Literal
1013
from typing import TypeVar
1114
from typing import overload
1215

16+
from loky import get_reusable_executor
17+
1318
from laygo.context import IContextManager
1419
from laygo.context.parallel import ParallelContextManager
1520
from laygo.context.types import IContextHandle
@@ -21,6 +26,29 @@
2126
PipelineFunction = Callable[[T], Any]
2227

2328

29+
# This function must be defined at the top level of the module (e.g., after imports)
30+
def _branch_consumer_process[T](transformer: Transformer, queue: "Queue", context_handle: IContextHandle) -> list[Any]:
31+
"""
32+
The entry point for a consumer process. It reconstructs the necessary
33+
objects and runs a dedicated pipeline instance on the data from its queue.
34+
"""
35+
# Re-create the context proxy within the new process
36+
context_proxy = context_handle.create_proxy()
37+
38+
def stream_from_queue() -> Iterator[T]:
39+
"""A generator that yields items from the process-safe queue."""
40+
while (batch := queue.get()) is not None:
41+
yield from batch
42+
43+
try:
44+
# Each consumer process runs its own mini-pipeline
45+
branch_pipeline = Pipeline(stream_from_queue(), context_manager=context_proxy)
46+
result_list, _ = branch_pipeline.apply(transformer).to_list()
47+
return result_list
48+
finally:
49+
context_proxy.shutdown()
50+
51+
2452
class Pipeline[T]:
2553
"""Manages a data source and applies transformers to it.
2654
@@ -303,12 +331,78 @@ def consume(self) -> tuple[None, dict[str, Any]]:
303331

304332
return None, self.context_manager.to_dict()
305333

334+
def _producer_fanout(
335+
self,
336+
source_iterator: Iterator[T],
337+
queues: dict[str, Queue],
338+
batch_size: int,
339+
) -> None:
340+
"""Producer for fan-out: sends every item to every branch."""
341+
for batch_tuple in itertools.batched(source_iterator, batch_size):
342+
batch_list = list(batch_tuple)
343+
for q in queues.values():
344+
q.put(batch_list)
345+
for q in queues.values():
346+
q.put(None)
347+
348+
def _producer_router(
349+
self,
350+
source_iterator: Iterator[T],
351+
queues: dict[str, Queue],
352+
parsed_branches: list[tuple[str, Transformer, Callable]],
353+
batch_size: int,
354+
) -> None:
355+
"""Producer for router (`first_match=True`): sends item to the first matching branch."""
356+
buffers = {name: [] for name, _, _ in parsed_branches}
357+
for item in source_iterator:
358+
for name, _, condition in parsed_branches:
359+
if condition(item):
360+
branch_buffer = buffers[name]
361+
branch_buffer.append(item)
362+
if len(branch_buffer) >= batch_size:
363+
queues[name].put(branch_buffer)
364+
buffers[name] = []
365+
break
366+
for name, buffer_list in buffers.items():
367+
if buffer_list:
368+
queues[name].put(buffer_list)
369+
for q in queues.values():
370+
q.put(None)
371+
372+
def _producer_broadcast(
373+
self,
374+
source_iterator: Iterator[T],
375+
queues: dict[str, Queue],
376+
parsed_branches: list[tuple[str, Transformer, Callable]],
377+
batch_size: int,
378+
) -> None:
379+
"""Producer for broadcast (`first_match=False`): sends item to all matching branches."""
380+
buffers = {name: [] for name, _, _ in parsed_branches}
381+
for item in source_iterator:
382+
item_matches = [name for name, _, condition in parsed_branches if condition(item)]
383+
384+
for name in item_matches:
385+
buffers[name].append(item)
386+
branch_buffer = buffers[name]
387+
if len(branch_buffer) >= batch_size:
388+
queues[name].put(branch_buffer)
389+
buffers[name] = []
390+
391+
for name, buffer_list in buffers.items():
392+
if buffer_list:
393+
queues[name].put(buffer_list)
394+
for q in queues.values():
395+
q.put(None)
396+
397+
# In your Pipeline class
398+
306399
# Overload 1: Unconditional fan-out
307400
@overload
308401
def branch(
309402
self,
310-
branches: dict[str, Transformer[T, Any]],
403+
branches: Mapping[str, Transformer[T, Any]],
311404
*,
405+
executor_type: Literal["thread", "process"] = "thread",
312406
batch_size: int = 1000,
313407
max_batch_buffer: int = 1,
314408
) -> tuple[dict[str, list[Any]], dict[str, Any]]: ...
@@ -317,17 +411,19 @@ def branch(
317411
@overload
318412
def branch(
319413
self,
320-
branches: dict[str, tuple[Transformer[T, Any], Callable[[T], bool]]],
414+
branches: Mapping[str, tuple[Transformer[T, Any], Callable[[T], bool]]],
321415
*,
416+
executor_type: Literal["thread", "process"] = "thread",
322417
first_match: bool = True,
323418
batch_size: int = 1000,
324419
max_batch_buffer: int = 1,
325420
) -> tuple[dict[str, list[Any]], dict[str, Any]]: ...
326421

327422
def branch(
328423
self,
329-
branches: dict[str, Transformer[T, Any]] | dict[str, tuple[Transformer[T, Any], Callable[[T], bool]]],
424+
branches: Mapping[str, Transformer[T, Any]] | Mapping[str, tuple[Transformer[T, Any], Callable[[T], bool]]],
330425
*,
426+
executor_type: Literal["thread", "process"] = "thread",
331427
first_match: bool = True,
332428
batch_size: int = 1000,
333429
max_batch_buffer: int = 1,
@@ -350,9 +446,11 @@ def branch(
350446
351447
Args:
352448
branches: A dictionary defining the branches.
353-
first_match (bool): Determines the routing logic for conditional branches.
354-
batch_size (int): The number of items to batch for processing.
355-
max_batch_buffer (int): The max number of batches to buffer per branch.
449+
executor_type: The parallelism model. 'thread' for I/O-bound tasks,
450+
'process' for CPU-bound tasks. Defaults to 'thread'.
451+
first_match: Determines the routing logic for conditional branches.
452+
batch_size: The number of items to batch for processing.
453+
max_batch_buffer: The max number of batches to buffer per branch.
356454
357455
Returns:
358456
A tuple containing a dictionary of results and the final context.
@@ -378,85 +476,93 @@ def branch(
378476
else:
379477
producer_fn = self._producer_broadcast
380478

381-
return self._execute_branching(
382-
producer_fn=producer_fn,
383-
parsed_branches=parsed_branches,
384-
batch_size=batch_size,
385-
max_batch_buffer=max_batch_buffer,
386-
)
387-
388-
def _producer_fanout(
389-
self,
390-
source_iterator: Iterator[T],
391-
queues: dict[str, Queue],
392-
batch_size: int,
393-
) -> None:
394-
"""Producer for fan-out: sends every item to every branch."""
395-
for batch_tuple in itertools.batched(source_iterator, batch_size):
396-
batch_list = list(batch_tuple)
397-
for q in queues.values():
398-
q.put(batch_list)
399-
for q in queues.values():
400-
q.put(None)
479+
# Dispatch to the correct executor based on the chosen type
480+
if executor_type == "thread":
481+
return self._execute_branching_thread(
482+
producer_fn=producer_fn,
483+
parsed_branches=parsed_branches,
484+
batch_size=batch_size,
485+
max_batch_buffer=max_batch_buffer,
486+
)
487+
elif executor_type == "process":
488+
return self._execute_branching_process(
489+
producer_fn=producer_fn,
490+
parsed_branches=parsed_branches,
491+
batch_size=batch_size,
492+
max_batch_buffer=max_batch_buffer,
493+
)
494+
else:
495+
raise ValueError(f"Unsupported executor_type: '{executor_type}'. Must be 'thread' or 'process'.")
401496

402-
def _producer_router(
497+
def _execute_branching_process(
403498
self,
404-
source_iterator: Iterator[T],
405-
queues: dict[str, Queue],
499+
*,
500+
producer_fn: Callable,
406501
parsed_branches: list[tuple[str, Transformer, Callable]],
407502
batch_size: int,
408-
) -> None:
409-
"""Producer for router (`first_match=True`): sends item to the first matching branch."""
410-
buffers = {name: [] for name, _, _ in parsed_branches}
411-
for item in source_iterator:
412-
for name, _, condition in parsed_branches:
413-
if condition(item):
414-
branch_buffer = buffers[name]
415-
branch_buffer.append(item)
416-
if len(branch_buffer) >= batch_size:
417-
queues[name].put(branch_buffer)
418-
buffers[name] = []
419-
break
420-
for name, buffer_list in buffers.items():
421-
if buffer_list:
422-
queues[name].put(buffer_list)
423-
for q in queues.values():
424-
q.put(None)
503+
max_batch_buffer: int,
504+
) -> tuple[dict[str, list[Any]], dict[str, Any]]:
505+
"""Branching execution using a process pool for consumers."""
506+
source_iterator = self.processed_data
507+
num_branches = len(parsed_branches)
508+
final_results: dict[str, list[Any]] = {name: [] for name, _, _ in parsed_branches}
509+
context_handle = self.context_manager.get_handle()
425510

426-
def _producer_broadcast(
427-
self,
428-
source_iterator: Iterator[T],
429-
queues: dict[str, Queue],
430-
parsed_branches: list[tuple[str, Transformer, Callable]],
431-
batch_size: int,
432-
) -> None:
433-
"""Producer for broadcast (`first_match=False`): sends item to all matching branches."""
434-
buffers = {name: [] for name, _, _ in parsed_branches}
435-
for item in source_iterator:
436-
item_matches = [name for name, _, condition in parsed_branches if condition(item)]
511+
# A Manager creates queues that can be shared between processes
512+
manager = Manager()
513+
queues = {name: manager.Queue(maxsize=max_batch_buffer) for name, _, _ in parsed_branches}
437514

438-
for name in item_matches:
439-
buffers[name].append(item)
440-
branch_buffer = buffers[name]
441-
if len(branch_buffer) >= batch_size:
442-
queues[name].put(branch_buffer)
443-
buffers[name] = []
515+
# The producer must run in a thread to access the pipeline's iterator,
516+
# while consumers run in processes for true CPU parallelism.
517+
producer_executor = ThreadPoolExecutor(max_workers=1)
518+
consumer_executor = get_reusable_executor(max_workers=num_branches)
444519

445-
for name, buffer_list in buffers.items():
446-
if buffer_list:
447-
queues[name].put(buffer_list)
448-
for q in queues.values():
449-
q.put(None)
520+
try:
521+
# Determine arguments for the producer function
522+
producer_args: tuple
523+
if producer_fn == self._producer_fanout:
524+
producer_args = (source_iterator, queues, batch_size)
525+
else:
526+
producer_args = (source_iterator, queues, parsed_branches, batch_size)
527+
528+
# Submit the producer to the thread pool
529+
producer_future = producer_executor.submit(producer_fn, *producer_args)
530+
531+
# Submit consumers to the process pool
532+
future_to_name = {
533+
consumer_executor.submit(_branch_consumer_process, transformer, queues[name], context_handle): name
534+
for name, transformer, _ in parsed_branches
535+
}
536+
537+
# Collect results as they complete
538+
for future in as_completed(future_to_name):
539+
name = future_to_name[future]
540+
try:
541+
final_results[name] = future.result()
542+
except Exception:
543+
final_results[name] = []
544+
545+
# Check for producer errors after consumers are done
546+
producer_future.result()
547+
548+
finally:
549+
producer_executor.shutdown()
550+
# The reusable executor from loky is managed globally
551+
552+
final_context = self.context_manager.to_dict()
553+
return final_results, final_context
450554

451-
def _execute_branching(
555+
# Rename original _execute_branching to be specific
556+
def _execute_branching_thread(
452557
self,
453558
*,
454559
producer_fn: Callable,
455560
parsed_branches: list[tuple[str, Transformer, Callable]],
456561
batch_size: int,
457562
max_batch_buffer: int,
458563
) -> tuple[dict[str, list[Any]], dict[str, Any]]:
459-
"""Shared execution logic for all branching modes."""
564+
"""Shared execution logic for thread-based branching modes."""
565+
# ... (The original implementation of _execute_branching goes here)
460566
source_iterator = self.processed_data
461567
num_branches = len(parsed_branches)
462568
final_results: dict[str, list[Any]] = {name: [] for name, _, _ in parsed_branches}
@@ -474,7 +580,6 @@ def stream_from_queue() -> Iterator[T]:
474580
return result_list
475581

476582
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:
477-
# The producer needs different arguments depending on the type
478583
producer_args: tuple
479584
if producer_fn == self._producer_fanout:
480585
producer_args = (source_iterator, queues, batch_size)

tests/test_pipeline.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Tests for the Pipeline class."""
22

3+
import os
4+
import time
5+
36
from laygo import Pipeline
47
from laygo.context.types import IContextManager
58
from laygo.transformers.transformer import createTransformer
@@ -565,3 +568,52 @@ def test_branch_conditional_broadcast_mode(self):
565568
assert sorted(result["strings"]) == ["A", "B"]
566569
# The float (99.9) AND the integers (1, 2, 3) are processed by the 'numbers' branch.
567570
assert sorted(result["numbers"]) == [10.0, 20.0, 30.0, 999.0]
571+
572+
def test_branch_process_executor(self):
573+
"""Test branching with executor_type='process' for CPU-bound work."""
574+
575+
# Setup: A CPU-bound task is ideal for demonstrating process parallelism.
576+
def heavy_computation(x: int) -> int:
577+
# A simple but non-trivial calculation
578+
time.sleep(0.01) # Simulate work
579+
return x * x
580+
581+
# This function will run inside the worker process to check its PID
582+
def check_pid(chunk: list[int], ctx: IContextManager) -> list[int]:
583+
# Store the worker's process ID in the shared context
584+
if chunk:
585+
ctx[f"pid_for_item_{chunk[0]}"] = os.getpid()
586+
return chunk
587+
588+
data = [1, 2, 3, 4]
589+
pipeline = Pipeline(data)
590+
main_pid = os.getpid()
591+
592+
# Define branches with CPU-bound work and the PID check
593+
branches = {
594+
"evens": (
595+
createTransformer(int).filter(lambda x: x % 2 == 0).map(heavy_computation)._pipe(check_pid),
596+
lambda x: True, # Condition to route data
597+
),
598+
"odds": (
599+
createTransformer(int).filter(lambda x: x % 2 != 0).map(heavy_computation)._pipe(check_pid),
600+
lambda x: True,
601+
),
602+
}
603+
604+
# Action: Execute the branch with the process executor
605+
result, context = pipeline.branch(
606+
branches,
607+
first_match=False, # Use broadcast to send to all matching
608+
executor_type="process",
609+
)
610+
611+
# Assert: The computational results are correct
612+
assert sorted(result["evens"]) == [4, 16] # 2*2, 4*4
613+
assert sorted(result["odds"]) == [1, 9] # 1*1, 3*3
614+
615+
# Assert: The work was done in different processes
616+
worker_pids = {v for k, v in context.items() if "pid" in k}
617+
assert len(worker_pids) > 0, "No worker PIDs were found in the context."
618+
for pid in worker_pids:
619+
assert pid != main_pid, f"Worker PID {pid} is the same as the main PID."

0 commit comments

Comments
 (0)