|
| 1 | +import sys |
| 2 | +from pathlib import Path |
| 3 | +from typing import Any, List, Optional, Type |
| 4 | + |
| 5 | +from pipelex.hub import get_class_registry |
| 6 | +from pipelex.tools.typing.module_inspector import find_classes_in_module, import_module_from_file |
| 7 | + |
| 8 | + |
| 9 | +class ClassRegistryUtils: |
| 10 | + @classmethod |
| 11 | + def register_classes_in_file( |
| 12 | + cls, |
| 13 | + file_path: str, |
| 14 | + base_class: Optional[Type[Any]], |
| 15 | + is_include_imported: bool, |
| 16 | + ) -> None: |
| 17 | + """Processes a Python file to find and register classes.""" |
| 18 | + module = import_module_from_file(file_path) |
| 19 | + |
| 20 | + # Find classes that match criteria |
| 21 | + classes_to_register = find_classes_in_module( |
| 22 | + module=module, |
| 23 | + base_class=base_class, |
| 24 | + include_imported=is_include_imported, |
| 25 | + ) |
| 26 | + |
| 27 | + # Clean up sys.modules to prevent memory leaks |
| 28 | + del sys.modules[module.__name__] |
| 29 | + |
| 30 | + get_class_registry().register_classes(classes=classes_to_register) |
| 31 | + |
| 32 | + @classmethod |
| 33 | + def register_classes_in_folder( |
| 34 | + cls, |
| 35 | + folder_path: str, |
| 36 | + base_class: Optional[Type[Any]] = None, |
| 37 | + is_recursive: bool = True, |
| 38 | + is_include_imported: bool = False, |
| 39 | + ) -> None: |
| 40 | + """ |
| 41 | + Registers all classes in Python files within folders that are subclasses of base_class. |
| 42 | + If base_class is None, registers all classes. |
| 43 | +
|
| 44 | + Args: |
| 45 | + folder_paths: List of paths to folders containing Python files |
| 46 | + base_class: Optional base class to filter registerable classes |
| 47 | + recursive: Whether to search recursively in subdirectories |
| 48 | + exclude_files: List of filenames to exclude |
| 49 | + exclude_dirs: List of directory names to exclude |
| 50 | + include_imported: Whether to include classes imported from other modules |
| 51 | + """ |
| 52 | + |
| 53 | + python_files = cls.find_files_in_dir( |
| 54 | + dir_path=folder_path, |
| 55 | + pattern="*.py", |
| 56 | + is_recursive=is_recursive, |
| 57 | + ) |
| 58 | + |
| 59 | + for python_file in python_files: |
| 60 | + cls.register_classes_in_file( |
| 61 | + file_path=str(python_file), |
| 62 | + base_class=base_class, |
| 63 | + is_include_imported=is_include_imported, |
| 64 | + ) |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def find_files_in_dir(cls, dir_path: str, pattern: str, is_recursive: bool) -> List[Path]: |
| 68 | + """ |
| 69 | + Find files matching a pattern in a directory. |
| 70 | +
|
| 71 | + Args: |
| 72 | + dir_path: Directory path to search in |
| 73 | + pattern: File pattern to match (e.g. "*.py") |
| 74 | + recursive: Whether to search recursively in subdirectories |
| 75 | +
|
| 76 | + Returns: |
| 77 | + List of matching Path objects |
| 78 | + """ |
| 79 | + path = Path(dir_path) |
| 80 | + if is_recursive: |
| 81 | + return list(path.rglob(pattern)) |
| 82 | + else: |
| 83 | + return list(path.glob(pattern)) |
0 commit comments