Skip to content

Commit 52207f9

Browse files
committed
Add batched Modal CPU inference path for AI trading bot
1 parent f525f32 commit 52207f9

3 files changed

Lines changed: 192 additions & 14 deletions

File tree

llm_trader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def propose_trades_with_llm(config, candidates, max_positions=10, allow_shorts=T
153153

154154
predictions = []
155155
failures = []
156-
for candidate in prompt_candidates:
157-
prediction = client.predict_candidate(candidate)
156+
batch_predictions = client.predict_candidates(prompt_candidates)
157+
for candidate, prediction in zip(prompt_candidates, batch_predictions):
158158
if prediction is None:
159159
failures.append(
160160
{

modal_trained_model_service.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import json
2+
import os
3+
from typing import Any, Dict, List
4+
5+
import modal
6+
7+
APP_NAME = os.getenv("TRAINED_MODEL_MODAL_APP", "trading-bot-trained-model-inference")
8+
BASE_MODEL = os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
9+
VOLUME_NAME = os.getenv("TRAINED_MODEL_VOLUME", "train-once-artifacts")
10+
ADAPTER_PATH = os.getenv("TRAINED_MODEL_ADAPTER_PATH", "/artifacts/lora_solid_adapter")
11+
MODEL_NAME = os.getenv("TRAINED_MODEL_NAME", "quant-trained-trading-model")
12+
CPU_COUNT = int(os.getenv("TRAINED_MODEL_CPU", "8"))
13+
MEMORY_MB = int(os.getenv("TRAINED_MODEL_MEMORY_MB", "65536"))
14+
15+
app = modal.App(APP_NAME)
16+
image = (
17+
modal.Image.debian_slim(python_version="3.11")
18+
.pip_install(
19+
"fastapi>=0.115.0",
20+
"pydantic>=2.9.2",
21+
"torch>=2.4.1",
22+
"transformers>=4.46.0",
23+
"peft>=0.13.2",
24+
"accelerate>=1.0.1",
25+
"sentencepiece>=0.2.0",
26+
)
27+
)
28+
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=False)
29+
30+
_MODEL = None
31+
_TOKENIZER = None
32+
_TORCH = None
33+
34+
35+
def _load_runtime():
36+
global _MODEL, _TOKENIZER, _TORCH
37+
if _MODEL is not None and _TOKENIZER is not None and _TORCH is not None:
38+
return _MODEL, _TOKENIZER, _TORCH
39+
40+
import torch
41+
from peft import PeftModel
42+
from transformers import AutoModelForCausalLM, AutoTokenizer
43+
44+
torch.set_num_threads(max(1, CPU_COUNT))
45+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
46+
if tokenizer.pad_token is None:
47+
tokenizer.pad_token = tokenizer.eos_token
48+
49+
model = AutoModelForCausalLM.from_pretrained(
50+
BASE_MODEL,
51+
trust_remote_code=True,
52+
low_cpu_mem_usage=True,
53+
)
54+
model = PeftModel.from_pretrained(model, ADAPTER_PATH, is_trainable=False)
55+
model.eval()
56+
57+
_MODEL = model
58+
_TOKENIZER = tokenizer
59+
_TORCH = torch
60+
return _MODEL, _TOKENIZER, _TORCH
61+
62+
63+
def _candidate_prompt(candidate: Dict[str, Any]) -> str:
64+
symbol = str(candidate.get("symbol") or "UNKNOWN").strip().upper()
65+
as_of_date = candidate.get("as_of_date") or candidate.get("last_date") or "UNKNOWN"
66+
lines = [
67+
f"TICKER: {symbol}",
68+
f"DATE: {as_of_date}",
69+
"PRICE_ACTION:",
70+
f"- last_close: {candidate.get('last_close')}",
71+
f"- closes_tail: {candidate.get('closes_tail')}",
72+
f"- volume_1d: {candidate.get('volume_1d')}",
73+
f"- volume_20d_avg: {candidate.get('volume_20d_avg')}",
74+
"INDICATORS:",
75+
f"- return_1d: {candidate.get('return_1d')}",
76+
f"- return_5d: {candidate.get('return_5d')}",
77+
f"- return_10d: {candidate.get('return_10d')}",
78+
f"- volatility_20d: {candidate.get('volatility_20d')}",
79+
f"- dist_ma_20: {candidate.get('dist_ma_20')}",
80+
f"- dist_ma_50: {candidate.get('dist_ma_50')}",
81+
f"- rsi_14: {candidate.get('rsi_14')}",
82+
f"- volume_ratio: {candidate.get('volume_ratio')}",
83+
"NEWS_CONTEXT:",
84+
f"- news_count_7d: {candidate.get('news_count_7d')}",
85+
f"- news_sentiment_7d: {candidate.get('news_sentiment_7d')}",
86+
"",
87+
"QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL.",
88+
"Return ONLY JSON using this schema:",
89+
'{"label":"BUY","confidence":0.63,"reason":"..."}',
90+
]
91+
return "\n".join(lines)
92+
93+
94+
def _extract_json(text: str):
95+
if not text:
96+
return None
97+
text = str(text).strip()
98+
try:
99+
return json.loads(text)
100+
except Exception:
101+
pass
102+
start = text.find("{")
103+
end = text.rfind("}")
104+
if start >= 0 and end > start:
105+
try:
106+
return json.loads(text[start : end + 1])
107+
except Exception:
108+
return None
109+
return None
110+
111+
112+
def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
113+
model, tokenizer, torch = _load_runtime()
114+
system = (
115+
"You are the trained AI trading decision engine. "
116+
"Return only valid JSON with label, confidence, and reason. "
117+
"Use the provided market snapshot to classify the next 5-day return."
118+
)
119+
prompt = tokenizer.apply_chat_template(
120+
[
121+
{"role": "system", "content": system},
122+
{"role": "user", "content": _candidate_prompt(candidate)},
123+
],
124+
tokenize=False,
125+
add_generation_prompt=True,
126+
)
127+
encoded = tokenizer(prompt, return_tensors="pt")
128+
input_len = encoded["input_ids"].shape[-1]
129+
with torch.no_grad():
130+
generated = model.generate(
131+
**encoded,
132+
max_new_tokens=64,
133+
do_sample=False,
134+
pad_token_id=tokenizer.eos_token_id,
135+
)
136+
text = tokenizer.decode(generated[0][input_len:], skip_special_tokens=True).strip()
137+
parsed = _extract_json(text) or {"label": "NEUTRAL", "confidence": 0.5, "reason": text or "No parsable output."}
138+
parsed["symbol"] = candidate.get("symbol")
139+
return parsed
140+
141+
142+
@app.function(
143+
image=image,
144+
cpu=CPU_COUNT,
145+
memory=MEMORY_MB,
146+
scaledown_window=300,
147+
timeout=3600,
148+
volumes={"/artifacts": volume},
149+
)
150+
@modal.web_endpoint(method="POST")
151+
def predict_trade_candidates(payload: Dict[str, Any]):
152+
candidates = payload.get("candidates") or []
153+
if not isinstance(candidates, list):
154+
candidate = payload.get("candidate")
155+
candidates = [candidate] if isinstance(candidate, dict) else []
156+
signals = [_predict_one(c) for c in candidates if isinstance(c, dict)]
157+
return {
158+
"model": MODEL_NAME,
159+
"model_used": MODEL_NAME,
160+
"signals": signals,
161+
}

trained_model_client.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
import re
5-
from typing import Optional
5+
from typing import List, Optional
66

77
import requests
88

@@ -33,7 +33,7 @@ def __init__(self, ai_cfg: Optional[dict] = None):
3333
self.api_key_env = str(model_cfg.get("api_key_env", "") or "").strip()
3434
self.api_key = os.getenv(self.api_key_env).strip() if self.api_key_env and os.getenv(self.api_key_env) else ""
3535
self.timeout_seconds = int(model_cfg.get("timeout_seconds", 60) or 60)
36-
self.model_name = str(model_cfg.get("model_name", "trained-trading-model") or "trained-trading-model").strip()
36+
self.model_name = str(model_cfg.get("model_name", "quant-trained-trading-model") or "quant-trained-trading-model").strip()
3737
self.last_error = None
3838
self.last_model_used = None
3939

@@ -51,19 +51,32 @@ def is_ready(self) -> bool:
5151
return True
5252

5353
def predict_candidate(self, candidate: dict) -> Optional[dict]:
54+
results = self.predict_candidates([candidate])
55+
return results[0] if results else None
56+
57+
def predict_candidates(self, candidates: List[dict]) -> List[Optional[dict]]:
5458
if not self.is_ready():
55-
return None
59+
return [None for _ in list(candidates or [])]
60+
payload_candidates = [dict(c or {}) for c in list(candidates or []) if isinstance(c, dict)]
61+
if not payload_candidates:
62+
return []
5663
try:
57-
raw = self._predict_http(candidate)
64+
raw_signals = self._predict_batch_http(payload_candidates)
5865
except Exception as exc:
5966
self.last_error = str(exc)
60-
logger.warning("Trained model inference failed for %s: %s", candidate.get("symbol"), exc)
61-
return None
62-
return self._normalize_prediction(raw)
67+
logger.warning("Trained model batch inference failed: %s", exc)
68+
return [None for _ in payload_candidates]
69+
70+
out = []
71+
for signal in raw_signals:
72+
out.append(self._normalize_prediction(signal))
73+
while len(out) < len(payload_candidates):
74+
out.append(None)
75+
return out[: len(payload_candidates)]
6376

64-
def _predict_http(self, candidate: dict):
77+
def _predict_batch_http(self, candidates: List[dict]):
6578
payload = {
66-
"candidate": candidate,
79+
"candidates": candidates,
6780
"task": "trade_signal_classification",
6881
}
6982
headers = {"Content-Type": "application/json", "Accept": "application/json"}
@@ -73,9 +86,13 @@ def _predict_http(self, candidate: dict):
7386
response.raise_for_status()
7487
data = response.json()
7588
self.last_model_used = data.get("model") or data.get("model_used") or self.model_identifier
76-
if isinstance(data.get("signal"), dict):
77-
return data["signal"]
78-
return data
89+
signals = data.get("signals")
90+
if isinstance(signals, list):
91+
return signals
92+
signal = data.get("signal")
93+
if signal is not None:
94+
return [signal]
95+
return []
7996

8097
def _normalize_prediction(self, raw) -> Optional[dict]:
8198
parsed = raw

0 commit comments

Comments
 (0)