1+ # pipeline.py
2+
13from collections .abc import Callable
24from collections .abc import Iterable
35from collections .abc import Iterator
46import itertools
7+ import multiprocessing as mp
58from typing import Any
69from typing import TypeVar
710from typing import overload
1720class Pipeline [T ]:
1821 """
1922 Manages a data source and applies transformers to it.
20- Provides terminal operations to consume the resulting data .
23+ Always uses a multiprocessing-safe shared context .
2124 """
2225
2326 def __init__ (self , * data : Iterable [T ]):
2427 if len (data ) == 0 :
2528 raise ValueError ("At least one data source must be provided to Pipeline." )
2629 self .data_source : Iterable [T ] = itertools .chain .from_iterable (data ) if len (data ) > 1 else data [0 ]
2730 self .processed_data : Iterator = iter (self .data_source )
28- self .ctx = PipelineContext ()
31+
32+ # Always create a shared context with multiprocessing manager
33+ self ._manager = mp .Manager ()
34+ self .ctx = self ._manager .dict ()
35+ # Add a shared lock to the context for safe concurrent updates
36+ self .ctx ["lock" ] = self ._manager .Lock ()
37+
38+ # Store reference to original context for final synchronization
39+ self ._original_context_ref : PipelineContext | None = None
40+
41+ def __del__ (self ):
42+ """Clean up the multiprocessing manager when the pipeline is destroyed."""
43+ try :
44+ self ._sync_context_back ()
45+ self ._manager .shutdown ()
46+ except Exception :
47+ pass # Ignore errors during cleanup
2948
3049 def context (self , ctx : PipelineContext ) -> "Pipeline[T]" :
3150 """
32- Sets the context for the pipeline.
51+ Updates the pipeline context and stores a reference to the original context.
52+ When the pipeline finishes processing, the original context will be updated
53+ with the final pipeline context data.
3354 """
34- self .ctx = ctx
55+ # Store reference to the original context
56+ self ._original_context_ref = ctx
57+ # Copy the context data to the pipeline's shared context
58+ self .ctx .update (ctx )
3559 return self
3660
61+ def _sync_context_back (self ) -> None :
62+ """
63+ Synchronize the final pipeline context back to the original context reference.
64+ This is called after processing is complete.
65+ """
66+ if self ._original_context_ref is not None :
67+ # Copy the final context state back to the original context reference
68+ final_context_state = dict (self .ctx )
69+ final_context_state .pop ("lock" , None ) # Remove non-serializable lock
70+ self ._original_context_ref .clear ()
71+ self ._original_context_ref .update (final_context_state )
72+
73+ def transform [U ](self , t : Callable [[Transformer [T , T ]], Transformer [T , U ]]) -> "Pipeline[U]" :
74+ """
75+ Shorthand method to apply a transformation using a lambda function.
76+ Creates a Transformer under the hood and applies it to the pipeline.
77+
78+ Args:
79+ t: A callable that takes a transformer and returns a transformed transformer
80+
81+ Returns:
82+ A new Pipeline with the transformed data
83+ """
84+ # Create a new transformer and apply the transformation function
85+ transformer = t (Transformer [T , T ]())
86+ return self .apply (transformer )
87+
3788 @overload
3889 def apply [U ](self , transformer : Transformer [T , U ]) -> "Pipeline[U]" : ...
3990
4091 @overload
4192 def apply [U ](self , transformer : Callable [[Iterable [T ]], Iterator [U ]]) -> "Pipeline[U]" : ...
4293
4394 @overload
44- def apply [U ](
45- self ,
46- transformer : Callable [[Iterable [T ], PipelineContext ], Iterator [U ]],
47- ) -> "Pipeline[U]" : ...
95+ def apply [U ](self , transformer : Callable [[Iterable [T ], PipelineContext ], Iterator [U ]]) -> "Pipeline[U]" : ...
4896
4997 def apply [U ](
5098 self ,
@@ -53,42 +101,26 @@ def apply[U](
53101 | Callable [[Iterable [T ], PipelineContext ], Iterator [U ]],
54102 ) -> "Pipeline[U]" :
55103 """
56- Applies a transformer to the current data source.
104+ Applies a transformer to the current data source. The pipeline's
105+ managed context is passed down.
57106 """
58-
59107 match transformer :
60108 case Transformer ():
61- # If a Transformer instance is provided, use its __call__ method
109+ # The transformer is called with self.ctx, which is the
110+ # shared mp.Manager.dict proxy when inside a 'with' block.
62111 self .processed_data = transformer (self .processed_data , self .ctx ) # type: ignore
63112 case _ if callable (transformer ):
64- # If a callable function is provided, call it with the current data and context
65-
66113 if is_context_aware (transformer ):
67114 processed_transformer = transformer
68115 else :
69116 processed_transformer = lambda data , ctx : transformer (data ) # type: ignore # noqa: E731
70-
71117 self .processed_data = processed_transformer (self .processed_data , self .ctx ) # type: ignore
72118 case _:
73119 raise TypeError ("Transformer must be a Transformer instance or a callable function" )
74120
75121 return self # type: ignore
76122
77- def transform [U ](self , t : Callable [[Transformer [T , T ]], Transformer [T , U ]]) -> "Pipeline[U]" :
78- """
79- Shorthand method to apply a transformation using a lambda function.
80- Creates a Transformer under the hood and applies it to the pipeline.
81-
82- Args:
83- t: A callable that takes a transformer and returns a transformed transformer
84-
85- Returns:
86- A new Pipeline with the transformed data
87- """
88- # Create a new transformer and apply the transformation function
89- transformer = t (Transformer [T , T ]())
90- return self .apply (transformer )
91-
123+ # ... The rest of the Pipeline class (transform, __iter__, to_list, etc.) remains unchanged ...
92124 def __iter__ (self ) -> Iterator [T ]:
93125 """Allows the pipeline to be iterated over."""
94126 yield from self .processed_data
0 commit comments