@@ -89,6 +89,7 @@ def context(self, ctx: dict[str, Any]) -> "Pipeline[T]":
8989 automatically synchronized back to the original context object
9090 when the pipeline is destroyed or processing completes.
9191 """
92+ self ._user_context = ctx
9293 self .context_manager .update (ctx )
9394 return self
9495
@@ -180,95 +181,6 @@ def apply[U](
180181
181182 return self # type: ignore
182183
183- def branch (
184- self ,
185- branches : dict [str , Transformer [T , Any ]],
186- batch_size : int = 1000 ,
187- max_batch_buffer : int = 1 ,
188- use_queue_chunks : bool = True ,
189- ) -> dict [str , list [Any ]]:
190- """Forks the pipeline into multiple branches for concurrent, parallel processing.
191-
192- This is a **terminal operation** that implements a fan-out pattern where
193- the entire dataset is copied to each branch for independent processing.
194- Each branch processes the complete dataset concurrently using separate
195- transformers, and results are collected and returned in a dictionary.
196-
197- Args:
198- branches: A dictionary where keys are branch names (str) and values
199- are `Transformer` instances of any subtype.
200- batch_size: The number of items to batch together when sending data
201- to branches. Larger batches can improve throughput but
202- use more memory. Defaults to 1000.
203- max_batch_buffer: The maximum number of batches to buffer for each
204- branch queue. Controls memory usage and creates
205- backpressure. Defaults to 1.
206- use_queue_chunks: Whether to use passthrough chunking for the
207- transformers. When True, batches are processed
208- as chunks. Defaults to True.
209-
210- Returns:
211- A dictionary where keys are the branch names and values are lists
212- of all items processed by that branch's transformer.
213-
214- Note:
215- This operation consumes the pipeline's iterator, making subsequent
216- operations on the same pipeline return empty results.
217- """
218- if not branches :
219- self .consume ()
220- return {}
221-
222- source_iterator = self .processed_data
223- branch_items = list (branches .items ())
224- num_branches = len (branch_items )
225- final_results : dict [str , list [Any ]] = {}
226-
227- queues = [Queue (maxsize = max_batch_buffer ) for _ in range (num_branches )]
228-
229- def producer () -> None :
230- """Reads from the source and distributes batches to ALL branch queues."""
231- # Use itertools.batched for clean and efficient batch creation.
232- for batch_tuple in itertools .batched (source_iterator , batch_size ):
233- # The batch is a tuple; convert to a list for consumers.
234- batch_list = list (batch_tuple )
235- for q in queues :
236- q .put (batch_list )
237-
238- # Signal to all consumers that the stream is finished.
239- for q in queues :
240- q .put (None )
241-
242- def consumer (transformer : Transformer , queue : Queue ) -> list [Any ]:
243- """Consumes batches from a queue and runs them through a transformer."""
244-
245- def stream_from_queue () -> Iterator [T ]:
246- while (batch := queue .get ()) is not None :
247- yield batch
248-
249- if use_queue_chunks :
250- transformer = transformer .set_chunker (passthrough_chunks )
251-
252- result_iterator = transformer (stream_from_queue (), self .context_manager ) # type: ignore
253- return list (result_iterator )
254-
255- with ThreadPoolExecutor (max_workers = num_branches + 1 ) as executor :
256- executor .submit (producer )
257-
258- future_to_name = {
259- executor .submit (consumer , transformer , queues [i ]): name for i , (name , transformer ) in enumerate (branch_items )
260- }
261-
262- for future in as_completed (future_to_name ):
263- name = future_to_name [future ]
264- try :
265- final_results [name ] = future .result ()
266- except Exception as e :
267- print (f"Branch '{ name } ' raised an exception: { e } " )
268- final_results [name ] = []
269-
270- return final_results
271-
272184 def buffer (self , size : int , batch_size : int = 1000 ) -> "Pipeline[T]" :
273185 """Inserts a buffer in the pipeline to allow downstream processing to read ahead.
274186
@@ -328,7 +240,7 @@ def __iter__(self) -> Iterator[T]:
328240 """
329241 yield from self .processed_data
330242
331- def to_list (self ) -> list [T ]:
243+ def to_list (self ) -> tuple [ list [T ], dict [ str , Any ] ]:
332244 """Execute the pipeline and return the results as a list.
333245
334246 This is a terminal operation that consumes the pipeline's iterator
@@ -341,9 +253,9 @@ def to_list(self) -> list[T]:
341253 This operation consumes the pipeline's iterator, making subsequent
342254 operations on the same pipeline return empty results.
343255 """
344- return list (self .processed_data )
256+ return list (self .processed_data ), self . context_manager . to_dict ()
345257
346- def each (self , function : PipelineFunction [T ]) -> None :
258+ def each (self , function : PipelineFunction [T ]) -> tuple [ None , dict [ str , Any ]] :
347259 """Apply a function to each element (terminal operation).
348260
349261 This is a terminal operation that processes each element for side effects
@@ -360,7 +272,9 @@ def each(self, function: PipelineFunction[T]) -> None:
360272 for item in self .processed_data :
361273 function (item )
362274
363- def first (self , n : int = 1 ) -> list [T ]:
275+ return None , self .context_manager .to_dict ()
276+
277+ def first (self , n : int = 1 ) -> tuple [list [T ], dict [str , Any ]]:
364278 """Get the first n elements of the pipeline (terminal operation).
365279
366280 This is a terminal operation that consumes up to n elements from the
@@ -381,9 +295,9 @@ def first(self, n: int = 1) -> list[T]:
381295 operations will continue from where this operation left off.
382296 """
383297 assert n >= 1 , "n must be at least 1"
384- return list (itertools .islice (self .processed_data , n ))
298+ return list (itertools .islice (self .processed_data , n )), self . context_manager . to_dict ()
385299
386- def consume (self ) -> None :
300+ def consume (self ) -> tuple [ None , dict [ str , Any ]] :
387301 """Consume the pipeline without returning results (terminal operation).
388302
389303 This is a terminal operation that processes all elements in the pipeline
@@ -396,3 +310,94 @@ def consume(self) -> None:
396310 """
397311 for _ in self .processed_data :
398312 pass
313+
314+ return None , self .context_manager .to_dict ()
315+
316+ def branch (
317+ self ,
318+ branches : dict [str , Transformer [T , Any ]],
319+ batch_size : int = 1000 ,
320+ max_batch_buffer : int = 1 ,
321+ use_queue_chunks : bool = True ,
322+ ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
323+ """Forks the pipeline into multiple branches for concurrent, parallel processing.
324+
325+ This is a **terminal operation** that implements a fan-out pattern where
326+ the entire dataset is copied to each branch for independent processing.
327+ Each branch processes the complete dataset concurrently using separate
328+ transformers, and results are collected and returned in a dictionary.
329+
330+ Args:
331+ branches: A dictionary where keys are branch names (str) and values
332+ are `Transformer` instances of any subtype.
333+ batch_size: The number of items to batch together when sending data
334+ to branches. Larger batches can improve throughput but
335+ use more memory. Defaults to 1000.
336+ max_batch_buffer: The maximum number of batches to buffer for each
337+ branch queue. Controls memory usage and creates
338+ backpressure. Defaults to 1.
339+ use_queue_chunks: Whether to use passthrough chunking for the
340+ transformers. When True, batches are processed
341+ as chunks. Defaults to True.
342+
343+ Returns:
344+ A dictionary where keys are the branch names and values are lists
345+ of all items processed by that branch's transformer.
346+
347+ Note:
348+ This operation consumes the pipeline's iterator, making subsequent
349+ operations on the same pipeline return empty results.
350+ """
351+ if not branches :
352+ self .consume ()
353+ return {}, {}
354+
355+ source_iterator = self .processed_data
356+ branch_items = list (branches .items ())
357+ num_branches = len (branch_items )
358+ final_results : dict [str , list [Any ]] = {}
359+
360+ queues = [Queue (maxsize = max_batch_buffer ) for _ in range (num_branches )]
361+
362+ def producer () -> None :
363+ """Reads from the source and distributes batches to ALL branch queues."""
364+ # Use itertools.batched for clean and efficient batch creation.
365+ for batch_tuple in itertools .batched (source_iterator , batch_size ):
366+ # The batch is a tuple; convert to a list for consumers.
367+ batch_list = list (batch_tuple )
368+ for q in queues :
369+ q .put (batch_list )
370+
371+ # Signal to all consumers that the stream is finished.
372+ for q in queues :
373+ q .put (None )
374+
375+ def consumer (transformer : Transformer , queue : Queue ) -> list [Any ]:
376+ """Consumes batches from a queue and runs them through a transformer."""
377+
378+ def stream_from_queue () -> Iterator [T ]:
379+ while (batch := queue .get ()) is not None :
380+ yield batch
381+
382+ if use_queue_chunks :
383+ transformer = transformer .set_chunker (passthrough_chunks )
384+
385+ result_iterator = transformer (stream_from_queue (), self .context_manager ) # type: ignore
386+ return list (result_iterator )
387+
388+ with ThreadPoolExecutor (max_workers = num_branches + 1 ) as executor :
389+ executor .submit (producer )
390+
391+ future_to_name = {
392+ executor .submit (consumer , transformer , queues [i ]): name for i , (name , transformer ) in enumerate (branch_items )
393+ }
394+
395+ for future in as_completed (future_to_name ):
396+ name = future_to_name [future ]
397+ try :
398+ final_results [name ] = future .result ()
399+ except Exception as e :
400+ print (f"Branch '{ name } ' raised an exception: { e } " )
401+ final_results [name ] = []
402+
403+ return final_results , self .context_manager .to_dict ()
0 commit comments