55from concurrent .futures import ThreadPoolExecutor
66from concurrent .futures import as_completed
77import itertools
8- import multiprocessing as mp
98from queue import Queue
109from typing import Any
1110from typing import TypeVar
1211from typing import overload
1312
14- from laygo .helpers import PipelineContext
13+ from laygo .context import IContextManager
14+ from laygo .context import SimpleContextManager
1515from laygo .helpers import is_context_aware
1616from laygo .transformers .transformer import Transformer
1717from laygo .transformers .transformer import passthrough_chunks
@@ -44,12 +44,14 @@ class Pipeline[T]:
4444 pipeline effectively single-use unless the data source is re-initialized.
4545 """
4646
47- def __init__ (self , * data : Iterable [T ]) -> None :
47+ def __init__ (self , * data : Iterable [T ], context_manager : IContextManager | None = None ) -> None :
4848 """Initialize a pipeline with one or more data sources.
4949
5050 Args:
5151 *data: One or more iterable data sources. If multiple sources are
5252 provided, they will be chained together.
53+ context_manager: An instance of a class that implements IContextManager.
54+ If None, a SimpleContextManager is used by default.
5355
5456 Raises:
5557 ValueError: If no data sources are provided.
@@ -59,25 +61,16 @@ def __init__(self, *data: Iterable[T]) -> None:
5961 self .data_source : Iterable [T ] = itertools .chain .from_iterable (data ) if len (data ) > 1 else data [0 ]
6062 self .processed_data : Iterator = iter (self .data_source )
6163
62- # Always create a shared context with multiprocessing manager
63- self ._manager = mp .Manager ()
64- self .ctx = self ._manager .dict ()
65- # Add a shared lock to the context for safe concurrent updates
66- self .ctx ["lock" ] = self ._manager .Lock ()
67-
68- # Store reference to original context for final synchronization
69- self ._original_context_ref : PipelineContext | None = None
64+ # Rule 1: Pipeline creates a simple context manager by default.
65+ self .context_manager = context_manager or SimpleContextManager ()
7066
7167 def __del__ (self ) -> None :
72- """Clean up the multiprocessing manager when the pipeline is destroyed."""
73- try :
74- self ._sync_context_back ()
75- self ._manager .shutdown ()
76- except Exception :
77- pass
68+ """Clean up the context manager when the pipeline is destroyed."""
69+ if hasattr (self , "context_manager" ):
70+ self .context_manager .shutdown ()
7871
79- def context (self , ctx : PipelineContext ) -> "Pipeline[T]" :
80- """Update the pipeline context and store a reference to the original context .
72+ def context (self , ctx : dict [ str , Any ] ) -> "Pipeline[T]" :
73+ """Update the pipeline's context manager with values from a dictionary .
8174
8275 The provided context will be used during pipeline execution and any
8376 modifications made by transformers will be synchronized back to the
@@ -96,10 +89,7 @@ def context(self, ctx: PipelineContext) -> "Pipeline[T]":
9689 automatically synchronized back to the original context object
9790 when the pipeline is destroyed or processing completes.
9891 """
99- # Store reference to the original context
100- self ._original_context_ref = ctx
101- # Copy the context data to the pipeline's shared context
102- self .ctx .update (ctx )
92+ self .context_manager .update (ctx )
10393 return self
10494
10595 def _sync_context_back (self ) -> None :
@@ -108,12 +98,9 @@ def _sync_context_back(self) -> None:
10898 This is called after processing is complete to update the original
10999 context with any changes made during pipeline execution.
110100 """
111- if self ._original_context_ref is not None :
112- # Copy the final context state back to the original context reference
113- final_context_state = dict (self .ctx )
114- final_context_state .pop ("lock" , None ) # Remove non-serializable lock
115- self ._original_context_ref .clear ()
116- self ._original_context_ref .update (final_context_state )
101+ # This method is kept for backward compatibility but is no longer needed
102+ # since we use the context manager directly
103+ pass
117104
118105 def transform [U ](self , t : Callable [[Transformer [T , T ]], Transformer [T , U ]]) -> "Pipeline[U]" :
119106 """Apply a transformation using a lambda function.
@@ -146,13 +133,13 @@ def apply[U](self, transformer: Transformer[T, U]) -> "Pipeline[U]": ...
146133 def apply [U ](self , transformer : Callable [[Iterable [T ]], Iterator [U ]]) -> "Pipeline[U]" : ...
147134
148135 @overload
149- def apply [U ](self , transformer : Callable [[Iterable [T ], PipelineContext ], Iterator [U ]]) -> "Pipeline[U]" : ...
136+ def apply [U ](self , transformer : Callable [[Iterable [T ], IContextManager ], Iterator [U ]]) -> "Pipeline[U]" : ...
150137
151138 def apply [U ](
152139 self ,
153140 transformer : Transformer [T , U ]
154141 | Callable [[Iterable [T ]], Iterator [U ]]
155- | Callable [[Iterable [T ], PipelineContext ], Iterator [U ]],
142+ | Callable [[Iterable [T ], IContextManager ], Iterator [U ]],
156143 ) -> "Pipeline[U]" :
157144 """Apply a transformer to the current data source.
158145
@@ -181,10 +168,11 @@ def apply[U](
181168 """
182169 match transformer :
183170 case Transformer ():
184- self .processed_data = transformer (self .processed_data , self .ctx ) # type: ignore
171+ # Pass the pipeline's context manager to the transformer
172+ self .processed_data = transformer (self .processed_data , self .context_manager ) # type: ignore
185173 case _ if callable (transformer ):
186174 if is_context_aware (transformer ):
187- self .processed_data = transformer (self .processed_data , self .ctx ) # type: ignore
175+ self .processed_data = transformer (self .processed_data , self .context_manager ) # type: ignore
188176 else :
189177 self .processed_data = transformer (self .processed_data ) # type: ignore
190178 case _:
@@ -256,12 +244,12 @@ def consumer(transformer: Transformer, queue: Queue) -> list[Any]:
256244
257245 def stream_from_queue () -> Iterator [T ]:
258246 while (batch := queue .get ()) is not None :
259- yield batch
247+ yield from batch
260248
261249 if use_queue_chunks :
262250 transformer = transformer .set_chunker (passthrough_chunks )
263251
264- result_iterator = transformer (stream_from_queue (), self .ctx ) # type: ignore
252+ result_iterator = transformer (stream_from_queue (), self .context_manager ) # type: ignore
265253 return list (result_iterator )
266254
267255 with ThreadPoolExecutor (max_workers = num_branches + 1 ) as executor :
0 commit comments