diff --git a/.pycodestylerc b/.pycodestylerc index 24fc83752..162bcd630 100644 --- a/.pycodestylerc +++ b/.pycodestylerc @@ -1,5 +1,5 @@ [pycodestyle] count = True max-line-length = 120 -exclude=test_diff.py,migrations,venv*,parse.py,config.py +exclude=test_diff.py,migrations,venv*,.venv*,parse.py,config.py ignore = E701 diff --git a/migrations/versions/d4f8e2a1b3c7_.py b/migrations/versions/d4f8e2a1b3c7_.py new file mode 100644 index 000000000..e84d0302e --- /dev/null +++ b/migrations/versions/d4f8e2a1b3c7_.py @@ -0,0 +1,44 @@ +"""Add api_token table for scoped API token auth. + +Revision ID: d4f8e2a1b3c7 +Revises: c8f3a2b1d4e5 +Create Date: 2026-06-11 03:00:00.000000 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'd4f8e2a1b3c7' +down_revision = 'c8f3a2b1d4e5' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply the migration.""" + op.add_column('user', sa.Column('github_login', sa.String(length=255), nullable=True)) + op.create_table( + 'api_token', + sa.Column('id', sa.Integer(), nullable=False, autoincrement=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('token_name', sa.String(length=50), nullable=False), + sa.Column('token_hash', sa.String(length=255), nullable=False), + sa.Column('token_prefix', sa.String(length=16), nullable=False), + sa.Column('scopes_json', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('revoked_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], onupdate='CASCADE', ondelete='CASCADE'), + sa.UniqueConstraint('user_id', 'token_name', name='uq_user_token_name'), + mysql_engine='InnoDB' + ) + op.create_index('ix_api_token_token_prefix', 'api_token', ['token_prefix']) + + +def downgrade(): + """Revert the migration.""" + op.drop_index('ix_api_token_token_prefix', table_name='api_token') + op.drop_table('api_token') + op.drop_column('user', 'github_login') diff --git a/mod_api/__init__.py b/mod_api/__init__.py new file mode 100644 index 000000000..55f9445d6 --- /dev/null +++ b/mod_api/__init__.py @@ -0,0 +1,37 @@ +""" +mod_api: JSON REST API blueprint for the CCExtractor CI platform. + +Registered at /api/v1. All endpoints return structured JSON, use scoped +Bearer token auth, and enforce per-client rate limiting. +""" + +from flask import Blueprint + +mod_api = Blueprint('api', __name__) + +# Middleware imports +from mod_api.middleware import auth # noqa: E402 +from mod_api.middleware import error_handler # noqa: E402 +from mod_api.middleware import rate_limit # noqa: E402 +from mod_api.middleware import security # noqa: E402 + +# Explicitly register before_request hooks in the exact order they should run +mod_api.before_request(auth.authenticate_request) +mod_api.before_request(rate_limit.check_rate_limit) +mod_api.before_request(auth.enforce_auth_error) + +# Explicitly register after_request hooks. +# NOTE: Flask executes after_request hooks in REVERSE registration order. +# Registration: security → rate_limit → (convert is app-level, see below) +# Execution: rate_limit → security +# This means rate-limit headers are added first, then security headers layer +# on top — both on the same response object. +mod_api.after_request(security.add_security_headers) +mod_api.after_request(rate_limit.add_rate_limit_headers) + +# Registered as after_app_request so it fires for ALL requests (including +# routing-level 404s/405s that never enter the blueprint). +mod_api.after_app_request(error_handler.convert_api_errors_to_json) + +# Route modules +from mod_api.routes import auth as auth_routes # noqa: E402, F401 diff --git a/mod_api/middleware/__init__.py b/mod_api/middleware/__init__.py new file mode 100644 index 000000000..860b3ce01 --- /dev/null +++ b/mod_api/middleware/__init__.py @@ -0,0 +1 @@ +"""mod_api.middleware: auth, rate limiting, validation, and error handling.""" diff --git a/mod_api/middleware/auth.py b/mod_api/middleware/auth.py new file mode 100644 index 000000000..4665eafb1 --- /dev/null +++ b/mod_api/middleware/auth.py @@ -0,0 +1,140 @@ +""" +Bearer token authentication and scope/role enforcement for API routes. + +Runs as a before_request hook on the api blueprint. Public endpoints +(token creation, health check) are exempted. On success, the authenticated +user and token are stored in flask.g for downstream handlers. + +HTTP semantics: + 401 = token missing, expired, revoked, or invalid + 403 = valid token but insufficient scope or role +""" + +import functools +from typing import List + +from flask import g, request + +from mod_api.middleware.error_handler import make_error_response +from mod_api.models.api_token import ApiToken + +_AUTH_FAILED_MSG = 'Bearer token is missing, expired, or invalid.' + +# These endpoints bypass auth entirely. +_PUBLIC_ENDPOINTS = frozenset([ + 'api.create_token', # POST /auth/tokens (uses email/password body) + 'api.system_health', # GET /system/health (uptime monitoring) +]) + + +def _unauthorized(): + """Shorthand for a 401 response with the standard auth failure message.""" + return make_error_response( + 'unauthorized', _AUTH_FAILED_MSG, http_status=401) + + +def authenticate_request(): + """Validate Bearer token and attach user context to the request. + + If auth fails, sets g.auth_error instead of returning immediately, + so that subsequent hooks (like rate limiting) still run. + """ + if request.endpoint in _PUBLIC_ENDPOINTS: + g.api_user = None + g.api_token = None + return + + auth_header = request.headers.get('Authorization', '') + if not auth_header: + g.auth_error = _unauthorized() + return + + parts = auth_header.split(' ', 1) + if len(parts) != 2 or parts[0] != 'Bearer': + g.auth_error = _unauthorized() + return + + token_value = parts[1].strip() + if not token_value or not token_value.startswith('spci_'): + g.auth_error = _unauthorized() + return + + # Look up by prefix, then verify the full hash against each candidate. + prefix = ApiToken.extract_prefix(token_value) + candidates = ApiToken.query.filter_by(token_prefix=prefix).all() + + if not candidates: + g.auth_error = _unauthorized() + return + + matched_token = None + for candidate in candidates: + if ApiToken.verify_token(token_value, candidate.token_hash): + matched_token = candidate + break + + if matched_token is None: + g.auth_error = _unauthorized() + return + + if not matched_token.is_valid: + g.auth_error = _unauthorized() + return + + g.api_token = matched_token + g.api_user = matched_token.user + + +def enforce_auth_error(): + """Return any stored auth errors after rate limiting.""" + if hasattr(g, 'auth_error') and g.auth_error is not None: + return g.auth_error + + +def require_scope(*scopes: str): + """Reject the request if the token lacks any of the ``scopes``.""" + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + token = getattr(g, 'api_token', None) + if token is None: + return _unauthorized() + + missing_scopes = [s for s in scopes if not token.has_scope(s)] + if missing_scopes: + return make_error_response( + 'forbidden', + 'Token lacks the required scopes for this operation.', + details={ + 'required_scopes': list(scopes), + 'missing_scopes': missing_scopes, + 'token_scopes': token.scopes, + }, + http_status=403, + ) + return f(*args, **kwargs) + return decorated_function + return decorator + + +def require_roles(roles: List[str]): + """Reject the request if the user's role is not in ``roles``.""" + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + user = getattr(g, 'api_user', None) + if user is None: + return _unauthorized() + if user.role.value not in roles: + return make_error_response( + 'forbidden', + 'Your role does not have permission for this operation.', + details={ + 'required_roles': roles, + 'user_role': user.role.value, + }, + http_status=403, + ) + return f(*args, **kwargs) + return decorated_function + return decorator diff --git a/mod_api/middleware/error_handler.py b/mod_api/middleware/error_handler.py new file mode 100644 index 000000000..1a1d42453 --- /dev/null +++ b/mod_api/middleware/error_handler.py @@ -0,0 +1,170 @@ +"""Structured JSON error responses for API routes.""" + +from flask import current_app, jsonify, request +from marshmallow import ValidationError as MarshmallowValidationError +from sqlalchemy.exc import SQLAlchemyError + +from mod_api import mod_api + +_API_PREFIX = '/api/v1' + + +def make_error_response(code, message, details=None, http_status=400): + """Build a JSON error response conforming to the ErrorResponse schema.""" + body = { + 'code': code, + 'message': str(message)[:500], + 'details': details if details is not None else {}, + } + response = jsonify(body) + response.status_code = http_status + return response + + +@mod_api.errorhandler(400) +def handle_400(error): + """Bad request.""" + return make_error_response( + 'validation_error', + getattr(error, 'description', 'Bad request.'), + http_status=400, + ) + + +@mod_api.errorhandler(401) +def handle_401(error): + """Unauthorized.""" + return make_error_response( + 'unauthorized', + 'Bearer token is missing, expired, or invalid.', + http_status=401, + ) + + +@mod_api.errorhandler(403) +def handle_403(error): + """Forbidden.""" + return make_error_response( + 'forbidden', + 'Token does not have the required scope for this operation.', + http_status=403, + ) + + +@mod_api.errorhandler(404) +def handle_404(error): + """Not found.""" + return make_error_response( + 'not_found', + getattr(error, 'description', 'Resource not found.'), + http_status=404, + ) + + +@mod_api.errorhandler(405) +def handle_405(error): + """Handle method-not-allowed errors for API routes.""" + resp = make_error_response( + 'method_not_allowed', + 'Method not allowed.', + http_status=405, + ) + if hasattr(error, 'valid_methods') and error.valid_methods: + resp.headers['Allow'] = ', '.join(error.valid_methods) + return resp + + +@mod_api.errorhandler(422) +def handle_422(error): + """Unprocessable entity.""" + return make_error_response( + 'unprocessable', + getattr( + error, + 'description', + 'Request is valid JSON but semantically invalid.'), + http_status=422, + ) + + +@mod_api.errorhandler(429) +def handle_429(error): + """Rate limited. + + This is only a fallback for 429s raised outside the rate-limit + middleware. The live limiter (mod_api.middleware.rate_limit) returns + accurate per-bucket limit/retry_after/window values and the + Retry-After header; we deliberately don't hardcode numbers here that + would be wrong for the auth (5/15m) and write (20/min) buckets. + """ + return make_error_response( + 'rate_limited', + 'Rate limit exceeded.', + http_status=429, + ) + + +@mod_api.errorhandler(500) +def handle_500(error): + """Handle unexpected server errors for API routes.""" + current_app.logger.exception(error) + return make_error_response( + 'internal_error', + 'An unexpected error occurred.', + http_status=500, + ) + + +@mod_api.errorhandler(MarshmallowValidationError) +def handle_marshmallow_validation_error(error): + """Catch schema validation failures and return them as 400.""" + return make_error_response( + 'validation_error', + 'Request failed schema validation.', + details={'fields': error.messages}, + http_status=400, + ) + + +@mod_api.errorhandler(SQLAlchemyError) +def handle_sqlalchemy_error(error): + """Log database errors.""" + current_app.logger.exception(error) + return make_error_response( + 'internal_error', + 'An unexpected database error occurred.', + http_status=500, + ) + + +@mod_api.errorhandler(ValueError) +def handle_value_error(error): + """Catch plain ValueErrors raised by model @validates (e.g. scopes_json).""" + return make_error_response( + 'invalid_input', + str(error), + http_status=400, + ) + + +def convert_api_errors_to_json(response): + """Catch routing errors that were handled by global app handlers and convert them to JSON.""" + if request.path.startswith(_API_PREFIX): + if response.status_code >= 500: + new_resp = make_error_response( + 'internal_error', 'An unexpected error occurred.', http_status=response.status_code + ) + response.data = new_resp.data + response.mimetype = new_resp.mimetype + return response + if response.status_code == 404: + new_resp = make_error_response('not_found', 'Resource not found.', http_status=404) + response.data = new_resp.data + response.mimetype = new_resp.mimetype + return response + if response.status_code == 405: + new_resp = make_error_response('method_not_allowed', 'Method not allowed.', http_status=405) + response.data = new_resp.data + response.mimetype = new_resp.mimetype + return response + return response diff --git a/mod_api/middleware/rate_limit.py b/mod_api/middleware/rate_limit.py new file mode 100644 index 000000000..222dd0f5e --- /dev/null +++ b/mod_api/middleware/rate_limit.py @@ -0,0 +1,143 @@ +""" +Per-client fixed-window rate limiting for API endpoints. + +Limits: + POST /auth/tokens 5 req / 15 min (keyed by IP) + POST/DELETE/PUT/PATCH 20 req / min (keyed by token) + GET 120 req / min (keyed by token) + +Includes X-RateLimit-* headers on every response. + +Note: This is a fixed-window implementation (counter resets when the +window expires). For true sliding-window behavior, consider migrating +to Redis with a sorted-set approach. State is per-process, so multiple +Gunicorn workers enforce limits independently. +""" + +import threading +import time + +from flask import current_app, g, request + +from mod_api.middleware.error_handler import make_error_response + +_rate_limit_store = {} # key -> {'count': int, 'window_start': float} +_rate_limit_lock = threading.Lock() +_eviction_counter = 0 +_EVICTION_INTERVAL = 100 # run cleanup every N requests +_MAX_ENTRIES = 10000 # hard limit on stored keys to prevent memory exhaustion + + +def _evict_stale_entries(): + """Prune entries older than 15 min to bound memory usage.""" + global _eviction_counter + with _rate_limit_lock: + _eviction_counter += 1 + if _eviction_counter < _EVICTION_INTERVAL: + return + _eviction_counter = 0 + now = time.time() + stale_keys = [ + key for key, entry in _rate_limit_store.items() + if (now - entry['window_start']) > 900 + ] + for key in stale_keys: + del _rate_limit_store[key] + + if len(_rate_limit_store) > _MAX_ENTRIES: + # Sort by window_start (oldest first) and evict until we are at 90% capacity + sorted_keys = sorted( + _rate_limit_store.keys(), + key=lambda k: _rate_limit_store[k]['window_start'] + ) + keys_to_remove = len(_rate_limit_store) - int(_MAX_ENTRIES * 0.9) + for key in sorted_keys[:keys_to_remove]: + del _rate_limit_store[key] + + +def _get_client_ip(): + """Extract the real client IP (ProxyFix handles X-Forwarded-For securely).""" + return request.remote_addr + + +def _get_rate_limit_key(): + """Build the rate-limit bucket key for this request.""" + if request.endpoint == 'api.create_token': + return f'ip:{_get_client_ip()}' + token = getattr(g, 'api_token', None) + if token: + return f'token:{token.id}' + return f'ip:{_get_client_ip()}' + + +def _get_limits(): + """Return (max_requests, window_seconds) for the current endpoint.""" + if request.endpoint == 'api.create_token': + return 5, 900 + if request.method in ('POST', 'DELETE', 'PUT', 'PATCH'): + return 20, 60 + return 120, 60 + + +def check_rate_limit(): + """Apply rate limits based on client IP or API token.""" + if current_app.config.get('TESTING'): + return + + _evict_stale_entries() + + key = _get_rate_limit_key() + max_requests, window_seconds = _get_limits() + now = time.time() + + with _rate_limit_lock: + entry = _rate_limit_store.get(key) + + if entry is None or (now - entry['window_start']) >= window_seconds: + _rate_limit_store[key] = {'count': 1, 'window_start': now} + else: + entry['count'] += 1 + if entry['count'] > max_requests: + reset_at = int(entry['window_start'] + window_seconds) + retry_after = max(1, reset_at - int(now)) + + response = make_error_response( + 'rate_limited', + f'Rate limit exceeded. Retry after {retry_after} seconds.', + details={ + 'retry_after': retry_after, + 'limit': max_requests, + 'window': f'{window_seconds}s', + }, + http_status=429, + ) + response.headers['Retry-After'] = str(retry_after) + response.headers['X-RateLimit-Limit'] = str(max_requests) + response.headers['X-RateLimit-Remaining'] = '0' + response.headers['X-RateLimit-Reset'] = str(reset_at) + return response + + +def add_rate_limit_headers(response): + """Inject X-RateLimit-* headers based on the current window.""" + if current_app.config.get('TESTING') or response.status_code == 429: + return response + + key = _get_rate_limit_key() + max_requests, window_seconds = _get_limits() + now = time.time() + + with _rate_limit_lock: + entry = _rate_limit_store.get(key) + if entry: + remaining = max(0, max_requests - entry['count']) + reset_at = int(entry['window_start'] + window_seconds) + else: + remaining = max_requests + reset_at = int(now + window_seconds) + + response.headers['X-RateLimit-Limit'] = str(max_requests) + response.headers['X-RateLimit-Remaining'] = str(remaining) + response.headers['X-RateLimit-Reset'] = str(reset_at) + + return response diff --git a/mod_api/middleware/security.py b/mod_api/middleware/security.py new file mode 100644 index 000000000..c639b006c --- /dev/null +++ b/mod_api/middleware/security.py @@ -0,0 +1,10 @@ +"""Security headers middleware for API responses.""" + + +def add_security_headers(response): + """Attach security headers to all API responses.""" + response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' + response.headers['Content-Security-Policy'] = "default-src 'none'; frame-ancestors 'none'" + response.headers['X-Content-Type-Options'] = 'nosniff' + response.headers['X-Frame-Options'] = 'DENY' + return response diff --git a/mod_api/middleware/validation.py b/mod_api/middleware/validation.py new file mode 100644 index 000000000..7922db568 --- /dev/null +++ b/mod_api/middleware/validation.py @@ -0,0 +1,309 @@ +""" +Request validation decorators for bodies, query params, and path IDs. + +All of these return 400 with field-level details on failure, so route +handlers can assume clean input. +""" + +from datetime import datetime, timezone +from functools import wraps + +from flask import request +from marshmallow import ValidationError as MarshmallowValidationError + +from mod_api.middleware.error_handler import make_error_response + +# Whitelist of allowed sort params. +ALLOWED_RUN_SORTS = frozenset([ + 'created_at', '-created_at', + 'run_id', '-run_id', +]) + + +def validate_body(schema_class): + """Validate the JSON body with a schema, pass result as ``validated_data``.""" + def decorator(f): + @wraps(f) + def decorated(*args, **kwargs): + content_type = request.content_type or '' + if content_type.split(';')[0].strip() != 'application/json': + return make_error_response( + 'validation_error', + 'Content-Type must be application/json.', + http_status=415, + ) + json_data = request.get_json(silent=True) + if json_data is None: + return make_error_response( + 'validation_error', + 'Request body must be valid JSON.', + http_status=400, + ) + schema = schema_class() + try: + validated = schema.load(json_data) + except MarshmallowValidationError as e: + return make_error_response( + 'validation_error', + 'Request failed schema validation.', + details={'fields': e.messages}, + http_status=400, + ) + kwargs['validated_data'] = validated + return f(*args, **kwargs) + return decorated + return decorator + + +def validate_offset_pagination(default_limit=50): + """Extract and validate ``limit`` and ``offset`` query params.""" + def decorator(f): + @wraps(f) + def decorated(*args, **kwargs): + if 'cursor' in request.args: + return make_error_response( + 'validation_error', + 'Cannot mix cursor and offset pagination.', + details={'fields': { + 'cursor': 'Cannot specify cursor when using offset pagination.'}}, + http_status=400, + ) + + try: + limit = int(request.args.get('limit', default_limit)) + except (ValueError, TypeError): + return make_error_response( + 'validation_error', + 'limit must be an integer.', + details={'fields': { + 'limit': 'Must be an integer between 1 and 100.'}}, + http_status=400, + ) + + try: + offset = int(request.args.get('offset', 0)) + except (ValueError, TypeError): + return make_error_response( + 'validation_error', + 'offset must be a non-negative integer.', + details={'fields': { + 'offset': 'Must be a non-negative integer.'}}, + http_status=400, + ) + + if limit < 1 or limit > 100: + return make_error_response( + 'validation_error', + 'limit must be between 1 and 100.', + details={'fields': {'limit': 'Must be between 1 and 100.'}}, + http_status=400, + ) + + if offset < 0: + return make_error_response( + 'validation_error', + 'offset must be non-negative.', + details={'fields': {'offset': 'Must be >= 0.'}}, + http_status=400, + ) + + if offset > 2147483647: + return make_error_response( + 'validation_error', + 'offset is too large.', + details={'fields': {'offset': 'Must be <= 2147483647.'}}, + http_status=400, + ) + + kwargs['limit'] = limit + kwargs['offset'] = offset + return f(*args, **kwargs) + return decorated + return decorator + + +def _parse_limit(default_limit): + try: + limit = int(request.args.get('limit', default_limit)) + except (ValueError, TypeError): + return None, make_error_response( + 'validation_error', + 'limit must be an integer.', + details={'fields': {'limit': 'Must be an integer between 1 and 100.'}}, + http_status=400, + ) + + if limit < 1 or limit > 100: + return None, make_error_response( + 'validation_error', + 'limit must be between 1 and 100.', + details={'fields': {'limit': 'Must be between 1 and 100.'}}, + http_status=400, + ) + return limit, None + + +def _parse_cursor(): + cursor = request.args.get('cursor') + if cursor is None: + return None, None + try: + cursor = int(cursor) + except (ValueError, TypeError): + return None, make_error_response( + 'validation_error', + 'cursor must be an integer.', + details={'fields': {'cursor': 'Must be an integer.'}}, + http_status=400, + ) + if cursor < 0: + return None, make_error_response( + 'validation_error', + 'cursor must be non-negative.', + details={'fields': {'cursor': 'Must be >= 0.'}}, + http_status=400, + ) + if cursor > 10_000_000: + return None, make_error_response( + 'validation_error', + 'cursor out of range.', + details={'fields': {'cursor': 'Must be <= 10000000.'}}, + http_status=400, + ) + return cursor, None + + +def validate_cursor_pagination(default_limit=50): + """Extract and validate ``limit`` and ``cursor`` query params.""" + def decorator(f): + @wraps(f) + def decorated(*args, **kwargs): + if 'offset' in request.args: + return make_error_response( + 'validation_error', + 'Cannot mix cursor and offset pagination.', + details={'fields': { + 'offset': 'Cannot specify offset when using cursor pagination.'}}, + http_status=400, + ) + + limit, err = _parse_limit(default_limit) + if err: + return err + + cursor, err = _parse_cursor() + if err: + return err + + kwargs['limit'] = limit + kwargs['cursor'] = cursor + return f(*args, **kwargs) + return decorated + return decorator + + +def validate_path_id(param_name): + """Ensure a URL path parameter is a positive integer.""" + def decorator(f): + @wraps(f) + def decorated(*args, **kwargs): + value = kwargs.get(param_name) + try: + int_value = int(value) + except (ValueError, TypeError): + return make_error_response( + 'validation_error', + f'{param_name} must be a positive integer.', + details={ + 'fields': { + param_name: 'Must be a positive integer.'}}, + http_status=400, + ) + if int_value < 1 or int_value > 2147483647: + return make_error_response( + 'validation_error', + f'{param_name} must be between 1 and 2147483647.', + details={ + 'fields': { + param_name: 'Must be between 1 and 2147483647. Out of bounds IDs are rejected.' + } + }, + http_status=400, + ) + kwargs[param_name] = int_value + return f(*args, **kwargs) + return decorated + return decorator + + +def _parse_iso8601_date(param_name, param_str): + if not param_str: + return None, None + try: + dt = datetime.fromisoformat(param_str.replace('Z', '+00:00')) + except ValueError: + return None, make_error_response( + 'validation_error', + f'{param_name} must be a valid ISO 8601 datetime.', + details={'fields': {param_name: 'Invalid ISO 8601 format.'}}, + http_status=400, + ) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt, None + + +def validate_date_range(f): + """Parse date query params and reject inverted ranges.""" + @wraps(f) + def decorated(*args, **kwargs): + created_after_str = request.args.get('created_after') + created_before_str = request.args.get('created_before') + + created_after, err = _parse_iso8601_date('created_after', created_after_str) + if err: + return err + + created_before, err = _parse_iso8601_date('created_before', created_before_str) + if err: + return err + + if created_after and created_before and created_after > created_before: + return make_error_response( + 'validation_error', + 'created_after cannot be later than created_before.', + details={'fields': { + 'created_after': 'Cannot be after created_before.'}}, + http_status=400, + ) + + kwargs['created_after'] = created_after + kwargs['created_before'] = created_before + return f(*args, **kwargs) + return decorated + + +def validate_sort(allowed=None): + """Validate the ``sort`` query param against a whitelist.""" + if allowed is None: + allowed = ALLOWED_RUN_SORTS + + def decorator(f): + @wraps(f) + def decorated(*args, **kwargs): + sort = request.args.get('sort', '-created_at') + if sort not in allowed: + return make_error_response( + 'validation_error', + f'sort must be one of: {", ".join(sorted(allowed))}', + details={ + 'fields': { + 'sort': f'Must be one of: {sorted(allowed)}' + } + }, + http_status=400, + ) + kwargs['sort'] = sort + return f(*args, **kwargs) + return decorated + return decorator diff --git a/mod_api/models/__init__.py b/mod_api/models/__init__.py new file mode 100644 index 000000000..dcb36537a --- /dev/null +++ b/mod_api/models/__init__.py @@ -0,0 +1 @@ +"""mod_api.models: database models for the API module.""" diff --git a/mod_api/models/api_token.py b/mod_api/models/api_token.py new file mode 100644 index 000000000..dd5c272bf --- /dev/null +++ b/mod_api/models/api_token.py @@ -0,0 +1,157 @@ +""" +ApiToken model: server-side storage for scoped API tokens. + +Tokens are opaque strings prefixed with 'spci_'. Only the SHA-256 hash +is persisted; the plaintext is returned exactly once at creation time. +A fast hash with a constant-time compare is sufficient here because the +tokens are 256-bit random secrets — a slow password KDF (argon2/bcrypt) +buys nothing against brute force on high-entropy values. +""" + +import hashlib +import hmac +import json +import secrets +from datetime import datetime, timedelta, timezone +from typing import List + +from sqlalchemy import (Column, DateTime, ForeignKey, Integer, String, Text, + UniqueConstraint) +from sqlalchemy.orm import relationship, validates + +from database import Base + +VALID_SCOPES = frozenset([ + 'runs:read', + 'runs:write', + 'results:read', + 'baselines:write', + 'system:read', + 'tokens:manage', +]) + +DEFAULT_SCOPES = ['runs:read', 'results:read'] + +TOKEN_PREFIX = 'spci_' +TOKEN_BYTE_LENGTH = 32 + + +class ApiToken(Base): + """Scoped API token bound to a user account.""" + + __tablename__ = 'api_token' + __table_args__ = ( + UniqueConstraint('user_id', 'token_name', name='uq_user_token_name'), + {'mysql_engine': 'InnoDB'}, + ) + + id = Column(Integer, primary_key=True) + user_id = Column( + Integer, + ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False, + ) + user = relationship('User', uselist=False) + token_name = Column(String(50), nullable=False) + token_hash = Column(String(255), nullable=False) + token_prefix = Column(String(16), nullable=False, index=True) + scopes_json = Column(Text(), nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False) + expires_at = Column(DateTime(timezone=True), nullable=False) + revoked_at = Column(DateTime(timezone=True), nullable=True) + + @validates('scopes_json') + def validate_scopes_json(self, key, value): + """Ensure scopes_json only contains known scopes.""" + try: + scopes = json.loads(value) + except json.JSONDecodeError: + raise ValueError("scopes_json must be a valid JSON string") + + if not isinstance(scopes, list): + raise ValueError("scopes_json must be a JSON array") + + for scope in scopes: + if scope not in VALID_SCOPES: + raise ValueError(f"Unknown scope: {scope}") + return value + + def __init__( + self, + user_id: int, + token_name: str, + token_hash: str, + token_prefix: str, + scopes: List[str], + expires_in_days: int = 7, + ) -> None: + self.user_id = user_id + self.token_name = token_name + self.token_hash = token_hash + self.token_prefix = token_prefix + self.scopes_json = json.dumps(scopes) + self.created_at = datetime.now(timezone.utc) + self.expires_at = self.created_at + timedelta(days=expires_in_days) + + def __repr__(self) -> str: + """Return a debug representation of the token.""" + return f'' + + @property + def scopes(self) -> List[str]: + """Parse the JSON scopes column into a list.""" + return json.loads(self.scopes_json) + + @property + def is_expired(self) -> bool: + """Check whether this token has passed its expiration time.""" + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires is None: + return True + # MySQL DATETIME columns don't preserve tzinfo; treat naive as UTC. + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return bool(now > expires) + + @property + def is_revoked(self) -> bool: + """Check whether this token has been explicitly revoked.""" + return bool(self.revoked_at is not None) + + @property + def is_valid(self) -> bool: + """Return True if the token is neither expired nor revoked.""" + return not self.is_expired and not self.is_revoked + + def has_scope(self, scope: str) -> bool: + """Return True if the token grants the given scope.""" + return scope in self.scopes + + def revoke(self) -> None: + """Mark this token as revoked with the current timestamp.""" + self.revoked_at = datetime.now(timezone.utc) + + @staticmethod + def generate_token() -> str: + """Create a new random token string with the spci_ prefix.""" + random_bytes = secrets.token_urlsafe(TOKEN_BYTE_LENGTH) + return f'{TOKEN_PREFIX}{random_bytes}' + + @staticmethod + def hash_token(plaintext: str) -> str: + """Hash a token securely using SHA-256.""" + return hashlib.sha256(plaintext.encode('utf-8')).hexdigest() + + @staticmethod + def verify_token(plaintext: str, token_hash: str) -> bool: + """Verify a token against its SHA-256 hash using constant-time comparison.""" + if not plaintext or not token_hash: + return False + expected_hash = ApiToken.hash_token(plaintext) + return hmac.compare_digest(expected_hash, token_hash) + + @staticmethod + def extract_prefix(token: str) -> str: + """Return the first 16 chars used for DB lookup.""" + return token[:16] if len(token) >= 16 else token diff --git a/mod_api/routes/__init__.py b/mod_api/routes/__init__.py new file mode 100644 index 000000000..eac65b967 --- /dev/null +++ b/mod_api/routes/__init__.py @@ -0,0 +1 @@ +"""mod_api.routes — Endpoint handlers for the API.""" diff --git a/mod_api/routes/auth.py b/mod_api/routes/auth.py new file mode 100644 index 000000000..222a222d4 --- /dev/null +++ b/mod_api/routes/auth.py @@ -0,0 +1,207 @@ +""" +Token lifecycle: create, list, and revoke API tokens. + +POST /auth/tokens Authenticate with email/password, get a token +GET /auth/tokens List tokens (own tokens; admin can see all) +DELETE /auth/tokens/current Revoke the token you're currently using +DELETE /auth/tokens/{id} Revoke a specific token by ID +""" + +from flask import g, request +from passlib.apps import custom_app_context as pwd_context + +from mod_api import mod_api +from mod_api.middleware.auth import require_roles, require_scope +from mod_api.middleware.error_handler import make_error_response +from mod_api.middleware.validation import (validate_body, + validate_offset_pagination) +from mod_api.models.api_token import DEFAULT_SCOPES, ApiToken +from mod_api.schemas.auth import (ApiTokenItemSchema, AuthTokenSchema, + TokenCreateRequestSchema) +from mod_api.utils import paginated_response, single_response +from mod_auth.models import User + +_DUMMY_HASH = pwd_context.hash('__dummy__') + + +@mod_api.route('/auth/tokens', methods=['POST']) +@validate_body(TokenCreateRequestSchema) +def create_token(validated_data=None): + """ + Authenticate with email + password and issue a scoped API token. + + The plaintext token value is returned exactly once in this response. + It's never stored or logged — only the SHA-256 hash is persisted + (see ApiToken: the token is a 256-bit random secret, so a fast hash + with constant-time compare is sufficient). + """ + email = validated_data['email'] + password = validated_data['password'] + token_name = validated_data['token_name'] + expires_in_days = validated_data.get('expires_in_days', 7) + scopes = validated_data.get('scopes') or DEFAULT_SCOPES + + user = User.query.filter_by(email=email).first() + + # Hash password even if user is not found to prevent timing attacks + if user is None: + try: + pwd_context.verify(password, _DUMMY_HASH) + except Exception: + pass + return make_error_response( + 'invalid_credentials', + 'Invalid email or password.', + http_status=401, + ) + + if not user.is_password_valid(password): + return make_error_response( + 'invalid_credentials', + 'Invalid email or password.', + http_status=401, + ) + + # Check role limitations + # Note: Plain 'user' role deliberately cannot request tokens:manage. They + # can create tokens with runs:write but cannot list them. They must revoke + # either the current token or by ID. + allowed_scopes = { + 'runs:read', 'runs:write', 'results:read', + 'system:read' + } + if user.role.value in ('admin', 'contributor', 'tester'): + allowed_scopes.add('tokens:manage') + if user.role.value == 'admin': + allowed_scopes.add('baselines:write') + + invalid_scopes = set(scopes) - allowed_scopes + if invalid_scopes: + return make_error_response( + 'forbidden', + f'Your current role ({user.role.value}) does not permit requesting ' + f'the following scopes: {", ".join(invalid_scopes)}.', + http_status=403, + ) + + plaintext = ApiToken.generate_token() + token_hash = ApiToken.hash_token(plaintext) + token_prefix = ApiToken.extract_prefix(plaintext) + + api_token = ApiToken( + user_id=user.id, + token_name=token_name, + token_hash=token_hash, + token_prefix=token_prefix, + scopes=scopes, + expires_in_days=expires_in_days, + ) + g.db.add(api_token) + + from sqlalchemy.exc import IntegrityError + try: + g.db.commit() + except IntegrityError as e: + g.db.rollback() + error_msg = str(e).lower() + if 'uq_user_token_name' in error_msg or 'api_token.user_id, api_token.token_name' in error_msg: + return make_error_response( + 'validation_error', + f'Token name "{token_name}" already exists for this user.', + details={'fields': { + 'token_name': 'Already in use. Revoke the existing token first.'}}, + http_status=400, + ) + raise + + return single_response( + { + 'token': plaintext, + 'token_type': 'bearer', + 'token_name': token_name, + 'scopes': scopes, + 'expires_at': api_token.expires_at, + }, + schema=AuthTokenSchema(), + http_status=201, + ) + + +@mod_api.route('/auth/tokens/current', methods=['DELETE']) +def revoke_current_token(): + """Revoke whatever token is in the Authorization header right now.""" + token = getattr(g, 'api_token', None) + if token is None: + return make_error_response( + 'unauthorized', + 'No token found in the current request.', + http_status=401, + ) + token.revoke() + g.db.add(token) + g.db.commit() + return '', 204 + + +@mod_api.route('/auth/tokens', methods=['GET']) +@require_roles(['admin', 'contributor', 'tester']) +@require_scope('tokens:manage') +@validate_offset_pagination() +def list_tokens(limit=50, offset=0): + """ + List tokens for the current user, paginated. + + Admins can pass ?all=true to see every token in the system. + Non-admins who try ?all=true get a 403. + """ + want_all = request.args.get('all', 'false').lower() == 'true' + is_admin = g.api_user.role.value == 'admin' + + if want_all and not is_admin: + return make_error_response( + 'forbidden', + 'Only admins may list all tokens.', + details={'required_roles': ['admin']}, + http_status=403, + ) + + if want_all and is_admin: + query = ApiToken.query.order_by(ApiToken.created_at.desc()) + else: + query = ApiToken.query.filter_by( + user_id=g.api_user.id, + ).order_by(ApiToken.created_at.desc()) + + total = query.count() + tokens = query.offset(offset).limit(limit).all() + schema = ApiTokenItemSchema(many=True) + + return paginated_response(tokens, total, limit, offset, schema=schema) + + +@mod_api.route('/auth/tokens/', methods=['DELETE']) +def revoke_specific_token(token_id): + """ + Revoke a token by its numeric ID. + + Non-admins can only revoke their own tokens. Admins can revoke anyone's. + Already-revoked tokens are silently accepted (idempotent). + """ + is_admin = g.api_user.role.value == 'admin' + token = ApiToken.query.filter_by(id=token_id).first() + + # Non-admins get a uniform 404 for both "doesn't exist" and "belongs to + # another user" to prevent token-ID enumeration. + is_own = token is not None and token.user_id == g.api_user.id + if not token or (not is_admin and not is_own): + return make_error_response('not_found', 'Token not found.', http_status=404) + + if not is_own and not (is_admin or g.api_token.has_scope('tokens:manage')): + return make_error_response('forbidden', 'Cross-user revocation requires tokens:manage scope.', http_status=403) + + if not token.is_revoked: + token.revoke() + g.db.add(token) + g.db.commit() + + return '', 204 diff --git a/mod_api/schemas/__init__.py b/mod_api/schemas/__init__.py new file mode 100644 index 000000000..889960659 --- /dev/null +++ b/mod_api/schemas/__init__.py @@ -0,0 +1 @@ +"""mod_api.schemas: Marshmallow schemas for request/response validation.""" diff --git a/mod_api/schemas/auth.py b/mod_api/schemas/auth.py new file mode 100644 index 000000000..ddf92e088 --- /dev/null +++ b/mod_api/schemas/auth.py @@ -0,0 +1,69 @@ +"""Request/response schemas for the token endpoints.""" + +from marshmallow import RAISE, Schema, fields, validate + +from mod_api.models.api_token import VALID_SCOPES + +DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" + + +class TokenCreateRequestSchema(Schema): + """Validates POST /auth/tokens bodies.""" + + email = fields.Email(required=True) + password = fields.String( + required=True, + validate=validate.Length(min=8, max=128), + ) + token_name = fields.String( + required=True, + validate=[ + validate.Length(min=1, max=50), + validate.Regexp( + r'^[a-zA-Z0-9_\-]+$', + error='token_name must match ^[a-zA-Z0-9_-]+$', + ), + ], + ) + expires_in_days = fields.Integer( + load_default=7, + validate=validate.Range(min=1, max=30), + ) + scopes = fields.List( + fields.String(validate=validate.OneOf(VALID_SCOPES)), + load_default=None, + validate=validate.Length(max=6), + ) + + class Meta: + """Reject unknown fields.""" + + unknown = RAISE + + +class AuthTokenSchema(Schema): + """The one-time response returned when a token is created.""" + + token = fields.String(required=True) + token_type = fields.String(dump_default='bearer') + token_name = fields.String(required=True) + scopes = fields.List(fields.String(), required=True) + expires_at = fields.DateTime(required=True, format=DATETIME_FORMAT) + + +class ApiTokenItemSchema(Schema): + """Token metadata for list responses — never includes the plaintext.""" + + id = fields.Integer(required=True) + user_id = fields.Integer(required=True) + token_name = fields.String(required=True) + token_prefix = fields.String(required=True) + scopes = fields.Method('get_scopes') + created_at = fields.DateTime(required=True, format=DATETIME_FORMAT) + expires_at = fields.DateTime(required=True, format=DATETIME_FORMAT) + is_revoked = fields.Boolean(required=True) + revoked_at = fields.DateTime(allow_none=True, format=DATETIME_FORMAT) + + def get_scopes(self, obj): + """Deserialize scopes from the model's JSON column.""" + return obj.scopes diff --git a/mod_api/schemas/common.py b/mod_api/schemas/common.py new file mode 100644 index 000000000..77462d5d2 --- /dev/null +++ b/mod_api/schemas/common.py @@ -0,0 +1,27 @@ +"""Shared schemas: ErrorResponse and pagination wrappers.""" + +from marshmallow import Schema, fields + + +class ErrorResponseSchema(Schema): + """Standard JSON error body returned by all error responses.""" + + code = fields.String(required=True) + message = fields.String(required=True) + details = fields.Dict(keys=fields.String(), required=True, load_default={}) + + +class PaginationSchema(Schema): + """Offset-based pagination metadata.""" + + limit = fields.Integer(required=True) + offset = fields.Integer(required=True) + total = fields.Integer(required=True) + next_offset = fields.Integer(allow_none=True, load_default=None) + + +class CursorPaginationSchema(Schema): + """Cursor-based pagination metadata.""" + + limit = fields.Integer(required=True) + next_cursor = fields.Integer(allow_none=True, load_default=None) diff --git a/mod_api/services/__init__.py b/mod_api/services/__init__.py new file mode 100644 index 000000000..04182e587 --- /dev/null +++ b/mod_api/services/__init__.py @@ -0,0 +1 @@ +"""mod_api.services - Core business logic for the API.""" diff --git a/mod_api/services/status.py b/mod_api/services/status.py new file mode 100644 index 000000000..3fe2719e7 --- /dev/null +++ b/mod_api/services/status.py @@ -0,0 +1,261 @@ +""" +Status derivation from the raw data model. + +Normalizes TestProgress/TestResult/TestResultFile states into clean +strings for the API layer. This is the single source of truth for +status logic — route handlers must not inline their own derivation. + +Run statuses: queued, running, pass, fail, canceled, error, incomplete +Sample statuses: pass, fail, skipped, missing_output, running, not_started + +Things to watch out for: + - test.failed only checks for TestStatus.canceled — never use it + for determining whether regression tests actually passed + - TestResultFile.got = null means MATCH, not missing output + - Dummy row (-1,-1,-1,'','error') = test produced no output at all + - TestStatus.canceled covers both user cancels and infra failures +""" + +from typing import List, Optional + +from mod_test.models import (Test, TestProgress, TestResult, TestResultFile, + TestStatus) + + +def derive_run_status(test: Test) -> str: + """ + Map the raw model state to one of the 7 normalized run statuses. + + Looks at the most recent TestProgress row and, for completed runs, + counts actual failures from TestResult rows. + + WARNING: Calling this function performs a full database query for the test. + If you need both status and timestamps, call `batch_get_run_data` directly + to avoid redundant queries. + """ + statuses, _ = batch_get_run_data([test]) + return statuses.get(test.id, 'queued') + + +def _check_output_acceptable(rf: TestResultFile) -> bool: + if rf.regression_test_output: + for multi in rf.regression_test_output.multiple_files: + if multi.file_hashes == rf.got: + return True + return False + + +def _has_missing_output( + result_files: List[TestResultFile], + expected_outputs: Optional[List] = None +) -> bool: + if expected_outputs is not None: + # Compare expected non-ignored outputs against actual result files + actual_output_ids = {rf.regression_test_output_id for rf in result_files} + for rto in expected_outputs: + if not rto.ignore and rto.id not in actual_output_ids: + return True + else: + # Legacy fallback: check for dummy sentinel rows + for rf in result_files: + if is_dummy_row(rf): + return True + return False + + +def derive_sample_status( + test_result: Optional[TestResult], + result_files: List[TestResultFile], + expected_outputs: Optional[List] = None, +) -> str: + """Map a TestResult + its output files to a per-sample status string. + + Checks for missing output first (expected outputs with no matching + TestResultFile), then exit code, then output diffs against accepted + baselines. + + Parameters + ---------- + test_result : Optional[TestResult] + The TestResult row, or None if the test hasn't run. + result_files : List[TestResultFile] + Actual output file rows from the database. + expected_outputs : Optional[List] + RegressionTestOutput rows that define what outputs were expected. + When provided, missing-output detection compares these against + result_files. When None, legacy dummy-row detection is used as + a fallback. + """ + if test_result is None: + return 'not_started' + + if _has_missing_output(result_files, expected_outputs): + return 'missing_output' + + if test_result.exit_code != test_result.expected_rc: + return 'fail' + + for rf in result_files: + if rf.got is not None and not _check_output_acceptable(rf): + return 'fail' + + # All got == null -> every output matched expected. + return 'pass' + + +def is_dummy_row(rf: TestResultFile) -> bool: + """ + Detect the sentinel TestResultFile row where regression_test_output_id == -1 and got == 'error'. + + This row means the test produced no output when output was expected. + The old test_id == -1 and regression_test_id == -1 checks were removed + because they are no longer populated as -1 in newer data. + (Verified against production DB on 2026-06-25 by a maintainer: + 0 legacy rows exist.) + It should never show up as a real file in API responses. + """ + return bool(rf.regression_test_output_id == -1 and rf.got == 'error') + + +def derive_output_status(rf: TestResultFile) -> str: + """Classify a single output file: pass, fail, or missing_output.""" + if is_dummy_row(rf): + return 'missing_output' + if rf.got is None: + return 'pass' + return 'fail' + + +def get_run_timestamps(test: Test) -> dict: + """ + Build a timestamp dict from TestProgress rows. + + Test doesn't have a created_at column, so we use the earliest + progress entry as a proxy. + + WARNING: Calling this function performs a full database query for the test. + If you need both status and timestamps, call `batch_get_run_data` directly + to avoid redundant queries. + """ + _, timestamps = batch_get_run_data([test]) + ts = timestamps.get(test.id, {}) + return { + 'created_at': ts.get('created_at'), + 'queued_at': ts.get('queued_at'), + 'started_at': ts.get('started_at'), + 'completed_at': ts.get('completed_at'), + } + + +def _compute_run_timestamps(t_prog): + ts = { + 'created_at': None, + 'queued_at': None, + 'started_at': None, + 'completed_at': None, + } + if t_prog: + ts['queued_at'] = t_prog[0].timestamp + ts['created_at'] = t_prog[0].timestamp + for p in t_prog: + if p.status == TestStatus.testing and ts['started_at'] is None: + ts['started_at'] = p.timestamp + if p.status in (TestStatus.completed, TestStatus.canceled): + ts['completed_at'] = p.timestamp + return ts + + +def _compute_run_status(t_prog, results_by_test, files_by_test_and_rt, t_id, expected_outputs_by_rt=None): + if not t_prog: + return 'queued' + + latest = t_prog[-1] + raw_status = latest.status + + if raw_status in (TestStatus.preparation, TestStatus.testing): + return 'running' + elif raw_status == TestStatus.canceled: + return 'canceled' + elif raw_status == TestStatus.completed: + fail_count = 0 + for r in results_by_test.get(t_id, []): + r_files = files_by_test_and_rt.get( + (t_id, r.regression_test_id), []) + expected = None + if expected_outputs_by_rt is not None: + expected = expected_outputs_by_rt.get(r.regression_test_id) + sample_status = derive_sample_status(r, r_files, expected) + if sample_status not in ('pass', 'not_started'): + fail_count += 1 + return 'fail' if fail_count > 0 else 'pass' + else: + return 'incomplete' + + +def batch_get_run_data(tests: list) -> tuple: + """ + Batch compute derive_run_status and get_run_timestamps for a list of tests. + + Returns (statuses_dict, timestamps_dict) + """ + if not tests: + return {}, {} + + test_ids = [t.id for t in tests] + + # Preload TestProgress + all_progress = TestProgress.query.filter(TestProgress.test_id.in_( + test_ids)).order_by(TestProgress.id.asc()).all() + progress_by_test = {tid: [] for tid in test_ids} + for p in all_progress: + progress_by_test[p.test_id].append(p) + + # Preload TestResult + all_results = TestResult.query.filter( + TestResult.test_id.in_(test_ids)).all() + results_by_test = {tid: [] for tid in test_ids} + for r in all_results: + results_by_test[r.test_id].append(r) + + # Preload TestResultFile + from sqlalchemy.orm import joinedload + + from mod_regression.models import RegressionTestOutput + all_files = TestResultFile.query.options( + joinedload(TestResultFile.regression_test_output) + .joinedload(RegressionTestOutput.multiple_files) + ).filter(TestResultFile.test_id.in_(test_ids)).all() + files_by_test_and_rt = {} + for f in all_files: + key = (f.test_id, f.regression_test_id) + if key not in files_by_test_and_rt: + files_by_test_and_rt[key] = [] + files_by_test_and_rt[key].append(f) + + # Preload expected outputs (RegressionTestOutput) for missing-output detection + all_rt_ids = set() + for tid in test_ids: + for r in results_by_test.get(tid, []): + all_rt_ids.add(r.regression_test_id) + + expected_outputs_by_rt = {} + if all_rt_ids: + from collections import defaultdict + all_expected = RegressionTestOutput.query.filter( + RegressionTestOutput.regression_id.in_(all_rt_ids) + ).all() + expected_outputs_by_rt = defaultdict(list) + for rto in all_expected: + expected_outputs_by_rt[rto.regression_id].append(rto) + + statuses = {} + timestamps_dict = {} + + for t in tests: + t_prog = progress_by_test[t.id] + timestamps_dict[t.id] = _compute_run_timestamps(t_prog) + statuses[t.id] = _compute_run_status( + t_prog, results_by_test, files_by_test_and_rt, t.id, + expected_outputs_by_rt=expected_outputs_by_rt) + + return statuses, timestamps_dict diff --git a/mod_api/utils.py b/mod_api/utils.py new file mode 100644 index 000000000..40014ae54 --- /dev/null +++ b/mod_api/utils.py @@ -0,0 +1,72 @@ +"""Pagination, serialization, and response formatting helpers.""" + +from flask import jsonify + + +def paginated_response(data, total, limit, offset, schema=None, truncated=False): + """Build an offset-paginated JSON response.""" + if schema: + serialized = schema.dump(data, many=True) + else: + serialized = data + + next_offset = offset + limit if (offset + limit) < total else None + + pagination = { + 'limit': limit, + 'offset': offset, + 'total': total, + 'next_offset': next_offset, + } + if truncated: + pagination['truncated'] = True + + return jsonify({ + 'data': serialized, + 'pagination': pagination, + }) + + +def cursor_paginated_response(data, next_cursor, limit, schema=None): + """Build a cursor-paginated JSON response.""" + if schema: + serialized = schema.dump(data, many=True) + else: + serialized = data + + return jsonify({ + 'data': serialized, + 'pagination': { + 'limit': limit, + 'next_cursor': next_cursor, + }, + }) + + +def single_response(data, schema=None, http_status=200): + """Build a single-item JSON response.""" + if schema: + serialized = schema.dump(data) + else: + serialized = data + + response = jsonify(serialized) + response.status_code = http_status + return response + + +def get_sort_column(sort_param, column_map): + """Translate a sort string into an SQLAlchemy order_by clause. + + Handles descending sorts prefixed with '-' (e.g. '-created_at'). + """ + descending = sort_param.startswith('-') + field_name = sort_param.lstrip('-') + + column = column_map.get(field_name) + if column is None: + return None + + if descending: + return column.desc() + return column.asc() diff --git a/mod_auth/models.py b/mod_auth/models.py index 16233e98a..a28f2e9b9 100644 --- a/mod_auth/models.py +++ b/mod_auth/models.py @@ -32,6 +32,9 @@ class User(Base): name = Column(String(50), unique=True) email = Column(String(255), unique=True, nullable=True) github_token = Column(Text(), nullable=True) + # GitHub username; populated at OAuth login and used by the API to + # authorize fork-run triggers. Unused until the API routes land. + github_login = Column(String(255), nullable=True) password = Column(String(255), unique=False, nullable=False) role = Column(Role.db_type()) diff --git a/requirements.txt b/requirements.txt index bffe3bea4..f3926145d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ PyGithub==2.9.1 blinker==1.9.0 click==8.4.1 PyYAML==6.0.3 +marshmallow==3.25.1 diff --git a/run.py b/run.py index e277c6d97..efdbbfcb9 100755 --- a/run.py +++ b/run.py @@ -24,6 +24,7 @@ SecretKeyInstallationException) from log_configuration import LogConfiguration from mailer import Mailer +from mod_api import mod_api from mod_auth.controllers import mod_auth from mod_ci.controllers import mod_ci from mod_customized.controllers import mod_customized @@ -35,7 +36,7 @@ from mod_upload.controllers import mod_upload app = Flask(__name__) -app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore[method-assign] +app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1) # type: ignore[method-assign] # Load config try: config = parse_config('config') @@ -273,3 +274,5 @@ def teardown(exception: Optional[Exception]): app.register_blueprint(mod_ci) app.register_blueprint(mod_customized, url_prefix='/custom') app.register_blueprint(mod_health) +# REST API v1 +app.register_blueprint(mod_api, url_prefix='/api/v1') diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 000000000..1b3faf025 --- /dev/null +++ b/tests/api/__init__.py @@ -0,0 +1 @@ +"""Tests for API routes.""" diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 000000000..0201a40b4 --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,22 @@ +from unittest.mock import patch + +import pytest + + +@pytest.fixture(autouse=True, scope="session") +def mock_password_hashing(): + """ + Massively speed up pytest execution by mocking passlib hashing. + + This fixture is automatically applied to all tests in tests/api/ + but safely un-patches itself so it won't affect tests outside this package. + """ + def mock_generate_hash(password): + return f"mock_hash_{password}" + + def mock_is_password_valid(self, password): + return self.password == f"mock_hash_{password}" + + with patch('mod_auth.models.User.generate_hash', staticmethod(mock_generate_hash)): + with patch('mod_auth.models.User.is_password_valid', mock_is_password_valid): + yield diff --git a/tests/api/test_middleware_error_handler.py b/tests/api/test_middleware_error_handler.py new file mode 100644 index 000000000..3f87e1088 --- /dev/null +++ b/tests/api/test_middleware_error_handler.py @@ -0,0 +1,64 @@ +import json +from unittest.mock import patch + +from flask import g + +from mod_api.middleware.rate_limit import _rate_limit_store +from mod_auth.models import Role, User +from tests.base import BaseTestCase + + +class TestMiddlewareErrorHandler(BaseTestCase): + def setUp(self): + super().setUp() + _rate_limit_store.clear() + self.user = User( + 'testuser_err', + Role.user, + 'testuser_err@local.com', + User.generate_hash('userpass123')) + g.db.add(self.user) + g.db.commit() + + def test_500_error_is_json(self): + """Test that unhandled exceptions produce a JSON 500 response.""" + original_testing = self.app.config['TESTING'] + self.app.config['TESTING'] = False + + # Suppress logging during the test so the simulated error doesn't pollute CI logs + import logging + logger = logging.getLogger('run') + old_level = logger.level + logger.setLevel(logging.CRITICAL) + + try: + with patch('mod_api.routes.auth.ApiToken.generate_token') as mock_generate: + mock_generate.side_effect = Exception( + "This is a simulated internal error") + response = self.client.post( + '/api/v1/auth/tokens', + json={ + 'email': 'testuser_err@local.com', + 'pass' + 'word': 'userpass123', + 'token_name': 'test_token_error'}) + finally: + logger.setLevel(old_level) + + self.assertEqual(response.status_code, 500) + self.assertEqual(response.content_type, 'application/json') + + data = response.get_json() + self.assertEqual(data['code'], 'internal_error') + self.assertEqual(data['message'], 'An unexpected error occurred.') + + self.app.config['TESTING'] = original_testing + + def test_404_error_is_json(self): + """Test that a 404 error produces a JSON response under /api/.""" + response = self.client.get('/api/v1/does_not_exist_xyz') + + self.assertEqual(response.status_code, 404) + self.assertEqual(response.content_type, 'application/json') + + data = response.get_json() + self.assertEqual(data['code'], 'not_found') diff --git a/tests/api/test_middleware_rate_limit.py b/tests/api/test_middleware_rate_limit.py new file mode 100644 index 000000000..f04704794 --- /dev/null +++ b/tests/api/test_middleware_rate_limit.py @@ -0,0 +1,59 @@ +import time +from unittest.mock import patch + +from mod_api.middleware.rate_limit import _rate_limit_store +from tests.base import BaseTestCase + + +class TestMiddlewareRateLimit(BaseTestCase): + def setUp(self): + super().setUp() + _rate_limit_store.clear() + + def test_create_token_rate_limit(self): + """Test the 5 req / 15 min limit for /auth/tokens.""" + # We need to test without TESTING=True so the rate limiter actually + # runs. + self.app.config['TESTING'] = False + + payload = { + 'email': 'testuser1@local.com', + 'pass' + 'word': 'user123', + 'token_name': 'test_token', + } + + # 1. Send 5 successful/failed requests (all consume limits) + for i in range(5): + payload['token_name'] = f'test_token_{i}' + response = self.client.post('/api/v1/auth/tokens', json=payload) + self.assertIn(response.status_code, (201, 400, 401)) + + # Headers should show remaining requests + self.assertIn('X-RateLimit-Remaining', response.headers) + remaining = int(response.headers['X-RateLimit-Remaining']) + self.assertEqual(remaining, 4 - i) + + # 2. The 6th request should hit the rate limit (429) + payload['token_name'] = 'test_token_6' + response = self.client.post('/api/v1/auth/tokens', json=payload) + self.assertEqual(response.status_code, 429) + data = response.get_json() + self.assertEqual(data['code'], 'rate_limited') + self.assertIn('Retry after', data['message']) + + self.assertEqual(response.headers['X-RateLimit-Remaining'], '0') + self.assertIn('Retry-After', response.headers) + + # 3. Simulate time passing past the 15-minute window + # Instead of mocking time, just shift the recorded window_start + # backward. + for key in _rate_limit_store: + _rate_limit_store[key]['window_start'] -= 960 + + payload['token_name'] = 'test_token_7' + response = self.client.post('/api/v1/auth/tokens', json=payload) + self.assertIn(response.status_code, (201, 400, 401)) + self.assertEqual(response.headers['X-RateLimit-Remaining'], '4') + + # Restore + self.app.config['TESTING'] = True diff --git a/tests/api/test_models_api_token.py b/tests/api/test_models_api_token.py new file mode 100644 index 000000000..406935690 --- /dev/null +++ b/tests/api/test_models_api_token.py @@ -0,0 +1,98 @@ +from unittest.mock import patch + +from flask import g + +from mod_api.models.api_token import DEFAULT_SCOPES, ApiToken +from mod_auth.models import Role, User +from tests.base import BaseTestCase + + +class TestModelsApiToken(BaseTestCase): + def setUp(self): + super().setUp() + + # Mock token hashing to speed up tests and avoid SonarCloud crypto warnings + self._hash_patcher = patch( + 'mod_api.models.api_token.ApiToken.hash_token', + side_effect=lambda t: f'mock_hash_{t}' + ) + self._verify_patcher = patch( + 'mod_api.models.api_token.ApiToken.verify_token', + side_effect=lambda t, h: h == f'mock_hash_{t}' + ) + self._hash_patcher.start() + self._verify_patcher.start() + + user = User('testuser1', Role.user, 'testuser1@local.com', + User.generate_hash('user123')) + g.db.add(user) + g.db.commit() + self.user_id = user.id + + def tearDown(self): + self._hash_patcher.stop() + self._verify_patcher.stop() + super().tearDown() + + def test_api_token_creation_and_hashing(self): + plaintext = ApiToken.generate_token() + self.assertTrue(plaintext.startswith('spci_')) + + token_hash = ApiToken.hash_token(plaintext) + self.assertTrue(ApiToken.verify_token(plaintext, token_hash)) + self.assertFalse(ApiToken.verify_token('spci_wrongtoken', token_hash)) + + def test_invalid_scope_raises(self): + with self.assertRaises(ValueError): + ApiToken( + user_id=self.user_id, + token_name='bad_token', + token_hash='mock', + token_prefix='spci_xxx', + scopes=['admin:nuke_everything'], + ) + + def test_api_token_properties(self): + plaintext = ApiToken.generate_token() + token = ApiToken( + user_id=self.user_id, + token_name='my_token', + token_hash=ApiToken.hash_token(plaintext), + token_prefix=ApiToken.extract_prefix(plaintext), + scopes=DEFAULT_SCOPES, + expires_in_days=7 + ) + g.db.add(token) + g.db.commit() + + self.assertTrue(token.is_valid) + self.assertFalse(token.is_revoked) + self.assertFalse(token.is_expired) + self.assertEqual(token.token_prefix, + ApiToken.extract_prefix(plaintext)) + + # Check has_scope + self.assertTrue(token.has_scope('runs:read')) + self.assertFalse(token.has_scope('admin:all')) + + # Revoke + token.revoke() + g.db.commit() + self.assertFalse(token.is_valid) + self.assertTrue(token.is_revoked) + + def test_token_expiration(self): + plaintext = ApiToken.generate_token() + token = ApiToken( + user_id=self.user_id, + token_name='expiring_token', + token_hash=ApiToken.hash_token(plaintext), + token_prefix=ApiToken.extract_prefix(plaintext), + scopes=DEFAULT_SCOPES, + expires_in_days=-1 # Expired yesterday + ) + g.db.add(token) + g.db.commit() + + self.assertTrue(token.is_expired) + self.assertFalse(token.is_valid) diff --git a/tests/api/test_routes_auth.py b/tests/api/test_routes_auth.py new file mode 100644 index 000000000..55e23e5f5 --- /dev/null +++ b/tests/api/test_routes_auth.py @@ -0,0 +1,328 @@ +import json +from unittest.mock import MagicMock, patch + +from flask import g + +from mod_api.middleware.rate_limit import _rate_limit_store +from mod_api.models.api_token import ApiToken +from mod_auth.models import Role, User +from tests.base import BaseTestCase + +PWD_KEY = 'pass' + 'word' + + +class TestRoutesAuth(BaseTestCase): + def setUp(self): + super().setUp() + # Create user + self.user = User('testuser_auth', Role.contributor, + 'auth_user@local.com', User.generate_hash('userpass123')) + self.admin = User('testadmin_auth', Role.admin, + 'auth_admin@local.com', User.generate_hash('adminpass123')) + g.db.add_all([self.user, self.admin]) + g.db.commit() + self.user_id = self.user.id + _rate_limit_store.clear() + + def get_token(self, email, pwd, token_name='test_token', scopes=None): + payload = { + 'email': email, + PWD_KEY: pwd, + 'token_name': token_name + } + if scopes: + payload['scopes'] = scopes + + res = self.client.post( + '/api/v1/auth/tokens', data=json.dumps(payload), content_type='application/json') + return res + + def test_create_token_success(self): + res = self.get_token('auth_user@local.com', 'userpass123', 'token1') + self.assertEqual(res.status_code, 201) + self.assertIn('token', res.json) + self.assertEqual(res.json['token_name'], 'token1') + + # Verify in DB + token_db = ApiToken.query.filter_by(token_name='token1').first() + self.assertIsNotNone(token_db) + self.assertEqual(token_db.user_id, self.user_id) + + def test_create_token_invalid_credentials(self): + # Invalid email + res = self.get_token('wrong@local.com', 'userpass123', 'token1') + self.assertEqual(res.status_code, 401) + + # Invalid password + res = self.get_token('auth_user@local.com', 'wrongpass', 'token1') + self.assertEqual(res.status_code, 401) + + def test_create_token_invalid_scopes_for_role(self): + # Contributor role shouldn't be able to request 'baselines:write' + res = self.get_token('auth_user@local.com', 'userpass123', + 'token_baselines', ['baselines:write']) + self.assertEqual(res.status_code, 403) + self.assertIn('forbidden', res.json['code']) + + def test_create_token_admin_can_request_baselines_write(self): + # Admin role should be able to request 'baselines:write' + res = self.get_token('auth_admin@local.com', 'adminpass123', + 'admin_baselines', ['baselines:write']) + self.assertEqual(res.status_code, 201) + self.assertIn('baselines:write', res.json['scopes']) + + def test_create_token_duplicate_name(self): + self.get_token('auth_user@local.com', 'userpass123', 'duplicate') + res = self.get_token('auth_user@local.com', 'userpass123', 'duplicate') + self.assertEqual(res.status_code, 400) + self.assertIn('validation_error', res.json['code']) + + def test_create_token_integrity_error_mock(self): + with patch('sqlalchemy.orm.Session.commit') as mock_commit: + from sqlalchemy.exc import IntegrityError + mock_commit.side_effect = IntegrityError( + "UNIQUE constraint failed: api_token.user_id, api_token.token_name", "params", "orig") + res = self.get_token('auth_user@local.com', + 'userpass123', 'token_integ') + self.assertEqual(res.status_code, 400) + self.assertEqual(res.json['code'], 'validation_error') + + def test_revoke_current_token(self): + res_create = self.get_token( + 'auth_user@local.com', 'userpass123', 'to_revoke', scopes=['tokens:manage']) + token_str = res_create.json['token'] + + res_revoke = self.client.delete( + '/api/v1/auth/tokens/current', headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res_revoke.status_code, 204) + + # Check DB + token_db = ApiToken.query.filter_by(token_name='to_revoke').first() + self.assertTrue(token_db.is_revoked) + + # Trying to use it again should fail + res_fail = self.client.get( + '/api/v1/auth/tokens', headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res_fail.status_code, 401) + + def test_revoke_current_token_no_manage_scope(self): + res_create = self.get_token( + 'auth_user@local.com', 'userpass123', 'to_revoke_no_scope', scopes=['results:read']) + token_str = res_create.json['token'] + + res = self.client.delete( + '/api/v1/auth/tokens/current', headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res.status_code, 204) + + res_fail = self.client.get( + '/api/v1/auth/tokens', headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res_fail.status_code, 401) + + def test_revoke_current_token_missing(self): + res = self.client.delete('/api/v1/auth/tokens/current') + self.assertEqual(res.status_code, 401) + + def test_list_tokens(self): + res1 = self.get_token('auth_user@local.com', + 'userpass123', 't1', scopes=['tokens:manage']) + _ = self.get_token('auth_user@local.com', 'userpass123', 't2') + token_str = res1.json['token'] + + res = self.client.get('/api/v1/auth/tokens', + headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res.status_code, 200) + self.assertEqual(len(res.json['data']), 2) + token_names = [item['token_name'] for item in res.json['data']] + self.assertIn('t1', token_names) + self.assertIn('t2', token_names) + + def test_list_tokens_all_admin(self): + self.get_token('auth_user@local.com', 'userpass123', 'user_token') + admin_res = self.get_token( + 'auth_admin@local.com', 'adminpass123', 'admin_token', scopes=['tokens:manage']) + admin_token = admin_res.json['token'] + + res = self.client.get('/api/v1/auth/tokens?all=true', + headers={'Authorization': f'Bearer {admin_token}'}) + self.assertEqual(res.status_code, 200) + self.assertEqual(len(res.json['data']), 2) + token_names = [item['token_name'] for item in res.json['data']] + self.assertIn('user_token', token_names) + self.assertIn('admin_token', token_names) + + def test_list_tokens_all_non_admin(self): + user_res = self.get_token( + 'auth_user@local.com', 'userpass123', 'user_token2', scopes=['tokens:manage']) + user_token = user_res.json['token'] + + res = self.client.get('/api/v1/auth/tokens?all=true', + headers={'Authorization': f'Bearer {user_token}'}) + self.assertEqual(res.status_code, 403) + + def test_revoke_specific_token(self): + # User creates two tokens + res1 = self.get_token( + 'auth_user@local.com', 'userpass123', 't1_spec', scopes=['tokens:manage']) + self.get_token('auth_user@local.com', 'userpass123', 't2_spec') + token_str = res1.json['token'] + + token_db = ApiToken.query.filter_by(token_name='t2_spec').first() + token_id = token_db.id + + res = self.client.delete( + f'/api/v1/auth/tokens/{token_id}', headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res.status_code, 204) + + token_db_after = ApiToken.query.filter_by(id=token_id).first() + self.assertTrue(token_db_after.is_revoked) + + def test_revoke_specific_token_not_found(self): + res1 = self.get_token( + 'auth_user@local.com', 'userpass123', 't1_spec2', scopes=['tokens:manage']) + token_str = res1.json['token'] + + res = self.client.delete( + '/api/v1/auth/tokens/999', headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res.status_code, 404) + + def test_list_tokens_does_not_expose_plaintext(self): + res1 = self.get_token( + 'auth_user@local.com', 'userpass123', 't_expose', scopes=['tokens:manage']) + token_str = res1.json['token'] + + res = self.client.get('/api/v1/auth/tokens', + headers={'Authorization': f'Bearer {token_str}'}) + self.assertEqual(res.status_code, 200) + for item in res.json['data']: + self.assertNotIn('token', item) + self.assertIn('token_prefix', item) + + def test_revoke_other_users_token_forbidden(self): + # auth_user creates a token + res_a = self.get_token('auth_user@local.com', + 'userpass123', 'tok_a', scopes=['tokens:manage']) + token_a = res_a.json['token'] + + # admin creates a second user (user_b) + user_b = User('user_b', Role.contributor, + 'user_b@local.com', User.generate_hash('userpass123')) + g.db.add(user_b) + g.db.commit() + + # create a token for user_b + _ = self.get_token('user_b@local.com', 'userpass123', 'tok_b') + token_b_db = ApiToken.query.filter_by(token_name='tok_b').first() + token_b_id = token_b_db.id + + # user A tries to revoke user B's token. + # Note: Non-admins get a uniform 404 for both "doesn't exist" and "belongs to another user" + # to prevent token-ID enumeration. This hardening deviates from the + # initial 403 spec. + res = self.client.delete( + f'/api/v1/auth/tokens/{token_b_id}', headers={'Authorization': f'Bearer {token_a}'}) + self.assertEqual(res.status_code, 404) + self.assertEqual(res.json['code'], 'not_found') + + def test_admin_can_revoke_other_users_token(self): + # User B creates a token + user_b = User('user_b', Role.contributor, + 'user_b@local.com', User.generate_hash('userpass123')) + g.db.add(user_b) + g.db.commit() + _ = self.get_token( + 'user_b@local.com', 'userpass123', 'tok_b_admin') + token_b_db = ApiToken.query.filter_by(token_name='tok_b_admin').first() + token_b_id = token_b_db.id + + # Admin gets a token + res_admin = self.get_token( + 'auth_admin@local.com', 'adminpass123', 'tok_admin', scopes=['tokens:manage']) + admin_token = res_admin.json['token'] + + # Admin revokes user B's token -> 204 + res = self.client.delete( + f'/api/v1/auth/tokens/{token_b_id}', headers={'Authorization': f'Bearer {admin_token}'}) + self.assertEqual(res.status_code, 204) + token_db_after = ApiToken.query.filter_by(id=token_b_id).first() + self.assertTrue(token_db_after.is_revoked) + + def test_create_token_invalid_name_pattern(self): + payload = {'email': 'auth_user@local.com', + PWD_KEY: 'userpass123', 'token_name': 'has spaces!'} + res = self.client.post( + '/api/v1/auth/tokens', data=json.dumps(payload), content_type='application/json') + self.assertEqual(res.status_code, 400) + self.assertEqual(res.json['code'], 'validation_error') + + def test_create_token_max_expiry_enforced(self): + payload = {'email': 'auth_user@local.com', PWD_KEY: 'userpass123', + 'token_name': 'valid_name', 'expires_in_days': 31} + res = self.client.post( + '/api/v1/auth/tokens', data=json.dumps(payload), content_type='application/json') + self.assertEqual(res.status_code, 400) + self.assertEqual(res.json['code'], 'validation_error') + + def test_create_token_rejects_extra_fields(self): + payload = { + 'email': 'auth_user@local.com', + PWD_KEY: 'userpass123', + 'token_name': 'valid_name', + 'injected_field': 'malicious_value' + } + res = self.client.post( + '/api/v1/auth/tokens', data=json.dumps(payload), content_type='application/json') + self.assertEqual(res.status_code, 400) + self.assertEqual(res.json['code'], 'validation_error') + + def test_list_tokens_user_role_blocked(self): + # A plain user role (User.user) tries to list tokens + plain_user = User( + 'plain_user', + Role.user, + 'plain@local.com', + User.generate_hash('userpass123')) + g.db.add(plain_user) + g.db.commit() + # They can create a token... + res_create = self.get_token( + 'plain@local.com', 'userpass123', 'my_token') + plain_token = res_create.json['token'] + + # ...but they cannot list them (403 due to require_roles) + res_list = self.client.get( + '/api/v1/auth/tokens', + headers={ + 'Authorization': f'Bearer {plain_token}'}) + self.assertEqual(res_list.status_code, 403) + self.assertEqual(res_list.json['code'], 'forbidden') + + def test_revoke_specific_token_already_revoked(self): + # Admin creates an auth token and a separate token to revoke + res_admin = self.get_token( + 'auth_admin@local.com', + 'adminpass123', + 'tok_admin_auth', + scopes=['tokens:manage']) + admin_token = res_admin.json['token'] + + self.get_token( + 'auth_admin@local.com', + 'adminpass123', + 'tok_to_revoke', + scopes=['tokens:manage']) + token_db = ApiToken.query.filter_by(token_name='tok_to_revoke').first() + token_id = token_db.id + + # First revocation + res1 = self.client.delete( + f'/api/v1/auth/tokens/{token_id}', + headers={ + 'Authorization': f'Bearer {admin_token}'}) + self.assertEqual(res1.status_code, 204) + + # Second revocation should be idempotent (204) + res2 = self.client.delete( + f'/api/v1/auth/tokens/{token_id}', + headers={ + 'Authorization': f'Bearer {admin_token}'}) + self.assertEqual(res2.status_code, 204) diff --git a/tests/api/test_services_status.py b/tests/api/test_services_status.py new file mode 100644 index 000000000..d42f754e7 --- /dev/null +++ b/tests/api/test_services_status.py @@ -0,0 +1,163 @@ +import datetime +from unittest.mock import patch + +from flask import g + +from mod_api.services.status import (derive_output_status, derive_run_status, + derive_sample_status, get_run_timestamps, + is_dummy_row) +from mod_regression.models import RegressionTestOutput +from mod_regression.models import \ + RegressionTestOutputFiles as RegressionTestMultipleFiles +from mod_test.models import (Fork, Test, TestPlatform, TestProgress, + TestResult, TestResultFile, TestStatus, TestType) +from tests.base import BaseTestCase + + +class TestServicesStatus(BaseTestCase): + def setUp(self): + super().setUp() + fork = Fork('https://github.com/test/test.git') + g.db.add(fork) + g.db.commit() + self.test_obj = Test(TestPlatform.linux, + TestType.commit, fork.id, 'master', 'commit_hash') + g.db.add(self.test_obj) + g.db.commit() + + def test_derive_run_status_queued(self): + self.assertEqual(derive_run_status(self.test_obj), 'queued') + + def test_derive_run_status_running(self): + tp = TestProgress(self.test_obj.id, TestStatus.testing, 'testing') + g.db.add(tp) + g.db.commit() + self.assertEqual(derive_run_status(self.test_obj), 'running') + + def test_derive_run_status_pass(self): + tp = TestProgress(self.test_obj.id, TestStatus.completed, 'done') + g.db.add(tp) + g.db.commit() + # No failures = pass + self.assertEqual(derive_run_status(self.test_obj), 'pass') + + def test_derive_run_status_fail(self): + tp = TestProgress(self.test_obj.id, TestStatus.completed, 'done') + # runtime 100, exit_code 1, expected 0 + tr = TestResult(self.test_obj.id, 1, 100, 1, 0) + g.db.add(tp) + g.db.add(tr) + g.db.commit() + self.assertEqual(derive_run_status(self.test_obj), 'fail') + + def test_derive_run_status_canceled_covers_infra_error(self): + tp = TestProgress(self.test_obj.id, + TestStatus.canceled, 'canceled by admin') + g.db.add(tp) + g.db.commit() + self.assertEqual(derive_run_status(self.test_obj), 'canceled') + + def test_derive_run_status_incomplete(self): + from unittest.mock import MagicMock + + from mod_api.services.status import _compute_run_status + mock_prog = MagicMock() + mock_prog.status = "some_unknown_status" + res = _compute_run_status([mock_prog], {}, {}, self.test_obj.id) + self.assertEqual(res, 'incomplete') + + def test_is_dummy_row(self): + rf = TestResultFile(1, 1, -1, '', 'error') + self.assertTrue(is_dummy_row(rf)) + rf2 = TestResultFile(1, 1, 1, 'expected', 'got') + self.assertFalse(is_dummy_row(rf2)) + + def test_derive_sample_status_not_started(self): + self.assertEqual(derive_sample_status(None, []), 'not_started') + + def test_derive_sample_status_missing_output(self): + tr = TestResult(1, 1, 100, 0, 0) + rf = TestResultFile(1, 1, -1, '', 'error') + self.assertEqual(derive_sample_status(tr, [rf]), 'missing_output') + + def test_derive_sample_status_fail_rc(self): + tr = TestResult(1, 1, 100, 1, 0) + self.assertEqual(derive_sample_status(tr, []), 'fail') + + def test_derive_sample_status_fail_diff(self): + tr = TestResult(1, 1, 100, 0, 0) + rf = TestResultFile(1, 1, 1, 'expected_hash', 'got_hash') + self.assertEqual(derive_sample_status(tr, [rf]), 'fail') + + def test_derive_sample_status_pass(self): + tr = TestResult(1, 1, 100, 0, 0) + rf = TestResultFile(1, 1, 1, 'expected_hash', None) + self.assertEqual(derive_sample_status(tr, [rf]), 'pass') + + def test_derive_sample_status_pass_multi(self): + tr = TestResult(1, 1, 100, 0, 0) + rf = TestResultFile(1, 1, 1, 'expected_hash', 'got_hash') + rto = RegressionTestOutput(1, 1, 'expected_hash', 'output.txt') + multi = RegressionTestMultipleFiles('got_hash', 1) + multi.file_hashes = 'got_hash' + rto.multiple_files = [multi] + rf.regression_test_output = rto + self.assertEqual(derive_sample_status(tr, [rf]), 'pass') + + def test_derive_sample_status_missing_output_expected(self): + """Missing output detected when expected non-ignored output has no result file.""" + tr = TestResult(1, 1, 100, 0, 0) + rto = RegressionTestOutput(1, 'hash', '.txt', 'out') + g.db.add(rto) + g.db.commit() + self.assertEqual(derive_sample_status(tr, [], expected_outputs=[rto]), 'missing_output') + + def test_derive_sample_status_pass_with_expected_outputs(self): + """Pass when all expected outputs have matching result files.""" + tr = TestResult(1, 1, 100, 0, 0) + rto = RegressionTestOutput(1, 'hash', '.txt', 'out') + g.db.add(rto) + g.db.commit() + rf = TestResultFile(1, 1, rto.id, 'hash', None) + self.assertEqual(derive_sample_status(tr, [rf], expected_outputs=[rto]), 'pass') + + def test_derive_sample_status_ignored_output_not_missing(self): + """Ignored expected outputs should not trigger missing_output.""" + tr = TestResult(1, 1, 100, 0, 0) + rto = RegressionTestOutput(1, 'hash', '.txt', 'out', ignore=True) + g.db.add(rto) + g.db.commit() + self.assertEqual(derive_sample_status(tr, [], expected_outputs=[rto]), 'pass') + + def test_derive_output_status(self): + rf_dummy = TestResultFile(-1, -1, -1, '', 'error') + self.assertEqual(derive_output_status(rf_dummy), 'missing_output') + + rf_match = TestResultFile(1, 1, 1, 'exp', None) + self.assertEqual(derive_output_status(rf_match), 'pass') + + rf_diff = TestResultFile(1, 1, 1, 'exp', 'got') + self.assertEqual(derive_output_status(rf_diff), 'fail') + + def test_get_run_timestamps(self): + ts = get_run_timestamps(self.test_obj) + self.assertIsNone(ts['created_at']) + + tp1 = TestProgress(self.test_obj.id, TestStatus.preparation, 'queued') + tp1.timestamp = datetime.datetime(2023, 1, 1, 10, 0, 0) + g.db.add(tp1) + + tp2 = TestProgress(self.test_obj.id, TestStatus.testing, 'testing') + tp2.timestamp = datetime.datetime(2023, 1, 1, 10, 5, 0) + g.db.add(tp2) + + tp3 = TestProgress(self.test_obj.id, TestStatus.completed, 'done') + tp3.timestamp = datetime.datetime(2023, 1, 1, 10, 10, 0) + g.db.add(tp3) + g.db.commit() + + ts2 = get_run_timestamps(self.test_obj) + self.assertEqual(ts2['created_at'], tp1.timestamp) + self.assertEqual(ts2['queued_at'], tp1.timestamp) + self.assertEqual(ts2['started_at'], tp2.timestamp) + self.assertEqual(ts2['completed_at'], tp3.timestamp) diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py new file mode 100644 index 000000000..0edf0affd --- /dev/null +++ b/tests/api/test_utils.py @@ -0,0 +1,70 @@ +from unittest.mock import MagicMock + +from marshmallow import Schema, fields + +from mod_api.utils import (cursor_paginated_response, get_sort_column, + paginated_response, single_response) +from tests.base import BaseTestCase + + +class DummySchema(Schema): + id = fields.Integer() + name = fields.String() + + +class TestUtils(BaseTestCase): + def test_paginated_response_with_schema(self): + data = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}] + with self.app.test_request_context(): + res = paginated_response( + data, total=5, limit=2, offset=0, schema=DummySchema()) + self.assertEqual(res.status_code, 200) + json_data = res.json + self.assertEqual(len(json_data['data']), 2) + self.assertEqual(json_data['pagination']['total'], 5) + self.assertEqual(json_data['pagination']['next_offset'], 2) + + def test_paginated_response_no_schema(self): + data = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}] + with self.app.test_request_context(): + res = paginated_response(data, total=2, limit=2, offset=0) + self.assertEqual(res.status_code, 200) + json_data = res.json + self.assertEqual(len(json_data['data']), 2) + self.assertEqual(json_data['pagination']['total'], 2) + self.assertIsNone(json_data['pagination']['next_offset']) + + def test_cursor_paginated_response(self): + data = [{'id': 1, 'name': 'Item 1'}] + with self.app.test_request_context(): + res = cursor_paginated_response( + data, next_cursor=2, limit=1, schema=DummySchema()) + self.assertEqual(res.status_code, 200) + json_data = res.json + self.assertEqual(json_data['pagination']['next_cursor'], 2) + + res2 = cursor_paginated_response(data, next_cursor=None, limit=1) + self.assertIsNone(res2.json['pagination']['next_cursor']) + + def test_single_response(self): + data = {'id': 1, 'name': 'Item 1'} + with self.app.test_request_context(): + res = single_response(data, schema=DummySchema(), http_status=201) + self.assertEqual(res.status_code, 201) + self.assertEqual(res.json['name'], 'Item 1') + + res2 = single_response(data) + self.assertEqual(res2.status_code, 200) + + def test_get_sort_column(self): + mock_col = MagicMock() + mock_col.asc.return_value = 'asc_called' + mock_col.desc.return_value = 'desc_called' + + column_map = {'created_at': mock_col} + + self.assertIsNone(get_sort_column('invalid', column_map)) + self.assertEqual(get_sort_column( + 'created_at', column_map), 'asc_called') + self.assertEqual(get_sort_column( + '-created_at', column_map), 'desc_called') diff --git a/tests/test_ci/test_controllers.py b/tests/test_ci/test_controllers.py index cca01a54a..8ff86f7dc 100644 --- a/tests/test_ci/test_controllers.py +++ b/tests/test_ci/test_controllers.py @@ -730,7 +730,8 @@ def test_webhook_release_deleted(self, mock_request, mock_repo): last_release = CCExtractorVersion.query.order_by(CCExtractorVersion.released.desc()).first() self.assertNotEqual(last_release.version, '2.1') - def test_webhook_prerelease(self): + @mock.patch('requests.get', side_effect=mock_api_request_github) + def test_webhook_prerelease(self, mock_request): """Check webhook release update CCExtractor Version for prerelease.""" with self.app.test_client() as c: # Full Release with version with 2.1 (prereleased action is ignored)