Skip to content

Commit 11ee87d

Browse files
authored
Refactor/plugin management (Pipelex#129)
### πŸ“ Description - handle external plugins for LLMs ### πŸ”„ Type of Change - [ ] πŸ› Bug fix - [X] ✨ New feature - [ ] πŸ’₯ Breaking change - [ ] πŸ“š Documentation update - [X] 🧹 Code refactor - [ ] ⚑ Performance improvement - [X] βœ… Test update ### πŸ§ͺ Tests - test_external_plugin.py
1 parent dbcca5f commit 11ee87d

56 files changed

Lines changed: 645 additions & 368 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

β€Ždocs/pages/advanced-customization/index.mdβ€Ž

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ from pipelex import Pipelex
1616
pipelex = Pipelex(
1717
template_provider=MyTemplateProvider(),
1818
llm_model_provider=MyLLMProvider(),
19-
plugin_manager=MyPluginManager(),
2019
inference_manager=MyInferenceManager(),
2120
pipeline_tracker=MyPipelineTracker(),
2221
activity_manager=MyActivityManager(),
@@ -38,7 +37,6 @@ from pipelex.hub import PipelexHub
3837
hub = PipelexHub()
3938
hub.set_template_provider(MyTemplateProvider())
4039
hub.set_llm_models_provider(MyLLMProvider())
41-
hub.set_plugin_manager(MyPluginManager())
4240
# ... and so on for other components
4341
```
4442

@@ -88,54 +86,45 @@ Pipelex supports injection of the following components:
8886
- Default: `LLMModelLibrary`
8987
- [Details](llm-model-provider-injection.md)
9088

91-
3. **Plugin Manager** (`PluginManager`)
92-
93-
- Protocol: `PluginManagerProtocol`
94-
- Default: `PluginManager`
95-
- [Details](plugin-manager-injection.md)
96-
97-
4. **Inference Manager** (`InferenceManager`)
89+
3. **Inference Manager** (`InferenceManager`)
9890

9991
- Protocol: `InferenceManagerProtocol`
10092
- Default: `InferenceManager`
10193
- [Details](inference-manager-injection.md)
10294

103-
5. **Reporting Delegate** (`ReportingManager`)
95+
4. **Reporting Delegate** (`ReportingManager`)
10496

10597
- Protocol: `ReportingProtocol`
10698
- Default: `ReportingManager` or `ReportingNoOp` if disabled
10799
- [Details](reporting-delegate-injection.md)
108100

109-
6. **Pipeline Tracker** (`PipelineTracker`)
101+
5. **Pipeline Tracker** (`PipelineTracker`)
110102

111103
- Protocol: `PipelineTrackerProtocol`
112104
- Default: `PipelineTracker` or `PipelineTrackerNoOp` if disabled
113105
- [Details](pipeline-tracker-injection.md)
114106

115-
7. **Activity Manager** (`ActivityManager`)
107+
6. **Activity Manager** (`ActivityManager`)
116108

117109
- Protocol: `ActivityManagerProtocol`
118110
- Default: `ActivityManager` or `ActivityManagerNoOp` if disabled
119111
- [Details](activity-manager-injection.md)
120112

121-
8. **Secrets Provider** (`EnvSecretsProvider`)
113+
7. **Secrets Provider** (`EnvSecretsProvider`)
122114

123115
- Protocol: `SecretsProviderProtocol`
124116
- Default: `EnvSecretsProvider`
125117
- [Details](secrets-provider-injection.md)
126118

127-
9. **Content Generator** (`ContentGenerator`)
119+
8. **Content Generator** (`ContentGenerator`)
128120

129121
- Protocol: `ContentGeneratorProtocol`
130122
- Default: `ContentGenerator`
131123
- [Details](content-generator-injection.md)
132124

133-
10. **Pipe Router** (`PipeRouter`)
125+
9. **Pipe Router** (`PipeRouter`)
134126

135127
- Protocol: `PipeRouterProtocol`
136128
- Default: `PipeRouter`
137129
- [Details](pipe-router-injection.md)
138130

139-
## Best Practices
140-
141-
⚠️ Under construction

β€Žpipelex/cogt/exceptions.pyβ€Ž

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ class SdkTypeError(CogtError):
2323
pass
2424

2525

26+
class SdkRegistryError(CogtError):
27+
pass
28+
29+
2630
class LLMWorkerError(CogtError):
2731
pass
2832

@@ -128,5 +132,9 @@ def __init__(self, dependency_name: str, extra_name: str, message: Optional[str]
128132
super().__init__(error_msg)
129133

130134

135+
class MissingPluginError(CogtError):
136+
pass
137+
138+
131139
class OcrCapabilityError(CogtError):
132140
pass

β€Žpipelex/cogt/imgg/imgg_worker_factory.pyβ€Ž

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from pipelex.cogt.imgg.imgg_platform import ImggPlatform
66
from pipelex.cogt.imgg.imgg_worker_abstract import ImggWorkerAbstract
77
from pipelex.cogt.llm.llm_models.llm_platform import LLMPlatform
8-
from pipelex.cogt.plugin_manager import PluginHandle
98
from pipelex.hub import get_plugin_manager, get_secret
109
from pipelex.plugins.openai.openai_imgg_worker import OpenAIImggWorker
10+
from pipelex.plugins.plugin_sdk_registry import PluginSdkHandle
1111
from pipelex.reporting.reporting_protocol import ReportingProtocol
1212
from pipelex.tools.secrets.secrets_errors import SecretNotFoundError
1313

@@ -22,8 +22,8 @@ def make_imgg_worker(
2222
imgg_engine: ImggEngine,
2323
reporting_delegate: Optional[ReportingProtocol] = None,
2424
) -> ImggWorkerAbstract:
25-
imgg_sdk_handle = PluginHandle.get_for_imgg_engine(imgg_platform=imgg_engine.imgg_platform)
26-
plugin_manager = get_plugin_manager()
25+
imgg_sdk_handle = PluginSdkHandle.get_for_imgg_engine(imgg_platform=imgg_engine.imgg_platform)
26+
plugin_sdk_registry = get_plugin_manager().plugin_sdk_registry
2727
imgg_worker: ImggWorkerAbstract
2828
match imgg_engine.imgg_platform:
2929
case ImggPlatform.FAL_AI:
@@ -41,7 +41,9 @@ def make_imgg_worker(
4141

4242
from pipelex.plugins.fal.fal_imgg_worker import FalImggWorker
4343

44-
imgg_sdk_instance = plugin_manager.get_imgg_sdk_instance(imgg_sdk_handle=imgg_sdk_handle) or plugin_manager.set_imgg_sdk_instance(
44+
imgg_sdk_instance = plugin_sdk_registry.get_imgg_sdk_instance(
45+
imgg_sdk_handle=imgg_sdk_handle
46+
) or plugin_sdk_registry.set_imgg_sdk_instance(
4547
imgg_sdk_handle=imgg_sdk_handle,
4648
imgg_sdk_instance=FalAsyncClient(key=fal_api_key),
4749
)
@@ -54,7 +56,9 @@ def make_imgg_worker(
5456
case ImggPlatform.OPENAI:
5557
from pipelex.plugins.openai.openai_factory import OpenAIFactory
5658

57-
imgg_sdk_instance = plugin_manager.get_llm_sdk_instance(llm_sdk_handle=imgg_sdk_handle) or plugin_manager.set_llm_sdk_instance(
59+
imgg_sdk_instance = plugin_sdk_registry.get_llm_sdk_instance(
60+
llm_sdk_handle=imgg_sdk_handle
61+
) or plugin_sdk_registry.set_llm_sdk_instance(
5862
llm_sdk_handle=imgg_sdk_handle,
5963
llm_sdk_instance=OpenAIFactory.make_openai_client(llm_platform=LLMPlatform.OPENAI),
6064
)

β€Žpipelex/cogt/inference/inference_manager.pyβ€Ž

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Type
22

33
from typing_extensions import override
44

@@ -12,6 +12,7 @@
1212
from pipelex.cogt.llm.llm_models.llm_engine_factory import LLMEngineFactory
1313
from pipelex.cogt.llm.llm_worker_abstract import LLMWorkerAbstract
1414
from pipelex.cogt.llm.llm_worker_factory import LLMWorkerFactory
15+
from pipelex.cogt.llm.llm_worker_internal_abstract import LLMWorkerInternalAbstract
1516
from pipelex.cogt.ocr.ocr_engine_factory import OcrEngineFactory
1617
from pipelex.cogt.ocr.ocr_worker_abstract import OcrWorkerAbstract
1718
from pipelex.cogt.ocr.ocr_worker_factory import OcrWorkerFactory
@@ -31,10 +32,16 @@ def __init__(self):
3132
def teardown(self):
3233
self.imgg_worker_factory = ImggWorkerFactory()
3334
self.ocr_worker_factory = OcrWorkerFactory()
34-
self.llm_workers.clear()
35-
self.imgg_workers.clear()
36-
self.ocr_workers.clear()
37-
log.verbose("InferenceManagerAsync reset")
35+
for llm_worker in self.llm_workers.values():
36+
llm_worker.teardown()
37+
self.llm_workers = {}
38+
for imgg_worker in self.imgg_workers.values():
39+
imgg_worker.teardown()
40+
self.imgg_workers = {}
41+
for ocr_worker in self.ocr_workers.values():
42+
ocr_worker.teardown()
43+
self.ocr_workers = {}
44+
log.verbose("InferenceManager teardown done")
3845

3946
def print_workers(self):
4047
log.debug("LLM Workers:")
@@ -60,15 +67,15 @@ def setup_llm_workers(self):
6067
llm_handle_to_llm_engine_blueprint = get_llm_deck().llm_handles
6168
log.verbose(f"{len(llm_handle_to_llm_engine_blueprint)} LLM engine_cards found")
6269
for llm_handle, llm_engine_blueprint in llm_handle_to_llm_engine_blueprint.items():
63-
self._setup_one_llm_worker(llm_engine_blueprint=llm_engine_blueprint, llm_handle=llm_handle)
70+
self._setup_one_internal_llm_worker(llm_engine_blueprint=llm_engine_blueprint, llm_handle=llm_handle)
6471
log.verbose(f"Setup LLM worker for '{llm_handle}' on {llm_engine_blueprint.llm_platform_choice}")
6572
log.debug("Done setting up LLM Workers (async)")
6673

67-
def _setup_one_llm_worker(
74+
def _setup_one_internal_llm_worker(
6875
self,
6976
llm_engine_blueprint: LLMEngineBlueprint,
7077
llm_handle: str,
71-
) -> LLMWorkerAbstract:
78+
) -> LLMWorkerInternalAbstract:
7279
llm_engine = LLMEngineFactory.make_llm_engine(llm_engine_blueprint=llm_engine_blueprint)
7380
llm_worker = LLMWorkerFactory.make_llm_worker(
7481
llm_engine=llm_engine,
@@ -78,27 +85,34 @@ def _setup_one_llm_worker(
7885
return llm_worker
7986

8087
@override
81-
def get_llm_worker(
82-
self,
83-
llm_handle: str,
84-
specific_llm_engine_blueprint: Optional[LLMEngineBlueprint] = None,
85-
) -> LLMWorkerAbstract:
88+
def get_llm_worker(self, llm_handle: str) -> LLMWorkerAbstract:
8689
if llm_worker := self.llm_workers.get(llm_handle):
8790
return llm_worker
8891
if not get_config().cogt.inference_manager_config.is_auto_setup_preset_llm:
8992
raise InferenceManagerWorkerSetupError(
9093
f"No LLM worker for '{llm_handle}', set it up or enable cogt.inference_manager_config.is_auto_setup_preset_llm"
9194
)
9295

93-
if not specific_llm_engine_blueprint:
94-
specific_llm_engine_blueprint = get_llm_deck().get_llm_engine_blueprint(llm_handle=llm_handle)
95-
llm_worker = self._setup_one_llm_worker(
96-
llm_engine_blueprint=specific_llm_engine_blueprint,
96+
llm_engine_blueprint = get_llm_deck().get_llm_engine_blueprint(llm_handle=llm_handle)
97+
llm_worker = self._setup_one_internal_llm_worker(
98+
llm_engine_blueprint=llm_engine_blueprint,
9799
llm_handle=llm_handle,
98100
)
99101

100102
return llm_worker
101103

104+
@override
105+
def set_llm_worker_from_external_plugin(
106+
self,
107+
llm_handle: str,
108+
llm_worker_class: Type[LLMWorkerAbstract],
109+
should_warn_if_already_registered: bool = True,
110+
):
111+
if llm_handle in self.llm_workers:
112+
if should_warn_if_already_registered:
113+
log.warning(f"LLM worker for '{llm_handle}' already registered, skipping")
114+
self.llm_workers[llm_handle] = llm_worker_class(reporting_delegate=get_report_delegate())
115+
102116
####################################################################################################
103117
# Manage IMGG Workers
104118
####################################################################################################

β€Žpipelex/cogt/inference/inference_manager_protocol.pyβ€Ž

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Optional, Protocol
1+
from typing import Protocol, Type
22

33
from pipelex.cogt.imgg.imgg_worker_abstract import ImggWorkerAbstract
4-
from pipelex.cogt.llm.llm_models.llm_engine_blueprint import LLMEngineBlueprint
54
from pipelex.cogt.llm.llm_worker_abstract import LLMWorkerAbstract
65
from pipelex.cogt.ocr.ocr_worker_abstract import OcrWorkerAbstract
76

@@ -20,11 +19,14 @@ def teardown(self): ...
2019

2120
def setup_llm_workers(self): ...
2221

23-
def get_llm_worker(
22+
def get_llm_worker(self, llm_handle: str) -> LLMWorkerAbstract: ...
23+
24+
def set_llm_worker_from_external_plugin(
2425
self,
2526
llm_handle: str,
26-
specific_llm_engine_blueprint: Optional[LLMEngineBlueprint] = None,
27-
) -> LLMWorkerAbstract: ...
27+
llm_worker_class: Type[LLMWorkerAbstract],
28+
should_warn_if_already_registered: bool = True,
29+
): ...
2830

2931
####################################################################################################
3032
# IMG Generation Workers

β€Žpipelex/cogt/inference/inference_worker_abstract.pyβ€Ž

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ def __init__(
1111
):
1212
self.reporting_delegate = reporting_delegate
1313

14+
def setup(self):
15+
pass
16+
17+
def teardown(self):
18+
pass
19+
1420
@property
1521
@abstractmethod
1622
def desc(self) -> str:

β€Žpipelex/cogt/llm/llm_models/llm_deck.pyβ€Ž

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,14 @@ def get_llm_engine_blueprint(self, llm_handle: str) -> LLMEngineBlueprint:
4242

4343
@override
4444
def get_llm_setting(self, llm_setting_or_preset_id: LLMSettingOrPresetId) -> LLMSetting:
45-
the_llm_setting: LLMSetting
46-
if isinstance(llm_setting_or_preset_id, str):
45+
if isinstance(llm_setting_or_preset_id, LLMSetting):
46+
return llm_setting_or_preset_id
47+
else:
4748
# it's a preset id
4849
the_llm_preset = self.llm_presets.get(llm_setting_or_preset_id)
4950
if not the_llm_preset:
5051
raise LLMPresetNotFoundError(f"LLM preset '{llm_setting_or_preset_id}' not found in deck")
51-
the_llm_setting = the_llm_preset
52-
else:
53-
# it's an explict setting
54-
the_llm_setting = llm_setting_or_preset_id
55-
return the_llm_setting
52+
return the_llm_preset
5653

5754
@override
5855
def get_llm_setting_for_text(self, override: Optional[LLMSettingChoices] = None) -> LLMSetting:
@@ -139,6 +136,19 @@ def find_llm_model(self, llm_handle: str) -> LLMModel:
139136
)
140137
return llm_model
141138

139+
@override
140+
def find_optional_llm_model(self, llm_handle: str) -> Optional[LLMModel]:
141+
llm_models_provider = get_llm_models_provider()
142+
llm_engine_blueprint = self.llm_handles.get(llm_handle)
143+
if not llm_engine_blueprint:
144+
return None
145+
llm_model = llm_models_provider.get_optional_llm_model(
146+
llm_name=llm_engine_blueprint.llm_name,
147+
llm_version=llm_engine_blueprint.llm_version,
148+
llm_platform_choice=llm_engine_blueprint.llm_platform_choice,
149+
)
150+
return llm_model
151+
142152
@override
143153
@classmethod
144154
def final_validate(cls, deck: Self):
@@ -215,11 +225,11 @@ def validate_llm_choice_overrides(cls, value: LLMSettingChoices) -> LLMSettingCh
215225
value.for_object_list_direct = None
216226
return value
217227

218-
def add_llm_handle_to_llm_engine_blueprint(self, llm_handle: str, llm_engine_default: str):
219-
if llm_handle in self.llm_handles:
220-
raise ConfigValidationError(f"LLM engine blueprint for '{llm_handle}' is already defined in llm_handle_to_llm_engine_blueprint")
228+
def add_llm_name_as_handle_with_defaults(self, llm_name: str):
229+
if llm_name in self.llm_handles:
230+
raise ConfigValidationError(f"LLM engine blueprint for '{llm_name}' is already defined in llm deck's llm_handles")
221231
# TODO: sort the defaults by llm family
222-
self.llm_handles[llm_handle] = LLMEngineBlueprint(llm_name=llm_engine_default)
232+
self.llm_handles[llm_name] = LLMEngineBlueprint(llm_name=llm_name)
223233

224234
def validate_llm_presets(self) -> Self:
225235
for llm_preset_id, llm_setting in self.llm_presets.items():

β€Žpipelex/cogt/llm/llm_models/llm_deck_abstract.pyβ€Ž

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, Optional
2+
from typing import Dict, List, Optional
33

44
from pydantic import Field
55
from typing_extensions import Self
@@ -11,6 +11,7 @@
1111

1212
class LLMDeckAbstract(ABC):
1313
llm_handles: Dict[str, LLMEngineBlueprint] = Field(default_factory=dict)
14+
llm_external_handles: List[str] = Field(default_factory=list)
1415
llm_presets: Dict[str, LLMSetting] = Field(default_factory=dict)
1516
llm_choice_defaults: LLMSettingChoices
1617
llm_choice_overrides: LLMSettingChoices = LLMSettingChoices(
@@ -57,6 +58,10 @@ def get_llm_setting_for_object_list_direct(self, override: Optional[LLMSettingCh
5758
def find_llm_model(self, llm_handle: str) -> LLMModel:
5859
pass
5960

61+
@abstractmethod
62+
def find_optional_llm_model(self, llm_handle: str) -> Optional[LLMModel]:
63+
pass
64+
6065
@classmethod
6166
@abstractmethod
6267
def final_validate(cls, deck: Self):

β€Žpipelex/cogt/llm/llm_models/llm_engine.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def desc(self) -> str:
3131

3232
@property
3333
def tag(self) -> str:
34-
return f"{self.llm_platform} - {self.llm_model.llm_name} - {self.llm_model.version}"
34+
return f"{self.llm_platform}:{self.llm_model.llm_name}:{self.llm_model.version}:id-[{self.llm_id}]"
3535

3636
@property
3737
def is_gen_object_supported(self) -> bool:

β€Žpipelex/cogt/llm/llm_models/llm_engine_factory.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def make_llm_engine(
1515
Create an instance of LLMEngine based on the parameters provided through the LLMEngineCard.
1616
1717
Args:
18-
llm_engine_blueprint: LLMEngineCard
18+
llm_engine_blueprint: LLMEngineBlueprint
1919
2020
"""
2121
llm_platform_choice = llm_engine_blueprint.llm_platform_choice

0 commit comments

Comments
Β (0)