Skip to content

Commit 527cc80

Browse files
committed
Add direct CPU AI smoke test runner
1 parent 655182a commit 527cc80

1 file changed

Lines changed: 201 additions & 0 deletions

File tree

run_ai_trading_smoke_direct.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import json
2+
import os
3+
import re
4+
from datetime import datetime
5+
6+
import torch
7+
from peft import PeftModel
8+
from transformers import AutoModelForCausalLM, AutoTokenizer
9+
10+
from llm_trader import _enforce_english_reason, _pick_predictions, _weights_from_predictions
11+
from run_ai_trading_smoke import build_candidates, load_config
12+
13+
LABEL_TO_SCORE = {
14+
"STRONG_BUY": 2.0,
15+
"BUY": 1.0,
16+
"NEUTRAL": 0.0,
17+
"SELL": -1.0,
18+
"STRONG_SELL": -2.0,
19+
}
20+
LABEL_RE = re.compile(r"\b(STRONG_BUY|BUY|NEUTRAL|SELL|STRONG_SELL)\b", re.IGNORECASE)
21+
22+
23+
def _extract_json(text: str):
24+
if not text:
25+
return None
26+
text = str(text).strip()
27+
try:
28+
return json.loads(text)
29+
except Exception:
30+
pass
31+
start = text.find("{")
32+
end = text.rfind("}")
33+
if start >= 0 and end > start:
34+
try:
35+
return json.loads(text[start : end + 1])
36+
except Exception:
37+
pass
38+
return None
39+
40+
41+
def _parse_label(text: str):
42+
match = LABEL_RE.search(str(text or ""))
43+
if not match:
44+
return None
45+
return match.group(1).upper()
46+
47+
48+
def _candidate_prompt(candidate):
49+
return "\n".join(
50+
[
51+
f"TICKER: {candidate.get('symbol')}",
52+
f"DATE: {candidate.get('as_of_date') or candidate.get('last_date')}",
53+
f"LAST_CLOSE: {candidate.get('last_close')}",
54+
f"CLOSES_TAIL: {candidate.get('closes_tail')}",
55+
f"RETURN_1D: {candidate.get('return_1d')}",
56+
f"RETURN_5D: {candidate.get('return_5d')}",
57+
f"RETURN_10D: {candidate.get('return_10d')}",
58+
f"VOLATILITY_20D: {candidate.get('volatility_20d')}",
59+
f"DIST_MA_20: {candidate.get('dist_ma_20')}",
60+
f"DIST_MA_50: {candidate.get('dist_ma_50')}",
61+
f"RSI_14: {candidate.get('rsi_14')}",
62+
f"VOLUME_RATIO: {candidate.get('volume_ratio')}",
63+
f"NEWS_COUNT_7D: {candidate.get('news_count_7d')}",
64+
f"NEWS_SENTIMENT_7D: {candidate.get('news_sentiment_7d')}",
65+
"QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL.",
66+
'Return only compact JSON: {"label":"BUY","confidence":0.63,"reason":"short english phrase"}',
67+
]
68+
)
69+
70+
71+
def _load_runtime():
72+
base_model = os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
73+
adapter_path = os.getenv("TRAINED_MODEL_ADAPTER_PATH", "_smoke_artifacts/lora_solid_adapter")
74+
cpu_threads = max(1, int(os.getenv("TRAINED_MODEL_CPU_THREADS", "8") or 8))
75+
torch.set_num_threads(cpu_threads)
76+
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
77+
if tokenizer.pad_token is None:
78+
tokenizer.pad_token = tokenizer.eos_token
79+
model = AutoModelForCausalLM.from_pretrained(
80+
base_model,
81+
trust_remote_code=True,
82+
low_cpu_mem_usage=True,
83+
torch_dtype="auto",
84+
)
85+
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False)
86+
model.eval()
87+
return model, tokenizer
88+
89+
90+
def _predict_one(model, tokenizer, candidate):
91+
system = "You are the trained AI trading decision engine. Return only valid compact JSON with label, confidence, and a very short reason."
92+
prompt = tokenizer.apply_chat_template(
93+
[
94+
{"role": "system", "content": system},
95+
{"role": "user", "content": _candidate_prompt(candidate)},
96+
],
97+
tokenize=False,
98+
add_generation_prompt=True,
99+
)
100+
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536)
101+
prompt_len = encoded["input_ids"].shape[-1]
102+
with torch.no_grad():
103+
output = model.generate(
104+
**encoded,
105+
max_new_tokens=24,
106+
do_sample=False,
107+
pad_token_id=tokenizer.eos_token_id,
108+
)
109+
text = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip()
110+
parsed = _extract_json(text) or {}
111+
label = str(parsed.get("label") or "").strip().upper() or (_parse_label(parsed.get("reason") or text) or "NEUTRAL")
112+
confidence = parsed.get("confidence")
113+
try:
114+
confidence = float(confidence)
115+
except Exception:
116+
confidence = 0.5 if label == "NEUTRAL" else 0.65
117+
return {
118+
"symbol": candidate["symbol"],
119+
"label": label,
120+
"score": LABEL_TO_SCORE.get(label, 0.0),
121+
"confidence": max(0.0, min(1.0, confidence)),
122+
"reason": _enforce_english_reason(parsed.get("reason") or text, "LONG" if LABEL_TO_SCORE.get(label, 0.0) > 0 else "SHORT"),
123+
"raw_text": text,
124+
}
125+
126+
127+
def main():
128+
config = load_config()
129+
tickers = [s.strip().upper() for s in os.getenv("AI_SMOKE_TICKERS", "AAPL").split(",") if s.strip()]
130+
candidates, failures = build_candidates(config, tickers)
131+
model, tokenizer = _load_runtime()
132+
133+
predictions = []
134+
for candidate in candidates:
135+
pred = _predict_one(model, tokenizer, candidate)
136+
if float(pred.get("score", 0.0)) == 0.0:
137+
continue
138+
side = "LONG" if pred["score"] > 0 else "SHORT"
139+
predictions.append(
140+
{
141+
"symbol": candidate["symbol"],
142+
"side": side,
143+
"score": float(pred["score"]),
144+
"confidence": float(pred["confidence"]),
145+
"strength": max(0.01, abs(float(pred["score"])) * max(float(pred["confidence"]), 0.05)),
146+
"reason": pred["reason"],
147+
"label": pred["label"],
148+
"raw_text": pred["raw_text"],
149+
}
150+
)
151+
152+
ai_cfg = config.get("ai_trading", {}) if isinstance(config, dict) else {}
153+
picked = _pick_predictions(
154+
predictions,
155+
max_positions=min(int(ai_cfg.get("max_positions", 10) or 10), max(1, len(predictions) or 1)),
156+
allow_shorts=bool(ai_cfg.get("allow_shorts", True)),
157+
max_shorts=int(ai_cfg.get("max_shorts", 5) or 5),
158+
)
159+
weighted = _weights_from_predictions(picked, min_total_weight=float(ai_cfg.get("min_total_weight", 0.90) or 0.90)) if picked else []
160+
trades = [
161+
{
162+
"symbol": row["symbol"],
163+
"side": row["side"],
164+
"weight": float(row["weight"]),
165+
"reason": row["reason"],
166+
"label": row["label"],
167+
}
168+
for row in weighted
169+
]
170+
171+
payload = {
172+
"timestamp": datetime.utcnow().isoformat() + "Z",
173+
"tickers": tickers,
174+
"candidates_built": len(candidates),
175+
"candidate_failures": failures,
176+
"status": {
177+
"enabled": True,
178+
"ok": True,
179+
"error": None,
180+
"decision_engine": "trained_model",
181+
"backend": "github_actions_cpu_direct",
182+
"model": os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct"),
183+
"model_used": os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct"),
184+
"candidates_seen": len(candidates),
185+
"candidates_scored": len(predictions),
186+
},
187+
"predictions": predictions,
188+
"trades": trades,
189+
}
190+
191+
os.makedirs("results", exist_ok=True)
192+
out_path = os.path.join("results", f"ai_smoke_direct_{datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')}.json")
193+
with open(out_path, "w") as handle:
194+
json.dump(payload, handle, indent=2)
195+
print(json.dumps(payload, indent=2))
196+
if not candidates:
197+
raise SystemExit(1)
198+
199+
200+
if __name__ == "__main__":
201+
main()

0 commit comments

Comments
 (0)