|
1 | 1 | from collections import defaultdict |
2 | 2 | from typing import ClassVar, Dict, List, Optional, Type |
3 | 3 |
|
| 4 | +from kajson.class_registry_abstract import ClassRegistryAbstract |
| 5 | + |
4 | 6 | from pipelex import log |
5 | 7 | from pipelex.cogt.content_generation.content_generator_protocol import ContentGeneratorProtocol |
6 | 8 | from pipelex.cogt.imgg.imgg_worker_abstract import ImggWorkerAbstract |
@@ -43,6 +45,7 @@ def __init__(self): |
43 | 45 | self._config: Optional[ConfigRoot] = None |
44 | 46 | self._secrets_provider: Optional[SecretsProviderAbstract] = None |
45 | 47 | self._template_provider: Optional[TemplateProviderAbstract] = None |
| 48 | + self._class_registry: Optional[ClassRegistryAbstract] = None |
46 | 49 | # cogt |
47 | 50 | self._llm_models_provider: Optional[LLMModelProviderAbstract] = None |
48 | 51 | self._llm_deck_provider: Optional[LLMDeckAbstract] = None |
@@ -112,6 +115,9 @@ def set_secrets_provider(self, secrets_provider: SecretsProviderAbstract): |
112 | 115 | def set_template_provider(self, template_provider: TemplateProviderAbstract): |
113 | 116 | self._template_provider = template_provider |
114 | 117 |
|
| 118 | + def set_class_registry(self, class_registry: ClassRegistryAbstract): |
| 119 | + self._class_registry = class_registry |
| 120 | + |
115 | 121 | # cogt |
116 | 122 |
|
117 | 123 | def set_llm_models_provider(self, llm_models_provider: LLMModelProviderAbstract): |
@@ -187,6 +193,11 @@ def get_required_template_provider(self) -> TemplateProviderAbstract: |
187 | 193 | raise RuntimeError("Template provider is not set. You must initialize Pipelex first.") |
188 | 194 | return self._template_provider |
189 | 195 |
|
| 196 | + def get_required_class_registry(self) -> ClassRegistryAbstract: |
| 197 | + if self._class_registry is None: |
| 198 | + raise RuntimeError("ClassRegistry is not initialized") |
| 199 | + return self._class_registry |
| 200 | + |
190 | 201 | # cogt |
191 | 202 |
|
192 | 203 | def get_required_llm_models_provider(self) -> LLMModelProviderAbstract: |
@@ -294,6 +305,10 @@ def get_template(template_name: str) -> str: |
294 | 305 | return get_template_provider().get_template(template_name=template_name) |
295 | 306 |
|
296 | 307 |
|
| 308 | +def get_class_registry() -> ClassRegistryAbstract: |
| 309 | + return get_pipelex_hub().get_required_class_registry() |
| 310 | + |
| 311 | + |
297 | 312 | # cogt |
298 | 313 |
|
299 | 314 |
|
|
0 commit comments