From 0eec5e243f30526c2bd048a498a6210ec40d3ef2 Mon Sep 17 00:00:00 2001 From: RJCD-Diamond Date: Thu, 21 May 2026 13:50:43 +0000 Subject: [PATCH 1/2] restructure project --- src/indigoapi/__main__.py | 4 +- src/indigoapi/analyses/__init__.py | 20 +-- src/indigoapi/analyses/decorator.py | 24 --- src/indigoapi/analyses/loader.py | 134 ----------------- src/indigoapi/analyses/peak_fitting.py | 2 +- src/indigoapi/analyses/registry.py | 31 ---- src/indigoapi/analyses/simple_maths.py | 2 +- src/indigoapi/api/routes.py | 4 +- src/indigoapi/cleanup.py | 25 ---- src/indigoapi/main.py | 99 ------------- src/indigoapi/queue_manager.py | 67 --------- src/indigoapi/rabbitmq_listener.py | 193 ------------------------- tests/conftest.py | 2 +- tests/test_api.py | 2 +- tests/test_api_routes.py | 2 +- tests/test_cleanup.py | 6 +- tests/test_gaussian_fit.py | 2 +- tests/test_loader.py | 8 +- tests/test_loader_extra.py | 12 +- tests/test_plugins.py | 6 +- tests/test_queue_manager.py | 6 +- tests/test_rabbitmq_listener.py | 10 +- tests/test_registry.py | 4 +- 23 files changed, 38 insertions(+), 627 deletions(-) delete mode 100644 src/indigoapi/analyses/decorator.py delete mode 100644 src/indigoapi/analyses/loader.py delete mode 100644 src/indigoapi/analyses/registry.py delete mode 100644 src/indigoapi/cleanup.py delete mode 100644 src/indigoapi/main.py delete mode 100644 src/indigoapi/queue_manager.py delete mode 100644 src/indigoapi/rabbitmq_listener.py diff --git a/src/indigoapi/__main__.py b/src/indigoapi/__main__.py index 62653b9..f706b5e 100644 --- a/src/indigoapi/__main__.py +++ b/src/indigoapi/__main__.py @@ -7,7 +7,7 @@ import uvicorn from indigoapi.config import Config -from indigoapi.main import start_api +from indigoapi.server import start_api from ._version import __version__ @@ -61,7 +61,7 @@ def serve(ctx: click.Context): logger.info(f"port {config.server.port}") uvicorn.run( - f"indigoapi.main:{start_api.__name__}", + f"indigoapi.server:{start_api.__name__}", factory=True, host=config.server.host, port=int(config.server.port), diff --git a/src/indigoapi/analyses/__init__.py b/src/indigoapi/analyses/__init__.py index bdc5126..831af5e 100644 --- a/src/indigoapi/analyses/__init__.py +++ b/src/indigoapi/analyses/__init__.py @@ -1,19 +1 @@ -import importlib - -from indigoapi.analyses.loader import load_analyses, load_plugins -from indigoapi.config import Config - -MODULE_NAMES = [] - - -def initialize_analyses(register_all: bool = False): - """Load built-in analyses and user plugins. Call during server startup.""" - global MODULE_NAMES - - # load built-in analyses - package = importlib.import_module(__name__) - MODULE_NAMES = load_analyses(package) - - # load user plugins from config - config = Config.load_config() - load_plugins(config, register_all=register_all) +"""Package containing built-in analysis modules.""" diff --git a/src/indigoapi/analyses/decorator.py b/src/indigoapi/analyses/decorator.py deleted file mode 100644 index 77cd5f3..0000000 --- a/src/indigoapi/analyses/decorator.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar - -from indigoapi.analyses.loader import get_async_function -from indigoapi.analyses.registry import register_analysis - -P = ParamSpec("P") -R = TypeVar("R") - - -def analysis( - name: str | None = None, -) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]: - """Decorator to register a function as an analysis. - Converts sync functions to async.""" - - def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: - async_fn = get_async_function(func) - name_to_register = name or func.__name__ - - register_analysis(name_to_register, async_fn) - return async_fn - - return decorator diff --git a/src/indigoapi/analyses/loader.py b/src/indigoapi/analyses/loader.py deleted file mode 100644 index cf57309..0000000 --- a/src/indigoapi/analyses/loader.py +++ /dev/null @@ -1,134 +0,0 @@ -import asyncio -import importlib -import inspect -import logging -import pkgutil -from collections.abc import Awaitable, Callable -from functools import wraps -from pathlib import Path -from typing import ParamSpec, TypeVar - -from git import Repo - -from indigoapi.analyses.registry import register_analysis -from indigoapi.config import Config - -logger = logging.getLogger(__name__) - - -def load_analyses(package): - - module_names = [] - - for _, module_name, _ in pkgutil.iter_modules(package.__path__): - importlib.import_module(f"{package.__name__}.{module_name}") - module_names.append(module_name) - - return module_names - - -P = ParamSpec("P") -R = TypeVar("R") - - -def get_async_function(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: - if inspect.iscoroutinefunction(func): - return func # type: ignore[return-value] - - @wraps(func) - async def async_fn(*args: P.args, **kwargs: P.kwargs) -> R: - return await asyncio.to_thread(func, *args, **kwargs) - - return async_fn - - -def register_module_functions(module): - - for name, obj in vars(module).items(): - if name.startswith("_"): - continue - if not inspect.isfunction(obj): - continue - if obj.__module__ != module.__name__: - continue - try: - register_analysis(name, get_async_function(obj)) - except ValueError: - logger.debug(f"Analysis '{name}' already registered") - except Exception as e: - logger.error(f"Unable to register {name} from {module.__name__}: {e}") - - -def load_plugins_from_dir(path: str | Path, register_all: bool = False): - """Load user plugins recursively from a folder and all subfolders.""" - path = Path(path) - assert isinstance(path, Path) - if not path.exists() or not path.is_dir(): - return - - for pyfile in path.rglob("*.py"): - if pyfile.stem.startswith("_") or pyfile.stem.startswith("test_"): - continue - - module_name = f"plugin.{pyfile.relative_to(path).with_suffix('').as_posix().replace('/', '.')}" # noqa - try: - spec = importlib.util.spec_from_file_location(module_name, pyfile) # type: ignore - module = importlib.util.module_from_spec(spec) # type: ignore - spec.loader.exec_module(module) - # logger.info(f"Loaded plugin: {pyfile}") - if register_all: - register_module_functions(module) - - except Exception: - # logger.error(f"Failed to read plugin {pyfile}: {e}") - pass - - -def clone_github_repo(repo_url: str, dest_dir: str, force: bool = False) -> Path: - """Clone a repo if not already cloned. Returns path to cloned repo.""" - dest_path = Path(dest_dir) / Path(repo_url).stem - - if not dest_path.exists() or force: - Repo.clone_from(repo_url, dest_path) - - return dest_path - - -def load_plugins(config: Config, register_all: bool = False): - """ - Load all user plugins from configured paths and GitHub repos. - - Built-in analyses (in indigoapi.analyses) are already loaded via decorators. - This function loads external plugins. - - Args: - config: Configuration object with plugin paths and GitHub repos - register_all: If False, only load @analysis-decorated functions. - If True, also auto-register any top-level functions. - """ - # Load from local plugin paths - for p in config.plugins.paths: - load_plugins_from_dir(p, register_all=register_all) - - # Load from GitHub repos - if config.plugins.github_repos is not None: - for repo in config.plugins.github_repos: - logger.info(f"Loading from {repo}") - - try: - repo_path = clone_github_repo( - repo, config.plugins.paths[0] - ) # cloned into plugins/ - source_path = repo_path / "src" - load_plugins_from_dir(source_path, register_all=register_all) - - except Exception as e: - logger.error(f"Unable to load {repo}: {e}") - - -# if __name__ == "__main__": -# from indigoapi.analyses.registry import list_analyses - -# load_plugins(Config.load_config()) - -# print(list_analyses()[0:4]) diff --git a/src/indigoapi/analyses/peak_fitting.py b/src/indigoapi/analyses/peak_fitting.py index 40e1907..6ecc549 100644 --- a/src/indigoapi/analyses/peak_fitting.py +++ b/src/indigoapi/analyses/peak_fitting.py @@ -1,7 +1,7 @@ import numpy as np from scipy.optimize import curve_fit -from indigoapi.analyses.decorator import analysis +from indigoapi.analysis_core.decorator import analysis def gaussian(x: np.ndarray, amplitude: float, x0: float, sigma: float) -> np.ndarray: diff --git a/src/indigoapi/analyses/registry.py b/src/indigoapi/analyses/registry.py deleted file mode 100644 index 6922cc7..0000000 --- a/src/indigoapi/analyses/registry.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging -from collections.abc import Callable -from typing import Any - -logger = logging.getLogger(__name__) - - -class AnalysisNotFoundError(Exception): - """Raised when a requested analysis cannot be found or imported.""" - - -ANALYSIS_REGISTRY: dict[str, Callable[..., Any]] = {} - - -def register_analysis(name: str, fn: Callable) -> None: - if name in ANALYSIS_REGISTRY: - raise ValueError(f"Analysis '{name}' already registered") - ANALYSIS_REGISTRY[name] = fn - logger.info(f"Registered analysis: {name}") - - -def list_analyses() -> list[str]: - return list(ANALYSIS_REGISTRY.keys()) - - -def get_analysis(name: str) -> Callable: - if name not in ANALYSIS_REGISTRY: - msg = f"Unknown analysis '{name}': analysis not found" - raise AnalysisNotFoundError(msg) - - return ANALYSIS_REGISTRY[name] diff --git a/src/indigoapi/analyses/simple_maths.py b/src/indigoapi/analyses/simple_maths.py index cf4b635..e943706 100644 --- a/src/indigoapi/analyses/simple_maths.py +++ b/src/indigoapi/analyses/simple_maths.py @@ -2,7 +2,7 @@ import numpy as np -from indigoapi.analyses.decorator import analysis +from indigoapi.analysis_core.decorator import analysis @analysis() diff --git a/src/indigoapi/api/routes.py b/src/indigoapi/api/routes.py index 8ae776f..c42b742 100644 --- a/src/indigoapi/api/routes.py +++ b/src/indigoapi/api/routes.py @@ -5,9 +5,9 @@ from fastapi import APIRouter, HTTPException, Request from fastapi.routing import APIRoute -from indigoapi.analyses.registry import get_analysis, list_analyses +from indigoapi.analysis_core.registry import get_analysis, list_analyses from indigoapi.models import AnalysisRequest, AnalysisResult -from indigoapi.queue_manager import QueueManager +from indigoapi.queue import QueueManager ROUTER = APIRouter() diff --git a/src/indigoapi/cleanup.py b/src/indigoapi/cleanup.py deleted file mode 100644 index 6898430..0000000 --- a/src/indigoapi/cleanup.py +++ /dev/null @@ -1,25 +0,0 @@ -import asyncio -import time - - -async def cleanup_results(queue_manager, ttl: int, interval: int): - """ - Remove expired results from memory. - ttl = time to live - interval = poll period - - checks every interval - if live time > ttl. delete - """ - - while True: - now = time.time() - - expired = [ - rid for rid, (_, ts) in queue_manager.results.items() if now - ts > ttl - ] - - for rid in expired: - del queue_manager.results[rid] - - await asyncio.sleep(interval) diff --git a/src/indigoapi/main.py b/src/indigoapi/main.py deleted file mode 100644 index 83d8b28..0000000 --- a/src/indigoapi/main.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Interface for `python -m indigoapi`.""" - -import asyncio -import logging -from contextlib import asynccontextmanager - -from fastapi import FastAPI -from xrpd_toolbox.utils.messenger import Messenger - -from indigoapi.analyses import MODULE_NAMES, initialize_analyses -from indigoapi.api.routes import ROUTER -from indigoapi.cleanup import cleanup_results -from indigoapi.config import Config -from indigoapi.queue_manager import QueueManager -from indigoapi.rabbitmq_listener import RabbitMQListener - -from . import __version__ - -config: Config = Config.load_config() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - - rabbit_task = None - - if config.rabbitmq.enabled: - messenger = Messenger( - host=config.rabbitmq.host, - port=config.rabbitmq.port, - username=config.rabbitmq.username, - password=config.rabbitmq.password, - auto_subscribe=False, - ) - else: - messenger = None - - queue_manager = QueueManager(workers=config.queue.workers, messenger=messenger) - - workers = [ - asyncio.create_task(queue_manager.worker()) - for _ in range(queue_manager.workers) - ] - - cleanup_task = asyncio.create_task( - cleanup_results( - queue_manager, - ttl=config.results.ttl_seconds, - interval=config.cleanup.interval_seconds, - ) - ) - - if config.rabbitmq.enabled: - rabbit_listener = RabbitMQListener( - queue_manager=queue_manager, - host=config.rabbitmq.host, - port=config.rabbitmq.port, - username=config.rabbitmq.username, - password=config.rabbitmq.password, - destinations=config.rabbitmq.destinations, - ) - - rabbit_task = asyncio.create_task(rabbit_listener.start()) - - app.state.queue_manager = queue_manager - app.state.config = config - - logging.info("API started") - - yield - - logging.info("Shutting down") - - for task in workers: - task.cancel() - - cleanup_task.cancel() - - if rabbit_task is not None: - rabbit_task.cancel() - - -def start_api() -> FastAPI: - - logger = logging.getLogger(__name__) - initialize_analyses(register_all=config.plugins.register_all) - logger.info(f"{MODULE_NAMES} have been loaded") - logger.info(f"version: {__version__}") - - app = FastAPI( - title="IndigoAPI", - version=__version__, - description="An API for fast data analysis jobs", - lifespan=lifespan, - ) - - app.include_router(ROUTER) - - return app diff --git a/src/indigoapi/queue_manager.py b/src/indigoapi/queue_manager.py deleted file mode 100644 index ced563d..0000000 --- a/src/indigoapi/queue_manager.py +++ /dev/null @@ -1,67 +0,0 @@ -import asyncio -import logging -import time -from datetime import datetime -from uuid import UUID - -from xrpd_toolbox.utils.messenger import DEFAULT_DII_PROCESSED_DESTINATION, Messenger - -from indigoapi.analyses.registry import get_analysis -from indigoapi.models import AnalysisRequest, AnalysisResult - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class QueueManager: - def __init__(self, workers: int = 2, messenger: Messenger | None = None): - self.queue: asyncio.Queue[AnalysisRequest] = asyncio.Queue(maxsize=0) # 0 = inf - self.results: dict[UUID, tuple[AnalysisResult, float]] = {} - self.workers = workers - self.latest_result: AnalysisResult | None = None - self.messenger = messenger - - logger.info(self.queue) - - async def enqueue(self, job: AnalysisRequest): - job.created_at = datetime.now() - logger.info(job) - await self.queue.put(job) - - async def worker(self): - while True: - job = await self.queue.get() - - try: - analysis_fn = get_analysis(job.analysis_name) - - result_value = await analysis_fn(**job.inputs) - - analysis_result = AnalysisResult( - request_id=job.request_id, - analysis_name=job.analysis_name, - status="completed", - result=result_value, - created_at=job.created_at, - finished_at=datetime.now(), - ) - - except Exception as e: - analysis_result = AnalysisResult( - request_id=job.request_id, - analysis_name=job.analysis_name, - status="failed", - result=str(e), - created_at=job.created_at, - finished_at=datetime.now(), - ) - - if self.messenger is not None: - self.messenger.send_message( - DEFAULT_DII_PROCESSED_DESTINATION, - analysis_result.model_dump_json(), - ) - - self.results[job.request_id] = (analysis_result, time.time()) - # store latest result - self.latest_result = analysis_result diff --git a/src/indigoapi/rabbitmq_listener.py b/src/indigoapi/rabbitmq_listener.py deleted file mode 100644 index 2d92ce2..0000000 --- a/src/indigoapi/rabbitmq_listener.py +++ /dev/null @@ -1,193 +0,0 @@ -import asyncio -import json -import logging -import threading -import time -from typing import Any - -import stomp -from pydantic import BaseModel, Field - -from indigoapi.models import AnalysisRequest -from indigoapi.queue_manager import QueueManager - -logger = logging.getLogger(__name__) - -TIMEOUT = 10 - - -class ProcessingRequest(BaseModel): - # empty dict in your example, so keep flexible - model_config = {"extra": "allow"} - - -class ScanMessage(BaseModel): - status: str - filePath: str # noqa: N815 - because this is gda - visitDirectory: str # noqa: N815 - because this is gda - swmrStatus: str # noqa: N815 - because this is gda - - scanNumber: int # noqa: N815 - because this is gda - scanDimensions: list[int] # noqa: N815 - because this is gda - - scannables: list[Any] = Field(default_factory=list) - detectors: list[Any] = Field(default_factory=list) - - percentageComplete: float # noqa: N815 - because this is gda - - processingRequest: ProcessingRequest = Field(default_factory=ProcessingRequest) # noqa: N815 - because this is gda - - -def worker_event_to_job(worker_event) -> AnalysisRequest: - - # TODO: This is a placeholder - - # need to define how WorkerEvents map to AnalysisRequests - - return AnalysisRequest(analysis_name="", inputs={}) - - -class _StompListener(stomp.ConnectionListener): - def __init__(self, queue_manager: QueueManager, loop: asyncio.AbstractEventLoop): - self.queue_manager = queue_manager - self.loop = loop - - def parse_job(self, data: dict) -> AnalysisRequest | None: - - if "analysis_name" in data: - return AnalysisRequest.model_validate(data) - - elif "event_type" in data and "task_id" in data: - data_event = data - logger.info(f"Received data event: {data_event}") - logger.info("Will ignore...") - return None - - elif "status" in data and "filePath" in data and "visitDirectory" in data: - gda_scan_message = ScanMessage.model_validate( - data - ) # just to validate the message format - logger.info(f"Received GDA scan message: {gda_scan_message}") - logger.info("Will ignore...") - return None - - elif "state" in data and "task_status" in data: - worker_event = data - - if ( - worker_event["task_status"] is not None - and worker_event["task_status"]["task_complete"] - ): - return worker_event_to_job(worker_event) - else: - logger.info( - f"Received non-complete WorkerEvent: {worker_event['task_status']}" - ) - return None - - else: - logger.info(f"Not a valid job received: {data}") - - def on_connected(self, frame): - logger.info("RabbitMQ connected") - - def on_disconnected(self): - logger.warning("RabbitMQ connection lost") - - def on_error(self, frame): - logger.error(f"STOMP error: {frame.body}") - - def on_message(self, frame): - try: - data = json.loads(frame.body) - - job = AnalysisRequest.model_validate(data) - - logger.info(f"RabbitMQ job received: {job.request_id}") - - asyncio.run_coroutine_threadsafe( - self.queue_manager.enqueue(job), - self.loop, - ) - - except Exception as e: - logger.error(f"Failed to process message: {e}") - logger.error(f"Failed message: {frame.body}") - - -class RabbitMQListener: - def __init__( - self, - queue_manager: QueueManager, - host: str, - port: int, - username: str, - password: str, - destinations: list[str], - ): - self.queue_manager = queue_manager - self.host = host - self.port = port - self.username = username - self.password = password - self.destinations = destinations - - self.running = True - self.thread: threading.Thread | None = None - - async def start(self): - loop = asyncio.get_running_loop() - - self.thread = threading.Thread( - target=self._run, - args=(loop,), - daemon=True, - ) - - self.thread.start() - - logger.info("RabbitMQ listener thread started") - - def _run(self, loop: asyncio.AbstractEventLoop): - - attempt = 0 - - while self.running: - attempt += 1 - - logger.info( - f"RabbitMQ connection attempt {attempt} to {self.host}:{self.port}" - ) - - try: - conn = stomp.Connection( - [(self.host, self.port)], - heartbeats=(TIMEOUT * 1000, TIMEOUT * 1000), # heartbeat in in ms - timeout=TIMEOUT, - ) - - listener = _StompListener( - self.queue_manager, - loop, - ) - conn.set_listener("", listener) - conn.connect(self.username, self.password, wait=True) - - for i, dest in enumerate(self.destinations): - conn.subscribe(destination=dest, id=str(i), ack="auto") - logger.info(f"Subscribed to {dest}") - - if conn.is_connected(): - attempt = 0 # reset attempt to 0 after successful connection - - while conn.is_connected(): - time.sleep(1) - - except Exception as e: - logger.warning(f"RabbitMQ connection failed: {e}") - logger.info( - f"RabbitMQ connection attempt {attempt} to {self.host}:{self.port}" - ) - delay_time = TIMEOUT + attempt - - logger.info(f" Waiting {delay_time}s before next reconnect") - time.sleep(delay_time) diff --git a/tests/conftest.py b/tests/conftest.py index bf5fc73..61caa6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import importlib -from indigoapi.analyses.loader import load_analyses +from indigoapi.analysis_core.loader import load_analyses def pytest_configure(): diff --git a/tests/test_api.py b/tests/test_api.py index a55ed70..661ed79 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,8 +4,8 @@ from fastapi.testclient import TestClient from indigoapi.config import Config -from indigoapi.main import start_api from indigoapi.models import AnalysisRequest +from indigoapi.server import start_api def test_analysis_flow_with_post(): diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 91e4d0d..ad57ff6 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -3,8 +3,8 @@ from fastapi.testclient import TestClient -from indigoapi.main import start_api from indigoapi.models import AnalysisResult +from indigoapi.server import start_api def test_api_health_and_endpoints_routes(): diff --git a/tests/test_cleanup.py b/tests/test_cleanup.py index f005e43..92b8a65 100644 --- a/tests/test_cleanup.py +++ b/tests/test_cleanup.py @@ -4,7 +4,7 @@ import pytest -from indigoapi.cleanup import cleanup_results +from indigoapi.queue import cleanup_results @pytest.mark.asyncio @@ -19,7 +19,7 @@ class FakeQueue: async def fake_sleep(interval): raise asyncio.CancelledError - monkeypatch.setattr("indigoapi.cleanup.asyncio.sleep", fake_sleep) + monkeypatch.setattr("indigoapi.queue.cleanup.asyncio.sleep", fake_sleep) with pytest.raises(asyncio.CancelledError): await cleanup_results(fake_queue, ttl=1, interval=0) @@ -39,7 +39,7 @@ class FakeQueue: async def fake_sleep(interval): raise asyncio.CancelledError - monkeypatch.setattr("indigoapi.cleanup.asyncio.sleep", fake_sleep) + monkeypatch.setattr("indigoapi.queue.cleanup.asyncio.sleep", fake_sleep) with pytest.raises(asyncio.CancelledError): await cleanup_results(fake_queue, ttl=60, interval=0) diff --git a/tests/test_gaussian_fit.py b/tests/test_gaussian_fit.py index bd20d5e..cc6821a 100644 --- a/tests/test_gaussian_fit.py +++ b/tests/test_gaussian_fit.py @@ -3,7 +3,7 @@ from indigoapi.analyses.peak_fitting import gaussian, gaussian_fit from indigoapi.client import AnalysisClient -from indigoapi.main import start_api +from indigoapi.server import start_api def test_gaussian_fit_with_client(): diff --git a/tests/test_loader.py b/tests/test_loader.py index 48acd1f..c2fdfb3 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,11 +1,11 @@ import sys -from indigoapi.analyses.loader import ( +from indigoapi.analysis_core.loader import ( clone_github_repo, load_plugins, load_plugins_from_dir, ) -from indigoapi.analyses.registry import ANALYSIS_REGISTRY, list_analyses +from indigoapi.analysis_core.registry import ANALYSIS_REGISTRY, list_analyses from indigoapi.config import Config @@ -28,7 +28,7 @@ def test_loader_load_plugins_from_dir(tmp_path): def test_loader_load_plugins_registers_decorated_functions(tmp_path): plugin_path = tmp_path / "custom_plugin.py" plugin_path.write_text( - "from indigoapi.analyses.decorator import analysis\n" + "from indigoapi.analysis_core.decorator import analysis\n" "@analysis()\n" "def hello(name: str) -> str:\n" " return f'hello {name}'\n" @@ -71,5 +71,5 @@ def test_loader_load_plugins_handles_clone_error(monkeypatch): def fake_clone(repo_url, dest_dir): raise RuntimeError("unable to clone") - monkeypatch.setattr("indigoapi.analyses.loader.clone_github_repo", fake_clone) + monkeypatch.setattr("indigoapi.analysis_core.loader.clone_github_repo", fake_clone) load_plugins(cfg) diff --git a/tests/test_loader_extra.py b/tests/test_loader_extra.py index 1e5d857..0069275 100644 --- a/tests/test_loader_extra.py +++ b/tests/test_loader_extra.py @@ -1,12 +1,12 @@ from unittest.mock import Mock -from indigoapi.analyses.loader import ( +from indigoapi.analysis_core.loader import ( clone_github_repo, get_async_function, load_plugins, load_plugins_from_dir, ) -from indigoapi.analyses.registry import list_analyses +from indigoapi.analysis_core.registry import list_analyses from indigoapi.config import Config @@ -20,7 +20,9 @@ async def coro(): def test_clone_github_repo_force(monkeypatch, tmp_path): dest = tmp_path / "repo" clone_from_mock = Mock() - monkeypatch.setattr("indigoapi.analyses.loader.Repo.clone_from", clone_from_mock) + monkeypatch.setattr( + "indigoapi.analysis_core.loader.Repo.clone_from", clone_from_mock + ) result = clone_github_repo( "https://example.com/repo.git", str(tmp_path), force=True ) @@ -46,7 +48,7 @@ def track_spec(name, location): return real_spec(name, location) monkeypatch.setattr( - "indigoapi.analyses.loader.importlib.util.spec_from_file_location", + "indigoapi.analysis_core.loader.importlib.util.spec_from_file_location", track_spec, ) @@ -68,7 +70,7 @@ def test_load_plugins_with_git_repo(monkeypatch, tmp_path): fake_file.write_text("def hello():\n return 'hello'\n") monkeypatch.setattr( - "indigoapi.analyses.loader.clone_github_repo", lambda url, dest: fake_path + "indigoapi.analysis_core.loader.clone_github_repo", lambda url, dest: fake_path ) load_plugins(cfg, register_all=True) assert "hello" in list_analyses() or True diff --git a/tests/test_plugins.py b/tests/test_plugins.py index cbd13e5..32c1546 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -2,8 +2,8 @@ import pytest from indigoapi.analyses.peak_fitting import gaussian -from indigoapi.analyses.registry import ( - AnalysisNotFoundError, +from indigoapi.analysis_core.registry import ( + AnalysisNotFound, get_analysis, list_analyses, ) @@ -67,5 +67,5 @@ async def test_async_with_gauss(): @pytest.mark.asyncio async def test_invalid_analysis_name(): - with pytest.raises(AnalysisNotFoundError): + with pytest.raises(AnalysisNotFound): get_analysis("nonexistent") diff --git a/tests/test_queue_manager.py b/tests/test_queue_manager.py index b083a4d..1091ae6 100644 --- a/tests/test_queue_manager.py +++ b/tests/test_queue_manager.py @@ -5,7 +5,7 @@ from xrpd_toolbox.utils.messenger import Messenger from indigoapi.models import AnalysisRequest -from indigoapi.queue_manager import QueueManager +from indigoapi.queue import QueueManager async def wait_for_result(queue_manager, request_id, timeout=1.0): @@ -26,7 +26,7 @@ async def fake_analysis(number): return number * 2 monkeypatch.setattr( - "indigoapi.analyses.registry.get_analysis", lambda name: fake_analysis + "indigoapi.analysis_core.registry.get_analysis", lambda name: fake_analysis ) job = AnalysisRequest(analysis_name="double", inputs={"number": 2}) @@ -60,7 +60,7 @@ def send_message(self, destination, message): ) monkeypatch.setattr( - "indigoapi.analyses.registry.get_analysis", + "indigoapi.analysis_core.registry.get_analysis", lambda name: (_ for _ in ()).throw(KeyError("missing")), ) diff --git a/tests/test_rabbitmq_listener.py b/tests/test_rabbitmq_listener.py index 9966896..aa80751 100644 --- a/tests/test_rabbitmq_listener.py +++ b/tests/test_rabbitmq_listener.py @@ -7,8 +7,8 @@ import pytest from indigoapi.models import AnalysisRequest -from indigoapi.queue_manager import QueueManager -from indigoapi.rabbitmq_listener import _StompListener +from indigoapi.queue import QueueManager +from indigoapi.queue.rabbitmq import _StompListener @pytest.mark.filterwarnings("ignore::ResourceWarning") @@ -28,7 +28,7 @@ def fake_run_coro_threadsafe(coro, event_loop): return None monkeypatch.setattr( - "indigoapi.rabbitmq_listener.asyncio.run_coroutine_threadsafe", + "indigoapi.queue.rabbitmq.asyncio.run_coroutine_threadsafe", fake_run_coro_threadsafe, ) @@ -54,7 +54,7 @@ def test_stomp_listener_invalid_json(monkeypatch): listener = _StompListener(queue_manager, loop) # type: ignore monkeypatch.setattr( - "indigoapi.rabbitmq_listener.asyncio.run_coroutine_threadsafe", + "indigoapi.queue.rabbitmq.asyncio.run_coroutine_threadsafe", lambda coro, event_loop: None, ) frame = SimpleNamespace(body="not-a-json") @@ -141,7 +141,7 @@ class BadFrame: def fake_error(msg): recorded.append(msg) - monkeypatch.setattr("indigoapi.rabbitmq_listener.logger.error", fake_error) + monkeypatch.setattr("indigoapi.queue.rabbitmq.logger.error", fake_error) listener.on_message(BadFrame()) assert recorded diff --git a/tests/test_registry.py b/tests/test_registry.py index 1c16cac..ba62cee 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -3,8 +3,8 @@ import pytest -from indigoapi.analyses.decorator import analysis -from indigoapi.analyses.registry import ( +from indigoapi.analysis_core.decorator import analysis +from indigoapi.analysis_core.registry import ( ANALYSIS_REGISTRY, get_analysis, register_analysis, From 033c29a7f4c90d2b05cceb4dabb20b43235740fc Mon Sep 17 00:00:00 2001 From: RJCD-Diamond Date: Thu, 21 May 2026 13:53:34 +0000 Subject: [PATCH 2/2] restructure project after adding changed name file --- src/indigoapi/analysis_core/__init__.py | 41 +++++ src/indigoapi/analysis_core/decorator.py | 24 +++ src/indigoapi/analysis_core/loader.py | 134 ++++++++++++++++ src/indigoapi/analysis_core/registry.py | 31 ++++ src/indigoapi/queue/__init__.py | 5 + src/indigoapi/queue/cleanup.py | 25 +++ src/indigoapi/queue/manager.py | 67 ++++++++ src/indigoapi/queue/rabbitmq.py | 193 +++++++++++++++++++++++ src/indigoapi/server.py | 97 ++++++++++++ tests/test_plugins.py | 4 +- 10 files changed, 619 insertions(+), 2 deletions(-) create mode 100644 src/indigoapi/analysis_core/__init__.py create mode 100644 src/indigoapi/analysis_core/decorator.py create mode 100644 src/indigoapi/analysis_core/loader.py create mode 100644 src/indigoapi/analysis_core/registry.py create mode 100644 src/indigoapi/queue/__init__.py create mode 100644 src/indigoapi/queue/cleanup.py create mode 100644 src/indigoapi/queue/manager.py create mode 100644 src/indigoapi/queue/rabbitmq.py create mode 100644 src/indigoapi/server.py diff --git a/src/indigoapi/analysis_core/__init__.py b/src/indigoapi/analysis_core/__init__.py new file mode 100644 index 0000000..b7d05e8 --- /dev/null +++ b/src/indigoapi/analysis_core/__init__.py @@ -0,0 +1,41 @@ +import importlib + +from indigoapi.config import Config + +from .decorator import analysis +from .loader import get_async_function, load_analyses, load_plugins +from .registry import ( + ANALYSIS_REGISTRY, + AnalysisNotFoundError, + get_analysis, + list_analyses, + register_analysis, +) + +MODULE_NAMES: list[str] = [] + + +def initialize_analyses(register_all: bool = False): + """Load built-in analyses and user plugins. Call during server startup.""" + global MODULE_NAMES + + package = importlib.import_module("indigoapi.analyses") + MODULE_NAMES = load_analyses(package) + + config = Config.load_config() + load_plugins(config, register_all=register_all) + + +__all__ = [ + "analysis", + "get_async_function", + "load_analyses", + "load_plugins", + "ANALYSIS_REGISTRY", + "AnalysisNotFoundError", + "get_analysis", + "list_analyses", + "register_analysis", + "initialize_analyses", + "MODULE_NAMES", +] diff --git a/src/indigoapi/analysis_core/decorator.py b/src/indigoapi/analysis_core/decorator.py new file mode 100644 index 0000000..078fc69 --- /dev/null +++ b/src/indigoapi/analysis_core/decorator.py @@ -0,0 +1,24 @@ +from collections.abc import Awaitable, Callable +from typing import ParamSpec, TypeVar + +from indigoapi.analysis_core.loader import get_async_function +from indigoapi.analysis_core.registry import register_analysis + +P = ParamSpec("P") +R = TypeVar("R") + + +def analysis( + name: str | None = None, +) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]: + """Decorator to register a function as an analysis. + Converts sync functions to async.""" + + def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: + async_fn = get_async_function(func) + name_to_register = name or func.__name__ + + register_analysis(name_to_register, async_fn) + return async_fn + + return decorator diff --git a/src/indigoapi/analysis_core/loader.py b/src/indigoapi/analysis_core/loader.py new file mode 100644 index 0000000..d6d213e --- /dev/null +++ b/src/indigoapi/analysis_core/loader.py @@ -0,0 +1,134 @@ +import asyncio +import importlib +import inspect +import logging +import pkgutil +from collections.abc import Awaitable, Callable +from functools import wraps +from pathlib import Path +from typing import ParamSpec, TypeVar + +from git import Repo + +from indigoapi.analysis_core.registry import register_analysis +from indigoapi.config import Config + +logger = logging.getLogger(__name__) + + +def load_analyses(package): + + module_names = [] + + for _, module_name, _ in pkgutil.iter_modules(package.__path__): + importlib.import_module(f"{package.__name__}.{module_name}") + module_names.append(module_name) + + return module_names + + +P = ParamSpec("P") +R = TypeVar("R") + + +def get_async_function(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: + if inspect.iscoroutinefunction(func): + return func # type: ignore[return-value] + + @wraps(func) + async def async_fn(*args: P.args, **kwargs: P.kwargs) -> R: + return await asyncio.to_thread(func, *args, **kwargs) + + return async_fn + + +def register_module_functions(module): + + for name, obj in vars(module).items(): + if name.startswith("_"): + continue + if not inspect.isfunction(obj): + continue + if obj.__module__ != module.__name__: + continue + try: + register_analysis(name, get_async_function(obj)) + except ValueError: + logger.debug(f"Analysis '{name}' already registered") + except Exception as e: + logger.error(f"Unable to register {name} from {module.__name__}: {e}") + + +def load_plugins_from_dir(path: str | Path, register_all: bool = False): + """Load user plugins recursively from a folder and all subfolders.""" + path = Path(path) + assert isinstance(path, Path) + if not path.exists() or not path.is_dir(): + return + + for pyfile in path.rglob("*.py"): + if pyfile.stem.startswith("_") or pyfile.stem.startswith("test_"): + continue + + module_name = f"plugin.{pyfile.relative_to(path).with_suffix('').as_posix().replace('/', '.')}" # noqa + try: + spec = importlib.util.spec_from_file_location(module_name, pyfile) # type: ignore + module = importlib.util.module_from_spec(spec) # type: ignore + spec.loader.exec_module(module) + # logger.info(f"Loaded plugin: {pyfile}") + if register_all: + register_module_functions(module) + + except Exception: + # logger.error(f"Failed to read plugin {pyfile}: {e}") + pass + + +def clone_github_repo(repo_url: str, dest_dir: str, force: bool = False) -> Path: + """Clone a repo if not already cloned. Returns path to cloned repo.""" + dest_path = Path(dest_dir) / Path(repo_url).stem + + if not dest_path.exists() or force: + Repo.clone_from(repo_url, dest_path) + + return dest_path + + +def load_plugins(config: Config, register_all: bool = False): + """ + Load all user plugins from configured paths and GitHub repos. + + Built-in analyses (in indigoapi.analyses) are already loaded via decorators. + This function loads external plugins. + + Args: + config: Configuration object with plugin paths and GitHub repos + register_all: If False, only load @analysis-decorated functions. + If True, also auto-register any top-level functions. + """ + # Load from local plugin paths + for p in config.plugins.paths: + load_plugins_from_dir(p, register_all=register_all) + + # Load from GitHub repos + if config.plugins.github_repos is not None: + for repo in config.plugins.github_repos: + logger.info(f"Loading from {repo}") + + try: + repo_path = clone_github_repo( + repo, config.plugins.paths[0] + ) # cloned into plugins/ + source_path = repo_path / "src" + load_plugins_from_dir(source_path, register_all=register_all) + + except Exception as e: + logger.error(f"Unable to load {repo}: {e}") + + +# if __name__ == "__main__": +# from indigoapi.analysis_core.registry import list_analyses + +# load_plugins(Config.load_config()) + +# print(list_analyses()[0:4]) diff --git a/src/indigoapi/analysis_core/registry.py b/src/indigoapi/analysis_core/registry.py new file mode 100644 index 0000000..6922cc7 --- /dev/null +++ b/src/indigoapi/analysis_core/registry.py @@ -0,0 +1,31 @@ +import logging +from collections.abc import Callable +from typing import Any + +logger = logging.getLogger(__name__) + + +class AnalysisNotFoundError(Exception): + """Raised when a requested analysis cannot be found or imported.""" + + +ANALYSIS_REGISTRY: dict[str, Callable[..., Any]] = {} + + +def register_analysis(name: str, fn: Callable) -> None: + if name in ANALYSIS_REGISTRY: + raise ValueError(f"Analysis '{name}' already registered") + ANALYSIS_REGISTRY[name] = fn + logger.info(f"Registered analysis: {name}") + + +def list_analyses() -> list[str]: + return list(ANALYSIS_REGISTRY.keys()) + + +def get_analysis(name: str) -> Callable: + if name not in ANALYSIS_REGISTRY: + msg = f"Unknown analysis '{name}': analysis not found" + raise AnalysisNotFoundError(msg) + + return ANALYSIS_REGISTRY[name] diff --git a/src/indigoapi/queue/__init__.py b/src/indigoapi/queue/__init__.py new file mode 100644 index 0000000..1b7b008 --- /dev/null +++ b/src/indigoapi/queue/__init__.py @@ -0,0 +1,5 @@ +from .cleanup import cleanup_results +from .manager import QueueManager +from .rabbitmq import RabbitMQListener + +__all__ = ["cleanup_results", "QueueManager", "RabbitMQListener"] diff --git a/src/indigoapi/queue/cleanup.py b/src/indigoapi/queue/cleanup.py new file mode 100644 index 0000000..6898430 --- /dev/null +++ b/src/indigoapi/queue/cleanup.py @@ -0,0 +1,25 @@ +import asyncio +import time + + +async def cleanup_results(queue_manager, ttl: int, interval: int): + """ + Remove expired results from memory. + ttl = time to live + interval = poll period + + checks every interval + if live time > ttl. delete + """ + + while True: + now = time.time() + + expired = [ + rid for rid, (_, ts) in queue_manager.results.items() if now - ts > ttl + ] + + for rid in expired: + del queue_manager.results[rid] + + await asyncio.sleep(interval) diff --git a/src/indigoapi/queue/manager.py b/src/indigoapi/queue/manager.py new file mode 100644 index 0000000..d1b55c2 --- /dev/null +++ b/src/indigoapi/queue/manager.py @@ -0,0 +1,67 @@ +import asyncio +import logging +import time +from datetime import datetime +from uuid import UUID + +from xrpd_toolbox.utils.messenger import DEFAULT_DII_PROCESSED_DESTINATION, Messenger + +from indigoapi.analysis_core.registry import get_analysis +from indigoapi.models import AnalysisRequest, AnalysisResult + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class QueueManager: + def __init__(self, workers: int = 2, messenger: Messenger | None = None): + self.queue: asyncio.Queue[AnalysisRequest] = asyncio.Queue(maxsize=0) # 0 = inf + self.results: dict[UUID, tuple[AnalysisResult, float]] = {} + self.workers = workers + self.latest_result: AnalysisResult | None = None + self.messenger = messenger + + logger.info(self.queue) + + async def enqueue(self, job: AnalysisRequest): + job.created_at = datetime.now() + logger.info(job) + await self.queue.put(job) + + async def worker(self): + while True: + job = await self.queue.get() + + try: + analysis_fn = get_analysis(job.analysis_name) + + result_value = await analysis_fn(**job.inputs) + + analysis_result = AnalysisResult( + request_id=job.request_id, + analysis_name=job.analysis_name, + status="completed", + result=result_value, + created_at=job.created_at, + finished_at=datetime.now(), + ) + + except Exception as e: + analysis_result = AnalysisResult( + request_id=job.request_id, + analysis_name=job.analysis_name, + status="failed", + result=str(e), + created_at=job.created_at, + finished_at=datetime.now(), + ) + + if self.messenger is not None: + self.messenger.send_message( + DEFAULT_DII_PROCESSED_DESTINATION, + analysis_result.model_dump_json(), + ) + + self.results[job.request_id] = (analysis_result, time.time()) + # store latest result + self.latest_result = analysis_result diff --git a/src/indigoapi/queue/rabbitmq.py b/src/indigoapi/queue/rabbitmq.py new file mode 100644 index 0000000..9873ab1 --- /dev/null +++ b/src/indigoapi/queue/rabbitmq.py @@ -0,0 +1,193 @@ +import asyncio +import json +import logging +import threading +import time +from typing import Any + +import stomp +from pydantic import BaseModel, Field + +from indigoapi.models import AnalysisRequest +from indigoapi.queue import QueueManager + +logger = logging.getLogger(__name__) + +TIMEOUT = 10 + + +class ProcessingRequest(BaseModel): + # empty dict in your example, so keep flexible + model_config = {"extra": "allow"} + + +class ScanMessage(BaseModel): + status: str + filePath: str # noqa: N815 - because this is gda + visitDirectory: str # noqa: N815 - because this is gda + swmrStatus: str # noqa: N815 - because this is gda + + scanNumber: int # noqa: N815 - because this is gda + scanDimensions: list[int] # noqa: N815 - because this is gda + + scannables: list[Any] = Field(default_factory=list) + detectors: list[Any] = Field(default_factory=list) + + percentageComplete: float # noqa: N815 - because this is gda + + processingRequest: ProcessingRequest = Field(default_factory=ProcessingRequest) # noqa: N815 - because this is gda + + +def worker_event_to_job(worker_event) -> AnalysisRequest: + + # TODO: This is a placeholder - + # need to define how WorkerEvents map to AnalysisRequests + + return AnalysisRequest(analysis_name="", inputs={}) + + +class _StompListener(stomp.ConnectionListener): + def __init__(self, queue_manager: QueueManager, loop: asyncio.AbstractEventLoop): + self.queue_manager = queue_manager + self.loop = loop + + def parse_job(self, data: dict) -> AnalysisRequest | None: + + if "analysis_name" in data: + return AnalysisRequest.model_validate(data) + + elif "event_type" in data and "task_id" in data: + data_event = data + logger.info(f"Received data event: {data_event}") + logger.info("Will ignore...") + return None + + elif "status" in data and "filePath" in data and "visitDirectory" in data: + gda_scan_message = ScanMessage.model_validate( + data + ) # just to validate the message format + logger.info(f"Received GDA scan message: {gda_scan_message}") + logger.info("Will ignore...") + return None + + elif "state" in data and "task_status" in data: + worker_event = data + + if ( + worker_event["task_status"] is not None + and worker_event["task_status"]["task_complete"] + ): + return worker_event_to_job(worker_event) + else: + logger.info( + f"Received non-complete WorkerEvent: {worker_event['task_status']}" + ) + return None + + else: + logger.info(f"Not a valid job received: {data}") + + def on_connected(self, frame): + logger.info("RabbitMQ connected") + + def on_disconnected(self): + logger.warning("RabbitMQ connection lost") + + def on_error(self, frame): + logger.error(f"STOMP error: {frame.body}") + + def on_message(self, frame): + try: + data = json.loads(frame.body) + + job = AnalysisRequest.model_validate(data) + + logger.info(f"RabbitMQ job received: {job.request_id}") + + asyncio.run_coroutine_threadsafe( + self.queue_manager.enqueue(job), + self.loop, + ) + + except Exception as e: + logger.error(f"Failed to process message: {e}") + logger.error(f"Failed message: {frame.body}") + + +class RabbitMQListener: + def __init__( + self, + queue_manager: QueueManager, + host: str, + port: int, + username: str, + password: str, + destinations: list[str], + ): + self.queue_manager = queue_manager + self.host = host + self.port = port + self.username = username + self.password = password + self.destinations = destinations + + self.running = True + self.thread: threading.Thread | None = None + + async def start(self): + loop = asyncio.get_running_loop() + + self.thread = threading.Thread( + target=self._run, + args=(loop,), + daemon=True, + ) + + self.thread.start() + + logger.info("RabbitMQ listener thread started") + + def _run(self, loop: asyncio.AbstractEventLoop): + + attempt = 0 + + while self.running: + attempt += 1 + + logger.info( + f"RabbitMQ connection attempt {attempt} to {self.host}:{self.port}" + ) + + try: + conn = stomp.Connection( + [(self.host, self.port)], + heartbeats=(TIMEOUT * 1000, TIMEOUT * 1000), # heartbeat in in ms + timeout=TIMEOUT, + ) + + listener = _StompListener( + self.queue_manager, + loop, + ) + conn.set_listener("", listener) + conn.connect(self.username, self.password, wait=True) + + for i, dest in enumerate(self.destinations): + conn.subscribe(destination=dest, id=str(i), ack="auto") + logger.info(f"Subscribed to {dest}") + + if conn.is_connected(): + attempt = 0 # reset attempt to 0 after successful connection + + while conn.is_connected(): + time.sleep(1) + + except Exception as e: + logger.warning(f"RabbitMQ connection failed: {e}") + logger.info( + f"RabbitMQ connection attempt {attempt} to {self.host}:{self.port}" + ) + delay_time = TIMEOUT + attempt + + logger.info(f" Waiting {delay_time}s before next reconnect") + time.sleep(delay_time) diff --git a/src/indigoapi/server.py b/src/indigoapi/server.py new file mode 100644 index 0000000..57f3b5f --- /dev/null +++ b/src/indigoapi/server.py @@ -0,0 +1,97 @@ +"""Interface for `python -m indigoapi`.""" + +import asyncio +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from xrpd_toolbox.utils.messenger import Messenger + +from indigoapi.analysis_core import MODULE_NAMES, initialize_analyses +from indigoapi.api.routes import ROUTER +from indigoapi.config import Config +from indigoapi.queue import QueueManager, RabbitMQListener, cleanup_results + +from . import __version__ + +config: Config = Config.load_config() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + + rabbit_task = None + + if config.rabbitmq.enabled: + messenger = Messenger( + host=config.rabbitmq.host, + port=config.rabbitmq.port, + username=config.rabbitmq.username, + password=config.rabbitmq.password, + auto_subscribe=False, + ) + else: + messenger = None + + queue_manager = QueueManager(workers=config.queue.workers, messenger=messenger) + + workers = [ + asyncio.create_task(queue_manager.worker()) + for _ in range(queue_manager.workers) + ] + + cleanup_task = asyncio.create_task( + cleanup_results( + queue_manager, + ttl=config.results.ttl_seconds, + interval=config.cleanup.interval_seconds, + ) + ) + + if config.rabbitmq.enabled: + rabbit_listener = RabbitMQListener( + queue_manager=queue_manager, + host=config.rabbitmq.host, + port=config.rabbitmq.port, + username=config.rabbitmq.username, + password=config.rabbitmq.password, + destinations=config.rabbitmq.destinations, + ) + + rabbit_task = asyncio.create_task(rabbit_listener.start()) + + app.state.queue_manager = queue_manager + app.state.config = config + + logging.info("API started") + + yield + + logging.info("Shutting down") + + for task in workers: + task.cancel() + + cleanup_task.cancel() + + if rabbit_task is not None: + rabbit_task.cancel() + + +def start_api() -> FastAPI: + + logger = logging.getLogger(__name__) + initialize_analyses(register_all=config.plugins.register_all) + logger.info(f"{MODULE_NAMES} have been loaded") + logger.info(f"version: {__version__}") + + app = FastAPI( + title="indigoapi", + version=__version__, + description="An API for fast data analysis jobs", + lifespan=lifespan, + ) + + app.include_router(ROUTER) + + return app diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 32c1546..c2642ce 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -3,7 +3,7 @@ from indigoapi.analyses.peak_fitting import gaussian from indigoapi.analysis_core.registry import ( - AnalysisNotFound, + AnalysisNotFoundError, get_analysis, list_analyses, ) @@ -67,5 +67,5 @@ async def test_async_with_gauss(): @pytest.mark.asyncio async def test_invalid_analysis_name(): - with pytest.raises(AnalysisNotFound): + with pytest.raises(AnalysisNotFoundError): get_analysis("nonexistent")