|
| 1 | +import json |
| 2 | +import os |
| 3 | +import re |
| 4 | +from datetime import datetime |
| 5 | + |
| 6 | +import torch |
| 7 | +from peft import PeftModel |
| 8 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 9 | + |
| 10 | +from llm_trader import _enforce_english_reason, _pick_predictions, _weights_from_predictions |
| 11 | +from run_ai_trading_smoke import build_candidates, load_config |
| 12 | + |
| 13 | +LABEL_TO_SCORE = { |
| 14 | + "STRONG_BUY": 2.0, |
| 15 | + "BUY": 1.0, |
| 16 | + "NEUTRAL": 0.0, |
| 17 | + "SELL": -1.0, |
| 18 | + "STRONG_SELL": -2.0, |
| 19 | +} |
| 20 | +LABEL_RE = re.compile(r"\b(STRONG_BUY|BUY|NEUTRAL|SELL|STRONG_SELL)\b", re.IGNORECASE) |
| 21 | + |
| 22 | + |
| 23 | +def _extract_json(text: str): |
| 24 | + if not text: |
| 25 | + return None |
| 26 | + text = str(text).strip() |
| 27 | + try: |
| 28 | + return json.loads(text) |
| 29 | + except Exception: |
| 30 | + pass |
| 31 | + start = text.find("{") |
| 32 | + end = text.rfind("}") |
| 33 | + if start >= 0 and end > start: |
| 34 | + try: |
| 35 | + return json.loads(text[start : end + 1]) |
| 36 | + except Exception: |
| 37 | + pass |
| 38 | + return None |
| 39 | + |
| 40 | + |
| 41 | +def _parse_label(text: str): |
| 42 | + match = LABEL_RE.search(str(text or "")) |
| 43 | + if not match: |
| 44 | + return None |
| 45 | + return match.group(1).upper() |
| 46 | + |
| 47 | + |
| 48 | +def _candidate_prompt(candidate): |
| 49 | + return "\n".join( |
| 50 | + [ |
| 51 | + f"TICKER: {candidate.get('symbol')}", |
| 52 | + f"DATE: {candidate.get('as_of_date') or candidate.get('last_date')}", |
| 53 | + f"LAST_CLOSE: {candidate.get('last_close')}", |
| 54 | + f"CLOSES_TAIL: {candidate.get('closes_tail')}", |
| 55 | + f"RETURN_1D: {candidate.get('return_1d')}", |
| 56 | + f"RETURN_5D: {candidate.get('return_5d')}", |
| 57 | + f"RETURN_10D: {candidate.get('return_10d')}", |
| 58 | + f"VOLATILITY_20D: {candidate.get('volatility_20d')}", |
| 59 | + f"DIST_MA_20: {candidate.get('dist_ma_20')}", |
| 60 | + f"DIST_MA_50: {candidate.get('dist_ma_50')}", |
| 61 | + f"RSI_14: {candidate.get('rsi_14')}", |
| 62 | + f"VOLUME_RATIO: {candidate.get('volume_ratio')}", |
| 63 | + f"NEWS_COUNT_7D: {candidate.get('news_count_7d')}", |
| 64 | + f"NEWS_SENTIMENT_7D: {candidate.get('news_sentiment_7d')}", |
| 65 | + "QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL.", |
| 66 | + 'Return only compact JSON: {"label":"BUY","confidence":0.63,"reason":"short english phrase"}', |
| 67 | + ] |
| 68 | + ) |
| 69 | + |
| 70 | + |
| 71 | +def _load_runtime(): |
| 72 | + base_model = os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct") |
| 73 | + adapter_path = os.getenv("TRAINED_MODEL_ADAPTER_PATH", "_smoke_artifacts/lora_solid_adapter") |
| 74 | + cpu_threads = max(1, int(os.getenv("TRAINED_MODEL_CPU_THREADS", "8") or 8)) |
| 75 | + torch.set_num_threads(cpu_threads) |
| 76 | + tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| 77 | + if tokenizer.pad_token is None: |
| 78 | + tokenizer.pad_token = tokenizer.eos_token |
| 79 | + model = AutoModelForCausalLM.from_pretrained( |
| 80 | + base_model, |
| 81 | + trust_remote_code=True, |
| 82 | + low_cpu_mem_usage=True, |
| 83 | + torch_dtype="auto", |
| 84 | + ) |
| 85 | + model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False) |
| 86 | + model.eval() |
| 87 | + return model, tokenizer |
| 88 | + |
| 89 | + |
| 90 | +def _predict_one(model, tokenizer, candidate): |
| 91 | + system = "You are the trained AI trading decision engine. Return only valid compact JSON with label, confidence, and a very short reason." |
| 92 | + prompt = tokenizer.apply_chat_template( |
| 93 | + [ |
| 94 | + {"role": "system", "content": system}, |
| 95 | + {"role": "user", "content": _candidate_prompt(candidate)}, |
| 96 | + ], |
| 97 | + tokenize=False, |
| 98 | + add_generation_prompt=True, |
| 99 | + ) |
| 100 | + encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536) |
| 101 | + prompt_len = encoded["input_ids"].shape[-1] |
| 102 | + with torch.no_grad(): |
| 103 | + output = model.generate( |
| 104 | + **encoded, |
| 105 | + max_new_tokens=24, |
| 106 | + do_sample=False, |
| 107 | + pad_token_id=tokenizer.eos_token_id, |
| 108 | + ) |
| 109 | + text = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip() |
| 110 | + parsed = _extract_json(text) or {} |
| 111 | + label = str(parsed.get("label") or "").strip().upper() or (_parse_label(parsed.get("reason") or text) or "NEUTRAL") |
| 112 | + confidence = parsed.get("confidence") |
| 113 | + try: |
| 114 | + confidence = float(confidence) |
| 115 | + except Exception: |
| 116 | + confidence = 0.5 if label == "NEUTRAL" else 0.65 |
| 117 | + return { |
| 118 | + "symbol": candidate["symbol"], |
| 119 | + "label": label, |
| 120 | + "score": LABEL_TO_SCORE.get(label, 0.0), |
| 121 | + "confidence": max(0.0, min(1.0, confidence)), |
| 122 | + "reason": _enforce_english_reason(parsed.get("reason") or text, "LONG" if LABEL_TO_SCORE.get(label, 0.0) > 0 else "SHORT"), |
| 123 | + "raw_text": text, |
| 124 | + } |
| 125 | + |
| 126 | + |
| 127 | +def main(): |
| 128 | + config = load_config() |
| 129 | + tickers = [s.strip().upper() for s in os.getenv("AI_SMOKE_TICKERS", "AAPL").split(",") if s.strip()] |
| 130 | + candidates, failures = build_candidates(config, tickers) |
| 131 | + model, tokenizer = _load_runtime() |
| 132 | + |
| 133 | + predictions = [] |
| 134 | + for candidate in candidates: |
| 135 | + pred = _predict_one(model, tokenizer, candidate) |
| 136 | + if float(pred.get("score", 0.0)) == 0.0: |
| 137 | + continue |
| 138 | + side = "LONG" if pred["score"] > 0 else "SHORT" |
| 139 | + predictions.append( |
| 140 | + { |
| 141 | + "symbol": candidate["symbol"], |
| 142 | + "side": side, |
| 143 | + "score": float(pred["score"]), |
| 144 | + "confidence": float(pred["confidence"]), |
| 145 | + "strength": max(0.01, abs(float(pred["score"])) * max(float(pred["confidence"]), 0.05)), |
| 146 | + "reason": pred["reason"], |
| 147 | + "label": pred["label"], |
| 148 | + "raw_text": pred["raw_text"], |
| 149 | + } |
| 150 | + ) |
| 151 | + |
| 152 | + ai_cfg = config.get("ai_trading", {}) if isinstance(config, dict) else {} |
| 153 | + picked = _pick_predictions( |
| 154 | + predictions, |
| 155 | + max_positions=min(int(ai_cfg.get("max_positions", 10) or 10), max(1, len(predictions) or 1)), |
| 156 | + allow_shorts=bool(ai_cfg.get("allow_shorts", True)), |
| 157 | + max_shorts=int(ai_cfg.get("max_shorts", 5) or 5), |
| 158 | + ) |
| 159 | + weighted = _weights_from_predictions(picked, min_total_weight=float(ai_cfg.get("min_total_weight", 0.90) or 0.90)) if picked else [] |
| 160 | + trades = [ |
| 161 | + { |
| 162 | + "symbol": row["symbol"], |
| 163 | + "side": row["side"], |
| 164 | + "weight": float(row["weight"]), |
| 165 | + "reason": row["reason"], |
| 166 | + "label": row["label"], |
| 167 | + } |
| 168 | + for row in weighted |
| 169 | + ] |
| 170 | + |
| 171 | + payload = { |
| 172 | + "timestamp": datetime.utcnow().isoformat() + "Z", |
| 173 | + "tickers": tickers, |
| 174 | + "candidates_built": len(candidates), |
| 175 | + "candidate_failures": failures, |
| 176 | + "status": { |
| 177 | + "enabled": True, |
| 178 | + "ok": True, |
| 179 | + "error": None, |
| 180 | + "decision_engine": "trained_model", |
| 181 | + "backend": "github_actions_cpu_direct", |
| 182 | + "model": os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct"), |
| 183 | + "model_used": os.getenv("TRAINED_MODEL_BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct"), |
| 184 | + "candidates_seen": len(candidates), |
| 185 | + "candidates_scored": len(predictions), |
| 186 | + }, |
| 187 | + "predictions": predictions, |
| 188 | + "trades": trades, |
| 189 | + } |
| 190 | + |
| 191 | + os.makedirs("results", exist_ok=True) |
| 192 | + out_path = os.path.join("results", f"ai_smoke_direct_{datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')}.json") |
| 193 | + with open(out_path, "w") as handle: |
| 194 | + json.dump(payload, handle, indent=2) |
| 195 | + print(json.dumps(payload, indent=2)) |
| 196 | + if not candidates: |
| 197 | + raise SystemExit(1) |
| 198 | + |
| 199 | + |
| 200 | +if __name__ == "__main__": |
| 201 | + main() |
0 commit comments