2020from laygo .context .types import IContextHandle
2121from laygo .helpers import is_context_aware
2222from laygo .transformers .transformer import Transformer
23+ from laygo .transformers .types import BaseTransformer
2324
2425T = TypeVar ("T" )
2526U = TypeVar ("U" )
2627PipelineFunction = Callable [[T ], Any ]
2728
2829
2930# This function must be defined at the top level of the module (e.g., after imports)
30- def _branch_consumer_process [T ](transformer : Transformer , queue : "Queue" , context_handle : IContextHandle ) -> list [Any ]:
31+ def _branch_consumer_process [T ](
32+ transformer : BaseTransformer , queue : "Queue" , context_handle : IContextHandle
33+ ) -> list [Any ]:
3134 """Entry point for a consumer process in parallel branching.
3235
3336 Reconstructs the necessary objects and runs a dedicated pipeline instance
@@ -457,7 +460,7 @@ def _producer_broadcast(
457460 @overload
458461 def branch (
459462 self ,
460- branches : Mapping [str , Transformer [T , Any ]],
463+ branches : Mapping [str , BaseTransformer [T , Any ]],
461464 * ,
462465 executor_type : Literal ["thread" , "process" ] = "thread" ,
463466 batch_size : int = 1000 ,
@@ -468,7 +471,7 @@ def branch(
468471 @overload
469472 def branch (
470473 self ,
471- branches : Mapping [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
474+ branches : Mapping [str , tuple [BaseTransformer [T , Any ], Callable [[T ], bool ]]],
472475 * ,
473476 executor_type : Literal ["thread" , "process" ] = "thread" ,
474477 first_match : bool = True ,
@@ -478,7 +481,7 @@ def branch(
478481
479482 def branch (
480483 self ,
481- branches : Mapping [str , Transformer [T , Any ]] | Mapping [str , tuple [Transformer [T , Any ], Callable [[T ], bool ]]],
484+ branches : Mapping [str , BaseTransformer [T , Any ]] | Mapping [str , tuple [BaseTransformer [T , Any ], Callable [[T ], bool ]]],
482485 * ,
483486 executor_type : Literal ["thread" , "process" ] = "thread" ,
484487 first_match : bool = True ,
@@ -519,7 +522,7 @@ def branch(
519522 first_value = next (iter (branches .values ()))
520523 is_conditional = isinstance (first_value , tuple )
521524
522- parsed_branches : list [tuple [str , Transformer [T , Any ], Callable [[T ], bool ]]]
525+ parsed_branches : list [tuple [str , BaseTransformer [T , Any ], Callable [[T ], bool ]]]
523526 if is_conditional :
524527 parsed_branches = [(name , trans , cond ) for name , (trans , cond ) in branches .items ()] # type: ignore
525528 else :
@@ -555,7 +558,7 @@ def _execute_branching_process(
555558 self ,
556559 * ,
557560 producer_fn : Callable ,
558- parsed_branches : list [tuple [str , Transformer , Callable ]],
561+ parsed_branches : list [tuple [str , BaseTransformer , Callable ]],
559562 batch_size : int ,
560563 max_batch_buffer : int ,
561564 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
@@ -629,7 +632,7 @@ def _execute_branching_thread(
629632 self ,
630633 * ,
631634 producer_fn : Callable ,
632- parsed_branches : list [tuple [str , Transformer , Callable ]],
635+ parsed_branches : list [tuple [str , BaseTransformer , Callable ]],
633636 batch_size : int ,
634637 max_batch_buffer : int ,
635638 ) -> tuple [dict [str , list [Any ]], dict [str , Any ]]:
@@ -654,7 +657,7 @@ def _execute_branching_thread(
654657 final_results : dict [str , list [Any ]] = {name : [] for name , _ , _ in parsed_branches }
655658 queues = {name : Queue (maxsize = max_batch_buffer ) for name , _ , _ in parsed_branches }
656659
657- def consumer (transformer : Transformer , queue : Queue , context_handle : IContextHandle ) -> list [Any ]:
660+ def consumer (transformer : BaseTransformer , queue : Queue , context_handle : IContextHandle ) -> list [Any ]:
658661 """Consume batches from a queue and process them with a transformer.
659662
660663 Creates a mini-pipeline for the transformer and processes all
0 commit comments