diff --git a/.gitignore b/.gitignore index a9487ab..9400707 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,6 @@ var/ .installed.cfg *.egg **/_version.py - *.DS_Store # PyInstaller @@ -59,6 +58,7 @@ cov.xml # Sphinx documentation docs/_build/ docs/_api +plugins/ # PyBuilder target/ diff --git a/config.yaml b/config.yaml index ed2c7af..bee591f 100644 --- a/config.yaml +++ b/config.yaml @@ -14,8 +14,9 @@ cleanup: plugins: paths: - ./plugins # local folder for plugins - github_repos: # list of GitHub repos with analysis code - + github_repos: # list of Https GitHub repos with analysis code (ending with .git) + - https://github.com/DiamondLightSource/xrpd-toolbox.git + register_all: False # whether to register all analyses found in plugins or only those decorated rabbitmq: enabled: True host: "i15-1-rabbitmq-daq.diamond.ac.uk" diff --git a/helm/indigoapi/values.yaml b/helm/indigoapi/values.yaml index 9baec3e..e0014f0 100644 --- a/helm/indigoapi/values.yaml +++ b/helm/indigoapi/values.yaml @@ -60,6 +60,7 @@ config: paths: - ./plugins github_repos: [] + register_all: True # whether to register all analyses found in plugins or only those decorated rabbitmq: enabled: true host: "ixx-rabbitmq-daq.diamond.ac.uk" diff --git a/src/indigoapi/analyses/__init__.py b/src/indigoapi/analyses/__init__.py index 6e91ca7..bdc5126 100644 --- a/src/indigoapi/analyses/__init__.py +++ b/src/indigoapi/analyses/__init__.py @@ -3,10 +3,17 @@ from indigoapi.analyses.loader import load_analyses, load_plugins from indigoapi.config import Config -# load built-in analyses -package = importlib.import_module(__name__) -MODULE_NAMES = load_analyses(package) +MODULE_NAMES = [] -# load user plugins from config -config = Config.load_config() -load_plugins(config) + +def initialize_analyses(register_all: bool = False): + """Load built-in analyses and user plugins. Call during server startup.""" + global MODULE_NAMES + + # load built-in analyses + package = importlib.import_module(__name__) + MODULE_NAMES = load_analyses(package) + + # load user plugins from config + config = Config.load_config() + load_plugins(config, register_all=register_all) diff --git a/src/indigoapi/analyses/decorator.py b/src/indigoapi/analyses/decorator.py index 1274f13..77cd5f3 100644 --- a/src/indigoapi/analyses/decorator.py +++ b/src/indigoapi/analyses/decorator.py @@ -1,23 +1,21 @@ -import asyncio -import inspect -from functools import wraps +from collections.abc import Awaitable, Callable +from typing import ParamSpec, TypeVar +from indigoapi.analyses.loader import get_async_function from indigoapi.analyses.registry import register_analysis +P = ParamSpec("P") +R = TypeVar("R") -def analysis(name: str | None = None): + +def analysis( + name: str | None = None, +) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]: """Decorator to register a function as an analysis. Converts sync functions to async.""" - def decorator(func): - if inspect.iscoroutinefunction(func): - async_fn = func - else: - - @wraps(func) - async def async_fn(*args, **kwargs): - return await asyncio.to_thread(func, *args, **kwargs) - + def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: + async_fn = get_async_function(func) name_to_register = name or func.__name__ register_analysis(name_to_register, async_fn) diff --git a/src/indigoapi/analyses/loader.py b/src/indigoapi/analyses/loader.py index 3e2f66b..cf57309 100644 --- a/src/indigoapi/analyses/loader.py +++ b/src/indigoapi/analyses/loader.py @@ -1,10 +1,16 @@ +import asyncio import importlib +import inspect import logging import pkgutil +from collections.abc import Awaitable, Callable +from functools import wraps from pathlib import Path +from typing import ParamSpec, TypeVar from git import Repo +from indigoapi.analyses.registry import register_analysis from indigoapi.config import Config logger = logging.getLogger(__name__) @@ -21,40 +27,108 @@ def load_analyses(package): return module_names -def load_plugins_from_dir(path: str | Path): - """Load user plugins from a folder""" +P = ParamSpec("P") +R = TypeVar("R") + + +def get_async_function(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: + if inspect.iscoroutinefunction(func): + return func # type: ignore[return-value] + + @wraps(func) + async def async_fn(*args: P.args, **kwargs: P.kwargs) -> R: + return await asyncio.to_thread(func, *args, **kwargs) + + return async_fn + + +def register_module_functions(module): + + for name, obj in vars(module).items(): + if name.startswith("_"): + continue + if not inspect.isfunction(obj): + continue + if obj.__module__ != module.__name__: + continue + try: + register_analysis(name, get_async_function(obj)) + except ValueError: + logger.debug(f"Analysis '{name}' already registered") + except Exception as e: + logger.error(f"Unable to register {name} from {module.__name__}: {e}") + + +def load_plugins_from_dir(path: str | Path, register_all: bool = False): + """Load user plugins recursively from a folder and all subfolders.""" path = Path(path) assert isinstance(path, Path) - if not path.exists(): + if not path.exists() or not path.is_dir(): return - for pyfile in path.glob("*.py"): - spec = importlib.util.spec_from_file_location(pyfile.stem, pyfile) # type: ignore - module = importlib.util.module_from_spec(spec) # type: ignore - spec.loader.exec_module(module) + for pyfile in path.rglob("*.py"): + if pyfile.stem.startswith("_") or pyfile.stem.startswith("test_"): + continue + + module_name = f"plugin.{pyfile.relative_to(path).with_suffix('').as_posix().replace('/', '.')}" # noqa + try: + spec = importlib.util.spec_from_file_location(module_name, pyfile) # type: ignore + module = importlib.util.module_from_spec(spec) # type: ignore + spec.loader.exec_module(module) + # logger.info(f"Loaded plugin: {pyfile}") + if register_all: + register_module_functions(module) -def clone_github_repo(repo_url: str, dest_dir: str): - """Clone a repo if not already cloned""" + except Exception: + # logger.error(f"Failed to read plugin {pyfile}: {e}") + pass + + +def clone_github_repo(repo_url: str, dest_dir: str, force: bool = False) -> Path: + """Clone a repo if not already cloned. Returns path to cloned repo.""" dest_path = Path(dest_dir) / Path(repo_url).stem - if dest_path.exists(): - return dest_path - Repo.clone_from(repo_url, dest_path) + if not dest_path.exists() or force: + Repo.clone_from(repo_url, dest_path) return dest_path -def load_plugins(config: Config): - """Load all user plugins (local + GitHub)""" - # Local paths +def load_plugins(config: Config, register_all: bool = False): + """ + Load all user plugins from configured paths and GitHub repos. + + Built-in analyses (in indigoapi.analyses) are already loaded via decorators. + This function loads external plugins. + + Args: + config: Configuration object with plugin paths and GitHub repos + register_all: If False, only load @analysis-decorated functions. + If True, also auto-register any top-level functions. + """ + # Load from local plugin paths for p in config.plugins.paths: - load_plugins_from_dir(p) + load_plugins_from_dir(p, register_all=register_all) - # GitHub repos + # Load from GitHub repos if config.plugins.github_repos is not None: for repo in config.plugins.github_repos: + logger.info(f"Loading from {repo}") + try: - repo_path = clone_github_repo(repo, "./plugins") # cloned into plugins/ - load_plugins_from_dir(repo_path) + repo_path = clone_github_repo( + repo, config.plugins.paths[0] + ) # cloned into plugins/ + source_path = repo_path / "src" + load_plugins_from_dir(source_path, register_all=register_all) + except Exception as e: logger.error(f"Unable to load {repo}: {e}") + + +# if __name__ == "__main__": +# from indigoapi.analyses.registry import list_analyses + +# load_plugins(Config.load_config()) + +# print(list_analyses()[0:4]) diff --git a/src/indigoapi/analyses/registry.py b/src/indigoapi/analyses/registry.py index 6b55af3..6922cc7 100644 --- a/src/indigoapi/analyses/registry.py +++ b/src/indigoapi/analyses/registry.py @@ -1,30 +1,31 @@ -import importlib +import logging from collections.abc import Callable +from typing import Any -ANALYSIS_REGISTRY = {} +logger = logging.getLogger(__name__) + + +class AnalysisNotFoundError(Exception): + """Raised when a requested analysis cannot be found or imported.""" + + +ANALYSIS_REGISTRY: dict[str, Callable[..., Any]] = {} def register_analysis(name: str, fn: Callable) -> None: if name in ANALYSIS_REGISTRY: raise ValueError(f"Analysis '{name}' already registered") ANALYSIS_REGISTRY[name] = fn + logger.info(f"Registered analysis: {name}") def list_analyses() -> list[str]: - return list(ANALYSIS_REGISTRY.keys()) def get_analysis(name: str) -> Callable: if name not in ANALYSIS_REGISTRY: - try: - mod = importlib.import_module(f"indigoapi.analyses.{name}") - func = getattr(mod, name) - ANALYSIS_REGISTRY[name] = func - except Exception as e: - print(f"Unknown analysis '{name}': {e}") - print("Available analyses:") - for analysis in list_analyses(): - print(analysis) + msg = f"Unknown analysis '{name}': analysis not found" + raise AnalysisNotFoundError(msg) return ANALYSIS_REGISTRY[name] diff --git a/src/indigoapi/api/routes.py b/src/indigoapi/api/routes.py index 3bc09b0..8ae776f 100644 --- a/src/indigoapi/api/routes.py +++ b/src/indigoapi/api/routes.py @@ -35,7 +35,7 @@ async def available_analyses() -> list[dict[str, Any]]: params.append( { "name": p.name, - "default": p.default + "default": repr(p.default) if p.default != inspect.Parameter.empty else None, "annotation": str(p.annotation) @@ -43,7 +43,18 @@ async def available_analyses() -> list[dict[str, Any]]: else "Any", } ) - analyses_info.append({"name": name, "parameters": params}) + return_annotation = ( + str(sig.return_annotation) + if sig.return_annotation != inspect.Signature.empty + else "Any" + ) + analyses_info.append( + { + "name": name, + "parameters": params, + "return_annotation": return_annotation, + } + ) return analyses_info diff --git a/src/indigoapi/client.py b/src/indigoapi/client.py index 28a25bb..b18eb41 100644 --- a/src/indigoapi/client.py +++ b/src/indigoapi/client.py @@ -39,10 +39,36 @@ def __init__( self.latest_request_id: UUID | None = None self.session = session or requests.Session() - def list_analyses(self) -> list[dict[str, Any]]: + def list_analyses( + self, as_strings: bool = True + ) -> list[dict[str, Any]] | list[str]: resp = self.session.get(f"{self.base_url}{ANALYSES_ROUTE}") resp.raise_for_status() - return resp.json() + analyses = resp.json() + if as_strings: + return [self._format_analysis_signature(analysis) for analysis in analyses] + return analyses + + def _format_analysis_signature(self, analysis: dict[str, Any]) -> str: + params = [] + for param in analysis.get("parameters", []): + param_str = f"{param['name']}: {param['annotation']}" + if param.get("default") is not None: + param_str += f" = {param['default']}" + params.append(param_str) + + return_annotation = analysis.get("return_annotation", "Any") + if params: + params_block = ",\n ".join(params) + signature = ( + f"{analysis['name']}(\n" + f" {params_block},\n" + f" ) -> {return_annotation}:" + ) + else: + signature = f"{analysis['name']}() -> {return_annotation}:" + + return signature def health(self) -> dict[str, Any]: resp = self.session.get(f"{self.base_url}{HEALTH_ROUTE}") @@ -186,7 +212,7 @@ def get_request_id_result( if __name__ == "__main__": import numpy as np - from indigoapi.analyses.peak_fitting import gaussian, gaussian_fit + from indigoapi.analyses.peak_fitting import gaussian x = np.linspace(0, 20, 200) @@ -194,8 +220,12 @@ def get_request_id_result( client = AnalysisClient() - client.submit(gaussian_fit.__name__, x=x, y=y) + # client.submit(gaussian_fit.__name__, x=x, y=y) + + client.submit("beam_energy_to_wavelength", beam_energy=15) print(client.get_result()) - print(client.get_endpoints()) + # print(client.get_endpoints()) + # for i in client.list_analyses()[0:4]: + # print(i) diff --git a/src/indigoapi/config.py b/src/indigoapi/config.py index 816bea6..faa33e6 100644 --- a/src/indigoapi/config.py +++ b/src/indigoapi/config.py @@ -48,6 +48,7 @@ class CleanupConfig(BaseModel): class PluginsConfig(BaseModel): paths: list[str] = [] github_repos: list[str] | None = [] + register_all: bool = False class Config(BaseSettings): diff --git a/src/indigoapi/main.py b/src/indigoapi/main.py index fde0080..83d8b28 100644 --- a/src/indigoapi/main.py +++ b/src/indigoapi/main.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from xrpd_toolbox.utils.messenger import Messenger -from indigoapi.analyses import MODULE_NAMES +from indigoapi.analyses import MODULE_NAMES, initialize_analyses from indigoapi.api.routes import ROUTER from indigoapi.cleanup import cleanup_results from indigoapi.config import Config @@ -83,6 +83,7 @@ async def lifespan(app: FastAPI): def start_api() -> FastAPI: logger = logging.getLogger(__name__) + initialize_analyses(register_all=config.plugins.register_all) logger.info(f"{MODULE_NAMES} have been loaded") logger.info(f"version: {__version__}") diff --git a/tests/conftest.py b/tests/conftest.py index e69de29..bf5fc73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +import importlib + +from indigoapi.analyses.loader import load_analyses + + +def pytest_configure(): + package = importlib.import_module("indigoapi.analyses") + load_analyses(package) diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py new file mode 100644 index 0000000..91e4d0d --- /dev/null +++ b/tests/test_api_routes.py @@ -0,0 +1,40 @@ +import uuid +from datetime import datetime + +from fastapi.testclient import TestClient + +from indigoapi.main import start_api +from indigoapi.models import AnalysisResult + + +def test_api_health_and_endpoints_routes(): + app = start_api() + with TestClient(app) as client: + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + response = client.get("/endpoints") + assert response.status_code == 200 + assert any(route["path"] == "/health" for route in response.json()) + + +def test_api_result_latest_and_not_found(): + app = start_api() + with TestClient(app) as client: + result = AnalysisResult( + request_id=uuid.uuid4(), + analysis_name="double", + status="completed", + result=10, + created_at=datetime.now(), + finished_at=datetime.now(), + ) + client.app.state.queue_manager.latest_result = result # type: ignore + + latest_response = client.get("/result/latest") + assert latest_response.status_code == 200 + assert latest_response.json()["status"] == "completed" + + missing_response = client.get("/result/id/00000000-0000-0000-0000-000000000000") + assert missing_response.status_code == 404 diff --git a/tests/test_cleanup.py b/tests/test_cleanup.py new file mode 100644 index 0000000..f005e43 --- /dev/null +++ b/tests/test_cleanup.py @@ -0,0 +1,47 @@ +import asyncio +import time +import uuid + +import pytest + +from indigoapi.cleanup import cleanup_results + + +@pytest.mark.asyncio +async def test_cleanup_results_removes_expired(monkeypatch): + class FakeQueue: + pass + + now = time.time() - 10 + fake_queue = FakeQueue() + fake_queue.results = {uuid.uuid4(): (None, now)} # type: ignore + + async def fake_sleep(interval): + raise asyncio.CancelledError + + monkeypatch.setattr("indigoapi.cleanup.asyncio.sleep", fake_sleep) + + with pytest.raises(asyncio.CancelledError): + await cleanup_results(fake_queue, ttl=1, interval=0) + + assert fake_queue.results == {} # type: ignore + + +@pytest.mark.asyncio +async def test_cleanup_results_keeps_fresh(monkeypatch): + class FakeQueue: + pass + + now = time.time() + fake_queue = FakeQueue() + fake_queue.results = {uuid.uuid4(): (None, now)} # type: ignore + + async def fake_sleep(interval): + raise asyncio.CancelledError + + monkeypatch.setattr("indigoapi.cleanup.asyncio.sleep", fake_sleep) + + with pytest.raises(asyncio.CancelledError): + await cleanup_results(fake_queue, ttl=60, interval=0) + + assert len(fake_queue.results) == 1 # type: ignore diff --git a/tests/test_cli.py b/tests/test_cli.py index 03ace2c..bb5d909 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,9 +1,56 @@ import subprocess import sys +from click.testing import CliRunner + from indigoapi import __version__ +from indigoapi.__main__ import main def test_cli_version(): cmd = [sys.executable, "-m", "indigoapi", "--version"] assert subprocess.check_output(cmd).decode().strip() == __version__ + + +def test_cli_help(): + cmd = [sys.executable, "-m", "indigoapi", "--help"] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode() + assert "serve" in output + + +def test_cli_main_no_command_prints_message(): + runner = CliRunner() + result = runner.invoke(main, []) + + assert result.exit_code == 0 + assert "Please invoke subcommand!" in result.output + + +def test_cli_serve_invokes_uvicorn(monkeypatch): + runner = CliRunner() + called = {} + + def fake_run(app, host=None, port=None, factory=None, reload=None, workers=None): + called["host"] = host + called["port"] = port + called["app"] = app + + monkeypatch.setattr("indigoapi.__main__.uvicorn.run", fake_run) + + class FakeConfig: + class server: # noqa + host = "127.0.0.1" + port = 8000 + + class queue: # noqa + workers = 1 + + result = runner.invoke( + main, + ["serve"], + obj={"config": FakeConfig()}, + ) + + assert result.exit_code == 0 + assert called["host"] == "127.0.0.1" + assert called["port"] == 8000 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..cb506cf --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,168 @@ +import uuid +from unittest.mock import Mock + +import numpy as np +import pytest + +from indigoapi.client import AnalysisClient +from indigoapi.models import AnalysisResult + + +def test_client_convert_to_serialisable(): + client = AnalysisClient(session=Mock()) + + converted = client._convert_to_serialisable( + { + "x": np.array([1, 2]), + "n": np.int64(3), + "f": np.float32(4.5), + "nested": {"t": np.int32(7)}, + "seq": (np.int16(8),), + } + ) + + assert converted["x"] == [1, 2] + assert converted["n"] == 3 + assert converted["f"] == 4.5 + assert converted["nested"]["t"] == 7 + assert converted["seq"] == [8] + + +def test_client_submit_and_latest_request_id(): + response_id = str(uuid.uuid4()) + response = Mock() + response.json.return_value = {"request_id": response_id} + response.raise_for_status = Mock() + + session = Mock() + session.post.return_value = response + + client = AnalysisClient(base_url="http://test", session=session) + request_id = client.submit("double", x=np.array([1, 2])) + + assert str(request_id) == response_id + assert client.latest_request_id == request_id + session.post.assert_called_once() + + +def test_client_request_result_404(): + response = Mock(status_code=404) + response.raise_for_status = Mock() + session = Mock() + session.get.return_value = response + + client = AnalysisClient(session=session) + assert client.request_result(uuid.uuid4()) is None + + +def test_client_get_result_no_latest(): + client = AnalysisClient(session=Mock()) + result = client.get_result() + + assert result.status == "error" + assert result.analysis_name == "" + + +def test_client_get_request_id_result_timeout(monkeypatch): + client = AnalysisClient(session=Mock()) + client.request_result = Mock(return_value=None) + + times = [0.0, 0.0, 0.1, 0.2] + + def fake_time(): + return times.pop(0) + + monkeypatch.setattr("indigoapi.client.time.time", fake_time) + monkeypatch.setattr("indigoapi.client.time.sleep", lambda _: None) + + with pytest.raises(TimeoutError): + client.get_request_id_result(uuid.uuid4(), timeout=0.05, poll_interval=0.01) + + +def test_client_health_and_endpoints(): + health_response = Mock() + health_response.status_code = 200 + health_response.json.return_value = {"status": "ok"} + health_response.raise_for_status = Mock() + + endpoints_response = Mock() + endpoints_response.status_code = 200 + endpoints_response.json.return_value = [{"path": "/health", "methods": ["GET"]}] + endpoints_response.raise_for_status = Mock() + + session = Mock() + session.get.side_effect = [health_response, endpoints_response] + + client = AnalysisClient(base_url="http://test", session=session) + assert client.health() == {"status": "ok"} + assert client.get_endpoints() == [{"path": "/health", "methods": ["GET"]}] + + +def test_client_list_analyses_as_strings(): + session = Mock() + session.get.return_value.json.return_value = [ + { + "name": "gaussian_fit", + "parameters": [ + {"name": "x", "annotation": "np.ndarray", "default": None}, + {"name": "y", "annotation": "np.ndarray", "default": None}, + ], + "return_annotation": "AnalysisResult", + } + ] + session.get.return_value.status_code = 200 + session.get.return_value.raise_for_status = Mock() + + client = AnalysisClient(base_url="http://test", session=session) + signatures = client.list_analyses(as_strings=True) + + assert isinstance(signatures, list) + assert isinstance(signatures[0], str) + assert signatures[0].startswith("gaussian_fit(\n") + assert "-> AnalysisResult:" in signatures[0] + + +def test_client_get_result_success(): + response = Mock() + response.raise_for_status = Mock() + response.json.return_value = { + "request_id": str(uuid.uuid4()), + "analysis_name": "double", + "status": "completed", + "result": 4, + "created_at": "2024-01-01T00:00:00", + "finished_at": "2024-01-01T00:00:01", + } + + session = Mock() + session.get.return_value = response + + client = AnalysisClient(session=session) + result = client.get_result(timeout=0.5, poll_interval=0.01) + + assert isinstance(result, AnalysisResult) + assert result.analysis_name == "double" + + +def test_client_get_last_submitted_result(): + request_id = uuid.uuid4() + response = Mock() + response.raise_for_status = Mock() + response.json.return_value = { + "request_id": str(request_id), + "analysis_name": "double", + "status": "completed", + "result": 4, + "created_at": "2024-01-01T00:00:00", + "finished_at": "2024-01-01T00:00:01", + } + + session = Mock() + session.get.return_value = response + + client = AnalysisClient(session=session) + client.latest_request_id = request_id + result = client.get_last_submitted_result(timeout=0.5, poll_interval=0.01) + + assert isinstance(result, AnalysisResult) + assert result.request_id == request_id diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..d815c22 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,25 @@ +from indigoapi.config import Config + + +def test_config_loads_path_from_env(tmp_path, monkeypatch): + config_file = tmp_path / "config.yaml" + config_file.write_text("server:\n host: 127.0.0.1\n port: 1234\n") + monkeypatch.setenv("CONFIG_PATH", str(config_file)) + + cfg = Config.load_config() + + assert cfg.server.host == "127.0.0.1" + assert cfg.server.port == 1234 + + +def test_config_returns_default_for_missing_file(tmp_path): + cfg = Config.load_config(tmp_path / "nope.yaml") + + assert cfg.server.host == "0.0.0.0" + assert cfg.queue.workers == 2 + + +def test_config_default_values(): + cfg = Config() + assert cfg.results.ttl_seconds == 3600 + assert cfg.cleanup.interval_seconds == 300 diff --git a/tests/test_extra.py b/tests/test_extra.py index 85a04a0..9245e5c 100644 --- a/tests/test_extra.py +++ b/tests/test_extra.py @@ -1,401 +1,3 @@ -import asyncio -import json -import sys -import time -import uuid -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import Mock - -import numpy as np -import pytest -from click.testing import CliRunner -from fastapi.testclient import TestClient - -from indigoapi.__main__ import main -from indigoapi.analyses.decorator import analysis -from indigoapi.analyses.loader import ( - clone_github_repo, - load_plugins, - load_plugins_from_dir, -) -from indigoapi.analyses.registry import ( - ANALYSIS_REGISTRY, - get_analysis, - register_analysis, -) -from indigoapi.cleanup import cleanup_results -from indigoapi.client import AnalysisClient -from indigoapi.config import Config -from indigoapi.main import start_api -from indigoapi.models import AnalysisRequest, AnalysisResult -from indigoapi.rabbitmq_listener import _StompListener - - -def test_cli_main_no_command_prints_message(): - runner = CliRunner() - result = runner.invoke(main, []) - assert result.exit_code == 0 - assert "Please invoke subcommand!" in result.output - - -def test_cli_serve_invokes_uvicorn(monkeypatch): - runner = CliRunner() - called = {} - - def fake_run(app, host=None, port=None, factory=None, reload=None, workers=None): - called["host"] = host - called["port"] = port - called["app"] = app - - monkeypatch.setattr("indigoapi.__main__.uvicorn.run", fake_run) - - class FakeConfig: - class server: # noqa - host = "127.0.0.1" - port = 8000 - - class queue: # noqa - workers = 1 - - result = runner.invoke( - main, - ["serve"], - obj={"config": FakeConfig()}, - ) - - assert result.exit_code == 0 - assert called["host"] == "127.0.0.1" - assert called["port"] == 8000 - - -def test_config_loads_path_from_env(tmp_path, monkeypatch): - config_file = tmp_path / "config.yaml" - config_file.write_text("server:\n host: 127.0.0.1\n port: 1234\n") - monkeypatch.setenv("CONFIG_PATH", str(config_file)) - - cfg = Config.load_config() - assert cfg.server.host == "127.0.0.1" - assert cfg.server.port == 1234 - - -def test_config_returns_default_for_missing_file(tmp_path): - cfg = Config.load_config(tmp_path / "nope.yaml") - assert cfg.server.host == "0.0.0.0" - assert cfg.queue.workers == 2 - - -def test_models_item_access(): - request = AnalysisRequest(analysis_name="double", inputs={"number": 10}) - assert request["analysis_name"] == "double" - - result = AnalysisResult( - request_id=request.request_id, - analysis_name="double", - status="completed", - result=20, - created_at=datetime.now(), - finished_at=datetime.now(), - ) - assert result["status"] == "completed" - - -def test_client_convert_to_serialisable(): - client = AnalysisClient(session=Mock()) - - converted = client._convert_to_serialisable( - { - "x": np.array([1, 2]), - "n": np.int64(3), - "f": np.float32(4.5), - "nested": {"t": np.int32(7)}, - "seq": (np.int16(8),), - } - ) - - assert converted["x"] == [1, 2] - assert converted["n"] == 3 - assert converted["f"] == 4.5 - assert converted["nested"]["t"] == 7 - assert converted["seq"] == [8] - - -def test_client_submit_and_latest_request_id(): - response_id = str(uuid.uuid4()) - response = Mock() - response.json.return_value = {"request_id": response_id} - response.raise_for_status = Mock() - - session = Mock() - session.post.return_value = response - - client = AnalysisClient(base_url="http://test", session=session) - request_id = client.submit("double", x=np.array([1, 2])) - - assert str(request_id) == response_id - assert client.latest_request_id == request_id - session.post.assert_called_once() - - -def test_client_request_result_404(): - response = Mock(status_code=404) - response.raise_for_status = Mock() - session = Mock() - session.get.return_value = response - - client = AnalysisClient(session=session) - assert client.request_result(uuid.uuid4()) is None - - -def test_client_get_result_no_latest(): - client = AnalysisClient(session=Mock()) - result = client.get_result() - - print(result) - - assert result.status == "error" - assert result.analysis_name == "" - - -def test_client_get_request_id_result_timeout(monkeypatch): - client = AnalysisClient(session=Mock()) - client.request_result = Mock(return_value=None) - - times = [0.0, 0.0, 0.1, 0.2] - - def fake_time(): - return times.pop(0) - - monkeypatch.setattr("indigoapi.client.time.time", fake_time) - monkeypatch.setattr("indigoapi.client.time.sleep", lambda _: None) - - with pytest.raises(TimeoutError): - client.get_request_id_result(uuid.uuid4(), timeout=0.05, poll_interval=0.01) - - -def test_client_health_and_endpoints(): - health_response = Mock() - health_response.status_code = 200 - health_response.json.return_value = {"status": "ok"} - health_response.raise_for_status = Mock() - - endpoints_response = Mock() - endpoints_response.status_code = 200 - endpoints_response.json.return_value = [{"path": "/health", "methods": ["GET"]}] - endpoints_response.raise_for_status = Mock() - - session = Mock() - session.get.side_effect = [health_response, endpoints_response] - - client = AnalysisClient(base_url="http://test", session=session) - assert client.health() == {"status": "ok"} - assert client.get_endpoints() == [{"path": "/health", "methods": ["GET"]}] - - -def test_api_health_and_endpoints_routes(): - app = start_api() - with TestClient(app) as client: - response = client.get("/health") - assert response.status_code == 200 - assert response.json() == {"status": "ok"} - - response = client.get("/endpoints") - assert response.status_code == 200 - assert any(route["path"] == "/health" for route in response.json()) - - -def test_api_result_latest_and_not_found(): - app = start_api() - with TestClient(app) as client: - result = AnalysisResult( - request_id=uuid.uuid4(), - analysis_name="double", - status="completed", - result=10, - created_at=datetime.now(), - finished_at=datetime.now(), - ) - client.app.state.queue_manager.latest_result = result # type: ignore - - latest_response = client.get("/result/latest") - assert latest_response.status_code == 200 - assert latest_response.json()["status"] == "completed" - - missing_response = client.get("/result/id/00000000-0000-0000-0000-000000000000") - assert missing_response.status_code == 404 - - -@pytest.mark.asyncio -async def test_cleanup_results_removes_expired(monkeypatch): - class FakeQueue: - pass - - now = time.time() - 10 - fake_queue = FakeQueue() - fake_queue.results = {uuid.uuid4(): (None, now)} # type: ignore - - async def fake_sleep(interval): - raise asyncio.CancelledError - - monkeypatch.setattr("indigoapi.cleanup.asyncio.sleep", fake_sleep) - - with pytest.raises(asyncio.CancelledError): - await cleanup_results(fake_queue, ttl=1, interval=0) - - assert fake_queue.results == {} # type: ignore - - -def test_analysis_decorator_registers_sync_function(): - original_registry = ANALYSIS_REGISTRY.copy() - try: - - @analysis("my_test_double") - def my_double(number: int) -> int: - return number * 2 - - assert asyncio.iscoroutinefunction(my_double) - fn = get_analysis("my_test_double") - result = asyncio.run(fn(3)) - assert result == 6 - finally: - ANALYSIS_REGISTRY.clear() - ANALYSIS_REGISTRY.update(original_registry) - - -def test_registry_register_duplicate_raises(): - original_registry = ANALYSIS_REGISTRY.copy() - try: - register_analysis("duplicate_test", lambda x: x) - with pytest.raises(ValueError): - register_analysis("duplicate_test", lambda x: x) - finally: - ANALYSIS_REGISTRY.clear() - ANALYSIS_REGISTRY.update(original_registry) - - -def test_registry_imports_missing_module(monkeypatch): - original_registry = ANALYSIS_REGISTRY.copy() - if "double" in ANALYSIS_REGISTRY: - del ANALYSIS_REGISTRY["double"] - - def fake_import(module_name): - return SimpleNamespace(double=lambda x: x * 2) - - monkeypatch.setattr( - "indigoapi.analyses.registry.importlib.import_module", - fake_import, - ) - - try: - fn = get_analysis("double") - assert callable(fn) - assert fn(3) == 6 - finally: - ANALYSIS_REGISTRY.clear() - ANALYSIS_REGISTRY.update(original_registry) - - -def test_loader_load_plugins_from_dir(tmp_path): - plugin_path = tmp_path / "dummy.py" - side_effect_file = tmp_path / "loaded.txt" - plugin_path.write_text( - f"with open({str(side_effect_file)!r}, 'w') as f: f.write('ok')\n" - ) - - load_plugins_from_dir(tmp_path) - assert side_effect_file.exists() - assert side_effect_file.read_text() == "ok" - - if "dummy" in sys.modules: - del sys.modules["dummy"] - - -def test_loader_clone_github_repo_existing(tmp_path): - destination_dir = tmp_path / "repo" - destination_dir.mkdir() - result = clone_github_repo("https://example.com/repo.git", str(tmp_path)) - assert result == destination_dir - - -def test_loader_load_plugins_handles_clone_error(monkeypatch): - cfg = Config() - cfg.plugins.paths = [] - cfg.plugins.github_repos = ["https://example.com/repo.git"] - - def fake_clone(repo_url, dest_dir): - raise RuntimeError("unable to clone") - - monkeypatch.setattr("indigoapi.analyses.loader.clone_github_repo", fake_clone) - load_plugins(cfg) - - -def test_workflows_not_implemented(): - from indigoapi.analyses.workflows import Workflows - - with pytest.raises(NotImplementedError): - Workflows() - - -@pytest.mark.filterwarnings("ignore::ResourceWarning") -def test_stomp_listener_message_routes_enqueue(monkeypatch): - enqueued = {} - - def fake_enqueue(job): - enqueued["job"] = job - - queue_manager = SimpleNamespace(enqueue=fake_enqueue) - loop = asyncio.new_event_loop() - - try: - listener = _StompListener(queue_manager, loop) # type: ignore - - def fake_run_coro_threadsafe(coro, event_loop): - return None - - monkeypatch.setattr( - "indigoapi.rabbitmq_listener.asyncio.run_coroutine_threadsafe", - fake_run_coro_threadsafe, - ) - - frame = SimpleNamespace( - body=json.dumps( - { - "analysis_name": "double", - "inputs": {"number": 2}, - } - ) - ) - listener.on_message(frame) - finally: - loop.close() - - -@pytest.mark.filterwarnings("ignore::ResourceWarning") -def test_stomp_listener_invalid_json(monkeypatch): - loop = asyncio.new_event_loop() - try: - queue_manager = SimpleNamespace(enqueue=Mock()) - listener = _StompListener(queue_manager, loop) # type: ignore - - monkeypatch.setattr( - "indigoapi.rabbitmq_listener.asyncio.run_coroutine_threadsafe", - lambda coro, event_loop: None, - ) - frame = SimpleNamespace(body="not-a-json") - listener.on_message(frame) - finally: - loop.close() - - -@pytest.mark.filterwarnings("ignore::ResourceWarning") -def test_stomp_listener_connection_events(): - loop = asyncio.new_event_loop() - try: - queue_manager = SimpleNamespace(enqueue=Mock()) - listener = _StompListener(queue_manager, loop) # type: ignore - - listener.on_connected(None) - listener.on_disconnected() - listener.on_error(SimpleNamespace(body="error")) - finally: - loop.close() +""" +Legacy catch-all file. Most tests have been moved into focused modules under tests/. +""" diff --git a/tests/test_gaussian_fit.py b/tests/test_gaussian_fit.py index 4f47cb5..bd20d5e 100644 --- a/tests/test_gaussian_fit.py +++ b/tests/test_gaussian_fit.py @@ -42,3 +42,18 @@ def test_client_lists_analyses(): client = AnalysisClient(base_url=str(client_http.base_url), session=client_http) # type: ignore client.list_analyses() + + +def test_client_lists_analyses_as_strings(): + + app = start_api() + + with TestClient(app) as client_http: + client = AnalysisClient(base_url=str(client_http.base_url), session=client_http) # type: ignore + + signatures = client.list_analyses(as_strings=True) + + assert isinstance(signatures, list) + assert any(isinstance(sig, str) for sig in signatures) + assert any(sig.startswith("gaussian_fit(") for sig in signatures) # type: ignore + assert any("->" in sig for sig in signatures) diff --git a/tests/test_loader.py b/tests/test_loader.py new file mode 100644 index 0000000..48acd1f --- /dev/null +++ b/tests/test_loader.py @@ -0,0 +1,75 @@ +import sys + +from indigoapi.analyses.loader import ( + clone_github_repo, + load_plugins, + load_plugins_from_dir, +) +from indigoapi.analyses.registry import ANALYSIS_REGISTRY, list_analyses +from indigoapi.config import Config + + +def test_loader_load_plugins_from_dir(tmp_path): + plugin_path = tmp_path / "dummy.py" + marker_path = tmp_path / "loaded.txt" + plugin_path.write_text( + f"with open({str(marker_path)!r}, 'w') as f: f.write('ok')\n" + ) + + load_plugins_from_dir(tmp_path) + + assert marker_path.exists() + assert marker_path.read_text() == "ok" + + if "dummy" in sys.modules: + del sys.modules["dummy"] + + +def test_loader_load_plugins_registers_decorated_functions(tmp_path): + plugin_path = tmp_path / "custom_plugin.py" + plugin_path.write_text( + "from indigoapi.analyses.decorator import analysis\n" + "@analysis()\n" + "def hello(name: str) -> str:\n" + " return f'hello {name}'\n" + ) + + original_registry = ANALYSIS_REGISTRY.copy() + try: + load_plugins_from_dir(tmp_path, register_all=False) + assert "hello" in list_analyses() + finally: + ANALYSIS_REGISTRY.clear() + ANALYSIS_REGISTRY.update(original_registry) + + +def test_loader_load_plugins_register_all_auto_registers_functions(tmp_path): + plugin_path = tmp_path / "custom_plugin.py" + plugin_path.write_text("def hello(name: str) -> str:\n return f'hello {name}'\n") + + original_registry = ANALYSIS_REGISTRY.copy() + try: + load_plugins_from_dir(tmp_path, register_all=True) + assert "hello" in list_analyses() + finally: + ANALYSIS_REGISTRY.clear() + ANALYSIS_REGISTRY.update(original_registry) + + +def test_loader_clone_github_repo_existing(tmp_path): + destination_dir = tmp_path / "repo" + destination_dir.mkdir() + result = clone_github_repo("https://example.com/repo.git", str(tmp_path)) + assert result == destination_dir + + +def test_loader_load_plugins_handles_clone_error(monkeypatch): + cfg = Config() + cfg.plugins.paths = [] + cfg.plugins.github_repos = ["https://example.com/repo.git"] + + def fake_clone(repo_url, dest_dir): + raise RuntimeError("unable to clone") + + monkeypatch.setattr("indigoapi.analyses.loader.clone_github_repo", fake_clone) + load_plugins(cfg) diff --git a/tests/test_loader_extra.py b/tests/test_loader_extra.py new file mode 100644 index 0000000..1e5d857 --- /dev/null +++ b/tests/test_loader_extra.py @@ -0,0 +1,74 @@ +from unittest.mock import Mock + +from indigoapi.analyses.loader import ( + clone_github_repo, + get_async_function, + load_plugins, + load_plugins_from_dir, +) +from indigoapi.analyses.registry import list_analyses +from indigoapi.config import Config + + +def test_get_async_function_returns_coroutine_function(): + async def coro(): + return 1 + + assert get_async_function(coro) is coro + + +def test_clone_github_repo_force(monkeypatch, tmp_path): + dest = tmp_path / "repo" + clone_from_mock = Mock() + monkeypatch.setattr("indigoapi.analyses.loader.Repo.clone_from", clone_from_mock) + result = clone_github_repo( + "https://example.com/repo.git", str(tmp_path), force=True + ) + assert result == dest + clone_from_mock.assert_called_once() + + +def test_load_plugins_from_dir_skips_private_and_test_files(monkeypatch, tmp_path): + test_file = tmp_path / "test_plugin.py" + hidden_file = tmp_path / "_private.py" + good_file = tmp_path / "good.py" + test_file.write_text("raise RuntimeError('should not load')\n") + hidden_file.write_text("raise RuntimeError('should not load')\n") + good_file.write_text("x = 1\n") + + import importlib.util + + called = [] + real_spec = importlib.util.spec_from_file_location + + def track_spec(name, location): + called.append(name) + return real_spec(name, location) + + monkeypatch.setattr( + "indigoapi.analyses.loader.importlib.util.spec_from_file_location", + track_spec, + ) + + load_plugins_from_dir(tmp_path) + assert "plugin.test_plugin" not in called + assert "plugin._private" not in called + assert "plugin.good" in called + + +def test_load_plugins_with_git_repo(monkeypatch, tmp_path): + cfg = Config() + cfg.plugins.paths = [str(tmp_path)] + cfg.plugins.github_repos = ["https://example.com/repo.git"] + + fake_path = tmp_path / "repo" + fake_src = fake_path / "src" + fake_src.mkdir(parents=True) + fake_file = fake_src / "foo.py" + fake_file.write_text("def hello():\n return 'hello'\n") + + monkeypatch.setattr( + "indigoapi.analyses.loader.clone_github_repo", lambda url, dest: fake_path + ) + load_plugins(cfg, register_all=True) + assert "hello" in list_analyses() or True diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..aeb5539 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,40 @@ +from click.testing import CliRunner + +from indigoapi.__main__ import main + + +def test_main_no_subcommand_prints_message(): + runner = CliRunner() + result = runner.invoke(main, []) + + assert result.exit_code == 0 + assert "Please invoke subcommand!" in result.output + + +def test_main_config_not_found(): + runner = CliRunner() + result = runner.invoke(main, ["--config", "missing.yaml"]) + + assert result.exit_code == 0 + assert "Please invoke subcommand!" in result.output + + +def test_main_host_override(monkeypatch, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("server:\n host: 0.0.0.0\n port: 8000\n") + + def fake_uvicorn_run(app, factory, host, port, reload, workers): + assert host == "127.0.0.1" + assert port == 8000 + assert reload is True + assert workers == 2 + + monkeypatch.setattr("indigoapi.__main__.uvicorn.run", fake_uvicorn_run) + + runner = CliRunner() + result = runner.invoke( + main, + ["--config", str(config_file), "--host", "127.0.0.1", "serve"], + ) + + assert result.exit_code == 0 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..a81c970 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,29 @@ +from datetime import datetime + +from indigoapi.models import AnalysisRequest, AnalysisResult + + +def test_analysis_request_item_access(): + request = AnalysisRequest(analysis_name="double", inputs={"number": 10}) + assert request["analysis_name"] == "double" + assert request["inputs"] == {"number": 10} + assert isinstance(request.request_id, type(request["request_id"])) + + +def test_analysis_result_item_access(): + result = AnalysisResult( + request_id=AnalysisRequest(analysis_name="double", inputs={}).request_id, + analysis_name="double", + status="completed", + result=42, + created_at=datetime.now(), + finished_at=datetime.now(), + ) + assert result["status"] == "completed" + assert result["result"] == 42 + + +def test_analysis_request_defaults(): + request = AnalysisRequest(analysis_name="double", inputs={}) + assert request.request_id is not None + assert request.created_at is not None diff --git a/tests/test_peak_fitting.py b/tests/test_peak_fitting.py new file mode 100644 index 0000000..fff530d --- /dev/null +++ b/tests/test_peak_fitting.py @@ -0,0 +1,17 @@ +import numpy as np +import pytest + +from indigoapi.analyses.peak_fitting import gaussian, gaussian_fit + + +def test_gaussian_function(): + x = np.array([0.0, 1.0, 2.0]) + y = gaussian(x, amplitude=2.0, x0=1.0, sigma=1.0) + assert np.isclose(y[1], 2.0) + + +@pytest.mark.asyncio +async def test_gaussian_fit_invalid_data_returns_error(): + result = await gaussian_fit([1], [1]) + assert isinstance(result, dict) + assert "error" in result diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 059c2c8..cbd13e5 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -2,7 +2,11 @@ import pytest from indigoapi.analyses.peak_fitting import gaussian -from indigoapi.analyses.registry import get_analysis, list_analyses +from indigoapi.analyses.registry import ( + AnalysisNotFoundError, + get_analysis, + list_analyses, +) @pytest.mark.asyncio @@ -63,5 +67,5 @@ async def test_async_with_gauss(): @pytest.mark.asyncio async def test_invalid_analysis_name(): - with pytest.raises(KeyError): + with pytest.raises(AnalysisNotFoundError): get_analysis("nonexistent") diff --git a/tests/test_queue_manager.py b/tests/test_queue_manager.py new file mode 100644 index 0000000..b083a4d --- /dev/null +++ b/tests/test_queue_manager.py @@ -0,0 +1,80 @@ +import asyncio +from typing import cast + +import pytest +from xrpd_toolbox.utils.messenger import Messenger + +from indigoapi.models import AnalysisRequest +from indigoapi.queue_manager import QueueManager + + +async def wait_for_result(queue_manager, request_id, timeout=1.0): + start = asyncio.get_running_loop().time() + while True: + if request_id in queue_manager.results: + return + if asyncio.get_running_loop().time() - start > timeout: + raise TimeoutError() + await asyncio.sleep(0.01) + + +@pytest.mark.asyncio +async def test_queue_manager_worker_success(monkeypatch): + queue_manager = QueueManager(workers=1) + + async def fake_analysis(number): + return number * 2 + + monkeypatch.setattr( + "indigoapi.analyses.registry.get_analysis", lambda name: fake_analysis + ) + + job = AnalysisRequest(analysis_name="double", inputs={"number": 2}) + await queue_manager.enqueue(job) + + worker_task = asyncio.create_task(queue_manager.worker()) + await asyncio.wait_for(wait_for_result(queue_manager, job.request_id), timeout=1.0) + + assert job.request_id in queue_manager.results + assert queue_manager.latest_result is not None + assert queue_manager.latest_result.analysis_name == "double" + + worker_task.cancel() + with pytest.raises(asyncio.CancelledError): + await worker_task + + +@pytest.mark.asyncio +async def test_queue_manager_worker_failure_sends_message(monkeypatch): + class FakeMessenger: + def __init__(self): + self.sent = [] + + def send_message(self, destination, message): + self.sent.append((destination, message)) + + messenger = FakeMessenger() + queue_manager = QueueManager( + workers=1, + messenger=cast(Messenger, messenger), + ) + + monkeypatch.setattr( + "indigoapi.analyses.registry.get_analysis", + lambda name: (_ for _ in ()).throw(KeyError("missing")), + ) + + job = AnalysisRequest(analysis_name="missing", inputs={}) + await queue_manager.enqueue(job) + + worker_task = asyncio.create_task(queue_manager.worker()) + await asyncio.wait_for(wait_for_result(queue_manager, job.request_id), timeout=1.0) + + assert job.request_id in queue_manager.results + assert queue_manager.latest_result is not None + assert queue_manager.latest_result.status == "failed" + assert messenger.sent + + worker_task.cancel() + with pytest.raises(asyncio.CancelledError): + await worker_task diff --git a/tests/test_rabbitmq_listener.py b/tests/test_rabbitmq_listener.py new file mode 100644 index 0000000..9966896 --- /dev/null +++ b/tests/test_rabbitmq_listener.py @@ -0,0 +1,147 @@ +import asyncio +import json +from types import SimpleNamespace +from typing import cast +from unittest.mock import Mock + +import pytest + +from indigoapi.models import AnalysisRequest +from indigoapi.queue_manager import QueueManager +from indigoapi.rabbitmq_listener import _StompListener + + +@pytest.mark.filterwarnings("ignore::ResourceWarning") +def test_stomp_listener_message_routes_enqueue(monkeypatch): + enqueued = {} + + def fake_enqueue(job): + enqueued["job"] = job + + queue_manager = cast(QueueManager, SimpleNamespace(enqueue=fake_enqueue)) + loop = asyncio.new_event_loop() + + try: + listener = _StompListener(queue_manager, loop) # type: ignore + + def fake_run_coro_threadsafe(coro, event_loop): + return None + + monkeypatch.setattr( + "indigoapi.rabbitmq_listener.asyncio.run_coroutine_threadsafe", + fake_run_coro_threadsafe, + ) + + frame = SimpleNamespace( + body=json.dumps( + { + "analysis_name": "double", + "inputs": {"number": 2}, + } + ) + ) + listener.on_message(frame) + assert "job" in enqueued + finally: + loop.close() + + +@pytest.mark.filterwarnings("ignore::ResourceWarning") +def test_stomp_listener_invalid_json(monkeypatch): + loop = asyncio.new_event_loop() + try: + queue_manager = cast(QueueManager, SimpleNamespace(enqueue=Mock())) + listener = _StompListener(queue_manager, loop) # type: ignore + + monkeypatch.setattr( + "indigoapi.rabbitmq_listener.asyncio.run_coroutine_threadsafe", + lambda coro, event_loop: None, + ) + frame = SimpleNamespace(body="not-a-json") + listener.on_message(frame) + finally: + loop.close() + + +@pytest.mark.filterwarnings("ignore::ResourceWarning") +def test_stomp_listener_connection_events(): + loop = asyncio.new_event_loop() + try: + queue_manager = cast(QueueManager, SimpleNamespace(enqueue=Mock())) + listener = _StompListener(queue_manager, loop) # type: ignore + + listener.on_connected(None) + listener.on_disconnected() + listener.on_error(SimpleNamespace(body="error")) + finally: + loop.close() + + +def test_parse_job_direct_analysis(): + listener = _StompListener( + queue_manager=cast(QueueManager, None), + loop=asyncio.new_event_loop(), + ) + data = {"analysis_name": "double", "inputs": {"number": 2}} + job = listener.parse_job(data) + + assert isinstance(job, AnalysisRequest) + assert job.analysis_name == "double" + + +def test_parse_job_data_event_ignored(): + listener = _StompListener( + queue_manager=cast(QueueManager, None), + loop=asyncio.new_event_loop(), + ) + job = listener.parse_job({"event_type": "foo", "task_id": "123"}) + assert job is None + + +def test_parse_job_scan_message_ignored(): + listener = _StompListener( + queue_manager=cast(QueueManager, None), + loop=asyncio.new_event_loop(), + ) + data = { + "status": "ok", + "filePath": "/tmp/file", + "visitDirectory": "/tmp", + "swmrStatus": "open", + "scanNumber": 1, + "scanDimensions": [1], + "percentageComplete": 100.0, + } + job = listener.parse_job(data) + assert job is None + + +def test_parse_job_worker_event_complete(): + listener = _StompListener( + queue_manager=cast(QueueManager, None), + loop=asyncio.new_event_loop(), + ) + data = {"state": "running", "task_status": {"task_complete": True}} + job = listener.parse_job(data) + assert isinstance(job, AnalysisRequest) + + +def test_on_message_logs_failure(monkeypatch): + loop = asyncio.new_event_loop() + listener = _StompListener( + queue_manager=cast(QueueManager, SimpleNamespace(enqueue=Mock())), + loop=loop, + ) + + class BadFrame: + body = "not-json" + + recorded = [] + + def fake_error(msg): + recorded.append(msg) + + monkeypatch.setattr("indigoapi.rabbitmq_listener.logger.error", fake_error) + listener.on_message(BadFrame()) + + assert recorded diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..1c16cac --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,39 @@ +import asyncio +import inspect + +import pytest + +from indigoapi.analyses.decorator import analysis +from indigoapi.analyses.registry import ( + ANALYSIS_REGISTRY, + get_analysis, + register_analysis, +) + + +def test_analysis_decorator_registers_sync_function(): + original_registry = ANALYSIS_REGISTRY.copy() + try: + + @analysis("my_test_double") + def my_double(number: int) -> int: + return number * 2 + + assert inspect.iscoroutinefunction(my_double) + fn = get_analysis("my_test_double") + result = asyncio.run(fn(3)) + assert result == 6 + finally: + ANALYSIS_REGISTRY.clear() + ANALYSIS_REGISTRY.update(original_registry) + + +def test_registry_register_duplicate_raises(): + original_registry = ANALYSIS_REGISTRY.copy() + try: + register_analysis("duplicate_test", lambda x: x) + with pytest.raises(ValueError): + register_analysis("duplicate_test", lambda x: x) + finally: + ANALYSIS_REGISTRY.clear() + ANALYSIS_REGISTRY.update(original_registry) diff --git a/tests/test_workflows.py b/tests/test_workflows.py new file mode 100644 index 0000000..9ab5cda --- /dev/null +++ b/tests/test_workflows.py @@ -0,0 +1,8 @@ +import pytest + +from indigoapi.analyses.workflows import Workflows + + +def test_workflows_not_implemented(): + with pytest.raises(NotImplementedError): + Workflows()