Skip to content

Commit ed0d59c

Browse files
committed
fix: global pipeline context passing
1 parent 5cceb7c commit ed0d59c

2 files changed

Lines changed: 101 additions & 53 deletions

File tree

laygo/pipeline.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
# pipeline.py
2+
13
from collections.abc import Callable
24
from collections.abc import Iterable
35
from collections.abc import Iterator
46
import itertools
7+
import multiprocessing as mp
58
from typing import Any
69
from typing import TypeVar
710
from typing import overload
@@ -17,34 +20,79 @@
1720
class Pipeline[T]:
1821
"""
1922
Manages a data source and applies transformers to it.
20-
Provides terminal operations to consume the resulting data.
23+
Always uses a multiprocessing-safe shared context.
2124
"""
2225

2326
def __init__(self, *data: Iterable[T]):
2427
if len(data) == 0:
2528
raise ValueError("At least one data source must be provided to Pipeline.")
2629
self.data_source: Iterable[T] = itertools.chain.from_iterable(data) if len(data) > 1 else data[0]
2730
self.processed_data: Iterator = iter(self.data_source)
28-
self.ctx = PipelineContext()
31+
32+
# Always create a shared context with multiprocessing manager
33+
self._manager = mp.Manager()
34+
self.ctx = self._manager.dict()
35+
# Add a shared lock to the context for safe concurrent updates
36+
self.ctx["lock"] = self._manager.Lock()
37+
38+
# Store reference to original context for final synchronization
39+
self._original_context_ref: PipelineContext | None = None
40+
41+
def __del__(self):
42+
"""Clean up the multiprocessing manager when the pipeline is destroyed."""
43+
try:
44+
self._sync_context_back()
45+
self._manager.shutdown()
46+
except Exception:
47+
pass # Ignore errors during cleanup
2948

3049
def context(self, ctx: PipelineContext) -> "Pipeline[T]":
3150
"""
32-
Sets the context for the pipeline.
51+
Updates the pipeline context and stores a reference to the original context.
52+
When the pipeline finishes processing, the original context will be updated
53+
with the final pipeline context data.
3354
"""
34-
self.ctx = ctx
55+
# Store reference to the original context
56+
self._original_context_ref = ctx
57+
# Copy the context data to the pipeline's shared context
58+
self.ctx.update(ctx)
3559
return self
3660

61+
def _sync_context_back(self) -> None:
62+
"""
63+
Synchronize the final pipeline context back to the original context reference.
64+
This is called after processing is complete.
65+
"""
66+
if self._original_context_ref is not None:
67+
# Copy the final context state back to the original context reference
68+
final_context_state = dict(self.ctx)
69+
final_context_state.pop("lock", None) # Remove non-serializable lock
70+
self._original_context_ref.clear()
71+
self._original_context_ref.update(final_context_state)
72+
73+
def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]":
74+
"""
75+
Shorthand method to apply a transformation using a lambda function.
76+
Creates a Transformer under the hood and applies it to the pipeline.
77+
78+
Args:
79+
t: A callable that takes a transformer and returns a transformed transformer
80+
81+
Returns:
82+
A new Pipeline with the transformed data
83+
"""
84+
# Create a new transformer and apply the transformation function
85+
transformer = t(Transformer[T, T]())
86+
return self.apply(transformer)
87+
3788
@overload
3889
def apply[U](self, transformer: Transformer[T, U]) -> "Pipeline[U]": ...
3990

4091
@overload
4192
def apply[U](self, transformer: Callable[[Iterable[T]], Iterator[U]]) -> "Pipeline[U]": ...
4293

4394
@overload
44-
def apply[U](
45-
self,
46-
transformer: Callable[[Iterable[T], PipelineContext], Iterator[U]],
47-
) -> "Pipeline[U]": ...
95+
def apply[U](self, transformer: Callable[[Iterable[T], PipelineContext], Iterator[U]]) -> "Pipeline[U]": ...
4896

