Skip to content

Commit 07ea033

Browse files
committed
Harden Modal CPU inference with cached HF downloads and longer timeout
1 parent aaf93ae commit 07ea033

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ ai_trading:
117117
inference_url: ""
118118
inference_url_env: "TRAINED_MODEL_INFERENCE_URL"
119119
api_key_env: "TRAINED_MODEL_API_KEY"
120-
timeout_seconds: 60
120+
timeout_seconds: 600
121121
model_name: "quant-trained-trading-model"
122122

123123

modal_trained_model_service.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313
MEMORY_MB = int(os.getenv("TRAINED_MODEL_MEMORY_MB", "65536"))
1414

1515
app = modal.App(APP_NAME)
16+
os.environ.setdefault("HF_HOME", "/artifacts/hf_home")
17+
os.environ.setdefault("TRANSFORMERS_CACHE", "/artifacts/hf_home/transformers")
18+
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/artifacts/hf_home/hub")
19+
1620
image = (
1721
modal.Image.debian_slim(python_version="3.11")
22+
.pip_install("torch==2.4.1", index_url="https://download.pytorch.org/whl/cpu")
1823
.pip_install(
1924
"fastapi>=0.115.0",
2025
"pydantic>=2.9.2",
21-
"torch>=2.4.1",
2226
"transformers>=4.46.0",
2327
"peft>=0.13.2",
2428
"accelerate>=1.0.1",
@@ -144,15 +148,18 @@ def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
144148
cpu=CPU_COUNT,
145149
memory=MEMORY_MB,
146150
scaledown_window=300,
147-
timeout=3600,
151+
timeout=7200,
152+
startup_timeout=1800,
148153
volumes={"/artifacts": volume},
149154
)
150155
@modal.fastapi_endpoint(method="POST")
151156
def predict_trade_candidates(payload: Dict[str, Any]):
152-
candidates = payload.get("candidates") or []
157+
candidates = payload.get("candidates")
153158
if not isinstance(candidates, list):
154159
candidate = payload.get("candidate")
155160
candidates = [candidate] if isinstance(candidate, dict) else []
161+
else:
162+
candidates = list(candidates or [])
156163
signals = [_predict_one(c) for c in candidates if isinstance(c, dict)]
157164
return {
158165
"model": MODEL_NAME,

0 commit comments

Comments
 (0)