diff --git a/applications/ColossalChat/start_code_verifier.py b/applications/ColossalChat/start_code_verifier.py index d1924f610698..f49d2d421c7c 100644 --- a/applications/ColossalChat/start_code_verifier.py +++ b/applications/ColossalChat/start_code_verifier.py @@ -1,11 +1,15 @@ +import os from typing import List, Optional from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, Header, HTTPException from pydantic import BaseModel app = FastAPI() +_API_KEY = os.environ.get("CODE_VERIFIER_API_KEY", "") +_MAX_TIMEOUT = 30 + class CheckCorrectnessRequest(BaseModel): in_outs: Optional[dict] @@ -21,12 +25,14 @@ class CheckCorrectnessResponse(BaseModel): @app.post("/check_correctness", response_model=CheckCorrectnessResponse) -def check_correctness_api(request: CheckCorrectnessRequest): +def check_correctness_api(request: CheckCorrectnessRequest, x_api_key: str = Header(...)): + if not _API_KEY or x_api_key != _API_KEY: + raise HTTPException(status_code=401, detail="Unauthorized") try: result, metadata = check_correctness( in_outs=request.in_outs, generation=request.generation, - timeout=request.timeout, + timeout=min(request.timeout, _MAX_TIMEOUT), debug=request.debug, eval_mode=request.eval_mode, ) diff --git a/tests/test_invariant_start_code_verifier.py b/tests/test_invariant_start_code_verifier.py new file mode 100644 index 000000000000..95c179c5ee6f --- /dev/null +++ b/tests/test_invariant_start_code_verifier.py @@ -0,0 +1,48 @@ +import importlib.util +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +# Load the actual production app from the file under test +spec = importlib.util.spec_from_file_location( + "start_code_verifier", Path(__file__).parent / "applications/ColossalChat/start_code_verifier.py" +) +module = importlib.util.load_from_spec(spec) +spec.loader.exec_module(module) +app = module.app + +client = TestClient(app, raise_server_exceptions=False) + +VALID_PAYLOAD = {"generation": "def f(x): return x", "in_outs": "{}", "timeout": 5, "debug": False, "eval_mode": "exec"} + + +@pytest.mark.parametrize( + "headers,payload", + [ + # Exact exploit: no auth, arbitrary code execution attempt + ( + {}, + { + "generation": "__import__('os').system('id')", + "in_outs": "{}", + "timeout": 10, + "debug": False, + "eval_mode": "exec", + }, + ), + # Missing token (boundary: empty Authorization header) + ({"Authorization": ""}, VALID_PAYLOAD), + # Malformed token + ({"Authorization": "Bearer not.a.valid.jwt"}, VALID_PAYLOAD), + # Expired/fake token + ({"Authorization": "Bearer eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid"}, VALID_PAYLOAD), + ], +) +def test_check_correctness_requires_authentication(headers, payload): + """Invariant: /check_correctness must reject unauthenticated or invalidly-authenticated requests with 401 or 403.""" + response = client.post("/check_correctness", json=payload, headers=headers) + assert response.status_code in (401, 403), ( + f"Expected 401 or 403 for unauthenticated request, got {response.status_code}. " + "The endpoint must not be publicly accessible without valid credentials." + )