4997
def apply[U](
5098
self,
@@ -53,42 +101,26 @@ def apply[U](
53101
| Callable[[Iterable[T], PipelineContext], Iterator[U]],
54102
) -> "Pipeline[U]":
55103
"""
56-
Applies a transformer to the current data source.
104+
Applies a transformer to the current data source. The pipeline's
105+
managed context is passed down.
57106
"""
58-
59107
match transformer:
60108
case Transformer():
61-
# If a Transformer instance is provided, use its __call__ method
109+
# The transformer is called with self.ctx, which is the
110+
# shared mp.Manager.dict proxy when inside a 'with' block.
62111
self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore
63112
case _ if callable(transformer):
64-
# If a callable function is provided, call it with the current data and context
65-
66113
if is_context_aware(transformer):
67114
processed_transformer = transformer
68115
else:
69116
processed_transformer = lambda data, ctx: transformer(data) # type: ignore # noqa: E731
70-
71117
self.processed_data = processed_transformer(self.processed_data, self.ctx) # type: ignore
72118
case _:
73119
raise TypeError("Transformer must be a Transformer instance or a callable function")
74120

75121
return self # type: ignore
76122

77-
def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]":
78-
"""
79-
Shorthand method to apply a transformation using a lambda function.
80-
Creates a Transformer under the hood and applies it to the pipeline.
81-
82-
Args:
83-
t: A callable that takes a transformer and returns a transformed transformer
84-
85-
Returns:
86-
A new Pipeline with the transformed data
87-
"""
88-
# Create a new transformer and apply the transformation function
89-
transformer = t(Transformer[T, T]())
90-
return self.apply(transformer)
91-
123+
# ... The rest of the Pipeline class (transform, __iter__, to_list, etc.) remains unchanged ...
92124
def __iter__(self) -> Iterator[T]:
93125
"""Allows the pipeline to be iterated over."""
94126
yield from self.processed_data

laygo/transformers/parallel.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import copy
1212
import itertools
1313
import multiprocessing as mp
14+
from multiprocessing.managers import DictProxy
1415
from typing import Any
1516
from typing import Union
1617
from typing import overload
@@ -85,30 +86,45 @@ def from_transformer[T, U](
8586
)
8687

8788
def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]:
88-
"""Executes the transformer on data concurrently using processes."""
89-
with mp.Manager() as manager:
90-
initial_ctx_data = context if context is not None else self.context
91-
shared_context = manager.dict(initial_ctx_data)
92-
93-
if "lock" not in shared_context:
94-
shared_context["lock"] = manager.Lock()
95-
96-
try:
97-
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
98-
chunks_to_process = self._chunk_generator(data)
99-
gen_func = self._ordered_generator if self.ordered else self._unordered_generator
100-
processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context)
101-
102-
for result_chunk in processed_chunks_iterator:
103-
yield from result_chunk
104-
finally:
105-
if context is not None:
106-
final_context_state = dict(shared_context)
107-
final_context_state.pop("lock", None)
108-
# FIX 2: Do not clear the context, just update it.
109-
# This allows chained transformers to merge their context results.
110-
# context.clear()
111-
context.update(final_context_state)
89+
"""
90+
Executes the transformer on data concurrently. It uses the shared
91+
context provided by the Pipeline, if available.
92+
"""
93+
run_context = context if context is not None else self.context
94+
95+
# Detect if the context is already managed by the Pipeline.
96+
is_managed_context = isinstance(run_context, DictProxy)
97+
98+
if is_managed_context:
99+
# Use the existing shared context and lock from the Pipeline.
100+
shared_context = run_context
101+
yield from self._execute_with_context(data, shared_context)
102+
# The context is live, so no need to update it here.
103+
# The Pipeline's __exit__ will handle final state.
104+
else:
105+
# Fallback for standalone use: create a temporary manager.
106+
with mp.Manager() as manager:
107+
initial_ctx_data = dict(run_context)
108+
shared_context = manager.dict(initial_ctx_data)
109+
if "lock" not in shared_context:
110+
shared_context["lock"] = manager.Lock()
111+
112+
yield from self._execute_with_context(data, shared_context)
113+
114+
# Copy results back to the original non-shared context.
115+
final_context_state = dict(shared_context)
116+
final_context_state.pop("lock", None)
117+
run_context.update(final_context_state)
118+
119+
def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]:
120+
"""Helper to run the execution logic with a given context."""
121+
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
122+
chunks_to_process = self._chunk_generator(data)
123+
gen_func = self._ordered_generator if self.ordered else self._unordered_generator
124+
processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context)
125+
126+
for result_chunk in processed_chunks_iterator:
127+
yield from result_chunk
112128

113129
# ... The rest of the file remains the same ...
114130
def _ordered_generator(

0 commit comments

Comments
 (0)