1- from typing import Dict , Optional
1+ from typing import Dict , Type
22
33from typing_extensions import override
44
1212from pipelex .cogt .llm .llm_models .llm_engine_factory import LLMEngineFactory
1313from pipelex .cogt .llm .llm_worker_abstract import LLMWorkerAbstract
1414from pipelex .cogt .llm .llm_worker_factory import LLMWorkerFactory
15+ from pipelex .cogt .llm .llm_worker_internal_abstract import LLMWorkerInternalAbstract
1516from pipelex .cogt .ocr .ocr_engine_factory import OcrEngineFactory
1617from pipelex .cogt .ocr .ocr_worker_abstract import OcrWorkerAbstract
1718from pipelex .cogt .ocr .ocr_worker_factory import OcrWorkerFactory
@@ -31,10 +32,16 @@ def __init__(self):
3132 def teardown (self ):
3233 self .imgg_worker_factory = ImggWorkerFactory ()
3334 self .ocr_worker_factory = OcrWorkerFactory ()
34- self .llm_workers .clear ()
35- self .imgg_workers .clear ()
36- self .ocr_workers .clear ()
37- log .verbose ("InferenceManagerAsync reset" )
35+ for llm_worker in self .llm_workers .values ():
36+ llm_worker .teardown ()
37+ self .llm_workers = {}
38+ for imgg_worker in self .imgg_workers .values ():
39+ imgg_worker .teardown ()
40+ self .imgg_workers = {}
41+ for ocr_worker in self .ocr_workers .values ():
42+ ocr_worker .teardown ()
43+ self .ocr_workers = {}
44+ log .verbose ("InferenceManager teardown done" )
3845
3946 def print_workers (self ):
4047 log .debug ("LLM Workers:" )
@@ -60,15 +67,15 @@ def setup_llm_workers(self):
6067 llm_handle_to_llm_engine_blueprint = get_llm_deck ().llm_handles
6168 log .verbose (f"{ len (llm_handle_to_llm_engine_blueprint )} LLM engine_cards found" )
6269 for llm_handle , llm_engine_blueprint in llm_handle_to_llm_engine_blueprint .items ():
63- self ._setup_one_llm_worker (llm_engine_blueprint = llm_engine_blueprint , llm_handle = llm_handle )
70+ self ._setup_one_internal_llm_worker (llm_engine_blueprint = llm_engine_blueprint , llm_handle = llm_handle )
6471 log .verbose (f"Setup LLM worker for '{ llm_handle } ' on { llm_engine_blueprint .llm_platform_choice } " )
6572 log .debug ("Done setting up LLM Workers (async)" )
6673
67- def _setup_one_llm_worker (
74+ def _setup_one_internal_llm_worker (
6875 self ,
6976 llm_engine_blueprint : LLMEngineBlueprint ,
7077 llm_handle : str ,
71- ) -> LLMWorkerAbstract :
78+ ) -> LLMWorkerInternalAbstract :
7279 llm_engine = LLMEngineFactory .make_llm_engine (llm_engine_blueprint = llm_engine_blueprint )
7380 llm_worker = LLMWorkerFactory .make_llm_worker (
7481 llm_engine = llm_engine ,
@@ -78,27 +85,34 @@ def _setup_one_llm_worker(
7885 return llm_worker
7986
8087 @override
81- def get_llm_worker (
82- self ,
83- llm_handle : str ,
84- specific_llm_engine_blueprint : Optional [LLMEngineBlueprint ] = None ,
85- ) -> LLMWorkerAbstract :
88+ def get_llm_worker (self , llm_handle : str ) -> LLMWorkerAbstract :
8689 if llm_worker := self .llm_workers .get (llm_handle ):
8790 return llm_worker
8891 if not get_config ().cogt .inference_manager_config .is_auto_setup_preset_llm :
8992 raise InferenceManagerWorkerSetupError (
9093 f"No LLM worker for '{ llm_handle } ', set it up or enable cogt.inference_manager_config.is_auto_setup_preset_llm"
9194 )
9295
93- if not specific_llm_engine_blueprint :
94- specific_llm_engine_blueprint = get_llm_deck ().get_llm_engine_blueprint (llm_handle = llm_handle )
95- llm_worker = self ._setup_one_llm_worker (
96- llm_engine_blueprint = specific_llm_engine_blueprint ,
96+ llm_engine_blueprint = get_llm_deck ().get_llm_engine_blueprint (llm_handle = llm_handle )
97+ llm_worker = self ._setup_one_internal_llm_worker (
98+ llm_engine_blueprint = llm_engine_blueprint ,
9799 llm_handle = llm_handle ,
98100 )
99101
100102 return llm_worker
101103
104+ @override
105+ def set_llm_worker_from_external_plugin (
106+ self ,
107+ llm_handle : str ,
108+ llm_worker_class : Type [LLMWorkerAbstract ],
109+ should_warn_if_already_registered : bool = True ,
110+ ):
111+ if llm_handle in self .llm_workers :
112+ if should_warn_if_already_registered :
113+ log .warning (f"LLM worker for '{ llm_handle } ' already registered, skipping" )
114+ self .llm_workers [llm_handle ] = llm_worker_class (reporting_delegate = get_report_delegate ())
115+
102116 ####################################################################################################
103117 # Manage IMGG Workers
104118 ####################################################################################################
0 commit comments