Skip to content

Commit d9b425c

Browse files
committed
Add trained model client for AI trading bot
1 parent d0a68d5 commit d9b425c

1 file changed

Lines changed: 216 additions & 0 deletions

File tree

trained_model_client.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import json
2+
import logging
3+
import os
4+
import re
5+
import threading
6+
from typing import Any, Dict, Optional
7+
8+
import requests
9+
10+
from llm_sentiment import _extract_json
11+
12+
logger = logging.getLogger(__name__)
13+
14+
LABEL_TO_SCORE = {
15+
"STRONG_BUY": 2.0,
16+
"BUY": 1.0,
17+
"NEUTRAL": 0.0,
18+
"SELL": -1.0,
19+
"STRONG_SELL": -2.0,
20+
}
21+
22+
_LABEL_RE = re.compile(r"\b(STRONG_BUY|BUY|NEUTRAL|SELL|STRONG_SELL)\b", re.IGNORECASE)
23+
24+
25+
class TrainedModelTradeClient:
26+
_runtime_lock = threading.Lock()
27+
_runtime_cache: Dict[str, Dict[str, Any]] = {}
28+
29+
def __init__(self, ai_cfg: Optional[dict] = None):
30+
ai_cfg = dict(ai_cfg or {})
31+
model_cfg = dict(ai_cfg.get("trained_model") or {})
32+
self.backend = str(model_cfg.get("backend", "http") or "http").strip().lower()
33+
self.inference_url = str(model_cfg.get("inference_url", "") or "").strip()
34+
self.api_key_env = str(model_cfg.get("api_key_env", "") or "").strip()
35+
self.api_key = os.getenv(self.api_key_env).strip() if self.api_key_env and os.getenv(self.api_key_env) else ""
36+
self.timeout_seconds = int(model_cfg.get("timeout_seconds", 60) or 60)
37+
self.base_model = str(model_cfg.get("base_model", "Qwen/Qwen2.5-7B-Instruct") or "Qwen/Qwen2.5-7B-Instruct").strip()
38+
self.adapter_dir = str(model_cfg.get("adapter_dir", "./models/lora_solid_adapter") or "./models/lora_solid_adapter").strip()
39+
self.max_new_tokens = int(model_cfg.get("max_new_tokens", 64) or 64)
40+
self.temperature = float(model_cfg.get("temperature", 0.0) or 0.0)
41+
self.cpu_threads = int(model_cfg.get("cpu_threads", 4) or 4)
42+
self.last_error = None
43+
self.last_model_used = None
44+
45+
@property
46+
def model_identifier(self) -> str:
47+
if self.backend == "http":
48+
return self.inference_url or "trained-model-http"
49+
adapter_name = os.path.basename(os.path.normpath(self.adapter_dir or "adapter")) or "adapter"
50+
return f"{self.base_model}+{adapter_name}"
51+
52+
def is_ready(self) -> bool:
53+
if self.backend == "http":
54+
if not self.inference_url:
55+
self.last_error = "trained_model.inference_url is not configured"
56+
return False
57+
return True
58+
if self.backend == "local":
59+
if not self.base_model or not self.adapter_dir:
60+
self.last_error = "trained_model.base_model or trained_model.adapter_dir is missing"
61+
return False
62+
return True
63+
self.last_error = f"Unsupported trained model backend: {self.backend}"
64+
return False
65+
66+
def predict_candidate(self, candidate: dict) -> Optional[dict]:
67+
if not self.is_ready():
68+
return None
69+
try:
70+
raw = self._predict_http(candidate) if self.backend == "http" else self._predict_local(candidate)
71+
except Exception as exc:
72+
self.last_error = str(exc)
73+
logger.warning("Trained model inference failed for %s: %s", candidate.get("symbol"), exc)
74+
return None
75+
return self._normalize_prediction(raw)
76+
77+
def _predict_http(self, candidate: dict):
78+
payload = {
79+
"candidate": candidate,
80+
"task": "trade_signal_classification",
81+
"max_new_tokens": self.max_new_tokens,
82+
"temperature": self.temperature,
83+
}
84+
headers = {"Content-Type": "application/json", "Accept": "application/json"}
85+
if self.api_key:
86+
headers["Authorization"] = f"Bearer {self.api_key}"
87+
response = requests.post(self.inference_url, json=payload, headers=headers, timeout=self.timeout_seconds)
88+
response.raise_for_status()
89+
data = response.json()
90+
self.last_model_used = data.get("model") or data.get("model_used") or self.model_identifier
91+
if isinstance(data.get("signal"), dict):
92+
return data["signal"]
93+
return data
94+
95+
def _predict_local(self, candidate: dict):
96+
runtime = self._ensure_local_runtime()
97+
tokenizer = runtime["tokenizer"]
98+
model = runtime["model"]
99+
torch = runtime["torch"]
100+
messages = self._build_messages(candidate)
101+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102+
encoded = tokenizer(prompt, return_tensors="pt")
103+
input_len = encoded["input_ids"].shape[-1]
104+
with torch.no_grad():
105+
generated = model.generate(
106+
**encoded,
107+
max_new_tokens=self.max_new_tokens,
108+
do_sample=bool(self.temperature and self.temperature > 0.0),
109+
temperature=max(self.temperature, 1e-5) if self.temperature and self.temperature > 0 else 1.0,
110+
pad_token_id=tokenizer.eos_token_id,
111+
)
112+
text = tokenizer.decode(generated[0][input_len:], skip_special_tokens=True).strip()
113+
self.last_model_used = self.model_identifier
114+
return text
115+
116+
def _ensure_local_runtime(self):
117+
cache_key = f"{self.base_model}|{self.adapter_dir}"
118+
with self._runtime_lock:
119+
cached = self._runtime_cache.get(cache_key)
120+
if cached is not None:
121+
return cached
122+
import torch
123+
from peft import PeftModel
124+
from transformers import AutoModelForCausalLM, AutoTokenizer
125+
126+
torch.set_num_threads(max(1, int(self.cpu_threads or 1)))
127+
tokenizer = AutoTokenizer.from_pretrained(self.base_model, trust_remote_code=True)
128+
if tokenizer.pad_token is None:
129+
tokenizer.pad_token = tokenizer.eos_token
130+
model = AutoModelForCausalLM.from_pretrained(
131+
self.base_model,
132+
trust_remote_code=True,
133+
low_cpu_mem_usage=True,
134+
)
135+
model = PeftModel.from_pretrained(model, self.adapter_dir, is_trainable=False)
136+
model.eval()
137+
runtime = {"tokenizer": tokenizer, "model": model, "torch": torch}
138+
self._runtime_cache[cache_key] = runtime
139+
return runtime
140+
141+
def _build_messages(self, candidate: dict):
142+
symbol = str(candidate.get("symbol") or "UNKNOWN").strip().upper()
143+
as_of_date = candidate.get("as_of_date") or candidate.get("last_date") or "UNKNOWN"
144+
lines = [
145+
f"TICKER: {symbol}",
146+
f"DATE: {as_of_date}",
147+
"PRICE_ACTION:",
148+
f"- last_close: {candidate.get('last_close')}",
149+
f"- closes_tail: {candidate.get('closes_tail')}",
150+
f"- volume_1d: {candidate.get('volume_1d')}",
151+
f"- volume_20d_avg: {candidate.get('volume_20d_avg')}",
152+
"INDICATORS:",
153+
f"- return_1d: {candidate.get('return_1d')}",
154+
f"- return_5d: {candidate.get('return_5d')}",
155+
f"- return_10d: {candidate.get('return_10d')}",
156+
f"- volatility_20d: {candidate.get('volatility_20d')}",
157+
f"- dist_ma_20: {candidate.get('dist_ma_20')}",
158+
f"- dist_ma_50: {candidate.get('dist_ma_50')}",
159+
f"- rsi_14: {candidate.get('rsi_14')}",
160+
f"- volume_ratio: {candidate.get('volume_ratio')}",
161+
"NEWS_CONTEXT:",
162+
f"- news_count_7d: {candidate.get('news_count_7d')}",
163+
f"- news_sentiment_7d: {candidate.get('news_sentiment_7d')}",
164+
"",
165+
"QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL.",
166+
"Return ONLY JSON using this schema:",
167+
'{"label":"BUY","confidence":0.63,"reason":"..."}',
168+
]
169+
system = (
170+
"You are the trained AI trading decision engine. "
171+
"Return only valid JSON with label, confidence, and reason. "
172+
"Use the provided market snapshot to classify the next 5-day return."
173+
)
174+
return [{"role": "system", "content": system}, {"role": "user", "content": "\n".join(lines)}]
175+
176+
def _normalize_prediction(self, raw) -> Optional[dict]:
177+
parsed = raw
178+
raw_text = None
179+
if isinstance(raw, str):
180+
raw_text = raw
181+
parsed = _extract_json(raw) or self._parse_plain_label(raw)
182+
elif isinstance(raw, dict):
183+
raw_text = json.dumps(raw)
184+
else:
185+
raw_text = str(raw)
186+
parsed = self._parse_plain_label(raw_text)
187+
if not isinstance(parsed, dict):
188+
self.last_error = "Trained model response could not be parsed"
189+
return None
190+
label = str(parsed.get("label") or parsed.get("signal") or "").strip().upper()
191+
if label not in LABEL_TO_SCORE:
192+
self.last_error = f"Unsupported trained model label: {label or 'missing'}"
193+
return None
194+
confidence = parsed.get("confidence")
195+
try:
196+
confidence = float(confidence)
197+
except (TypeError, ValueError):
198+
confidence = 0.9 if label.startswith("STRONG_") else (0.65 if label != "NEUTRAL" else 0.5)
199+
confidence = max(0.0, min(1.0, confidence))
200+
reason = str(parsed.get("reason") or parsed.get("notes") or f"Model classified {label}.").strip()
201+
return {
202+
"label": label,
203+
"score": LABEL_TO_SCORE[label],
204+
"confidence": confidence,
205+
"reason": reason,
206+
"raw_text": raw_text,
207+
}
208+
209+
@staticmethod
210+
def _parse_plain_label(text: str) -> Optional[dict]:
211+
if not text:
212+
return None
213+
match = _LABEL_RE.search(str(text))
214+
if not match:
215+
return None
216+
return {"label": match.group(1).upper(), "reason": str(text).strip()}

0 commit comments

Comments
 (0)