Skip to content

Commit f525f32

Browse files
committed
Remove local AI inference path and keep trained model remote-only
1 parent 1f68d0c commit f525f32

3 files changed

Lines changed: 12 additions & 120 deletions

File tree

config.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,7 @@ ai_trading:
118118
inference_url_env: "TRAINED_MODEL_INFERENCE_URL"
119119
api_key_env: "TRAINED_MODEL_API_KEY"
120120
timeout_seconds: 60
121-
base_model: "Qwen/Qwen2.5-7B-Instruct"
122-
adapter_dir: "./models/lora_solid_adapter"
123-
max_new_tokens: 64
124-
temperature: 0.0
125-
cpu_threads: 4
121+
model_name: "quant-trained-trading-model"
126122

127123

128124

requirements-ai-local.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

trained_model_client.py

Lines changed: 11 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import logging
33
import os
44
import re
5-
import threading
6-
from typing import Any, Dict, Optional
5+
from typing import Optional
76

87
import requests
98

@@ -23,9 +22,6 @@
2322

2423

2524
class TrainedModelTradeClient:
26-
_runtime_lock = threading.Lock()
27-
_runtime_cache: Dict[str, Dict[str, Any]] = {}
28-
2925
def __init__(self, ai_cfg: Optional[dict] = None):
3026
ai_cfg = dict(ai_cfg or {})
3127
model_cfg = dict(ai_cfg.get("trained_model") or {})
@@ -37,40 +33,28 @@ def __init__(self, ai_cfg: Optional[dict] = None):
3733
self.api_key_env = str(model_cfg.get("api_key_env", "") or "").strip()
3834
self.api_key = os.getenv(self.api_key_env).strip() if self.api_key_env and os.getenv(self.api_key_env) else ""
3935
self.timeout_seconds = int(model_cfg.get("timeout_seconds", 60) or 60)
40-
self.base_model = str(model_cfg.get("base_model", "Qwen/Qwen2.5-7B-Instruct") or "Qwen/Qwen2.5-7B-Instruct").strip()
41-
self.adapter_dir = str(model_cfg.get("adapter_dir", "./models/lora_solid_adapter") or "./models/lora_solid_adapter").strip()
42-
self.max_new_tokens = int(model_cfg.get("max_new_tokens", 64) or 64)
43-
self.temperature = float(model_cfg.get("temperature", 0.0) or 0.0)
44-
self.cpu_threads = int(model_cfg.get("cpu_threads", 4) or 4)
36+
self.model_name = str(model_cfg.get("model_name", "trained-trading-model") or "trained-trading-model").strip()
4537
self.last_error = None
4638
self.last_model_used = None
4739

4840
@property
4941
def model_identifier(self) -> str:
50-
if self.backend == "http":
51-
return self.inference_url or "trained-model-http"
52-
adapter_name = os.path.basename(os.path.normpath(self.adapter_dir or "adapter")) or "adapter"
53-
return f"{self.base_model}+{adapter_name}"
42+
return self.model_name or self.inference_url or "trained-model-http"
5443

5544
def is_ready(self) -> bool:
56-
if self.backend == "http":
57-
if not self.inference_url:
58-
self.last_error = "trained_model.inference_url is not configured"
59-
return False
60-
return True
61-
if self.backend == "local":
62-
if not self.base_model or not self.adapter_dir:
63-
self.last_error = "trained_model.base_model or trained_model.adapter_dir is missing"
64-
return False
65-
return True
66-
self.last_error = f"Unsupported trained model backend: {self.backend}"
67-
return False
45+
if self.backend != "http":
46+
self.last_error = f"Unsupported trained model backend: {self.backend}. Use remote HTTP inference only."
47+
return False
48+
if not self.inference_url:
49+
self.last_error = "trained_model.inference_url is not configured"
50+
return False
51+
return True
6852

6953
def predict_candidate(self, candidate: dict) -> Optional[dict]:
7054
if not self.is_ready():
7155
return None
7256
try:
73-
raw = self._predict_http(candidate) if self.backend == "http" else self._predict_local(candidate)
57+
raw = self._predict_http(candidate)
7458
except Exception as exc:
7559
self.last_error = str(exc)
7660
logger.warning("Trained model inference failed for %s: %s", candidate.get("symbol"), exc)
@@ -81,8 +65,6 @@ def _predict_http(self, candidate: dict):
8165
payload = {
8266
"candidate": candidate,
8367
"task": "trade_signal_classification",
84-
"max_new_tokens": self.max_new_tokens,
85-
"temperature": self.temperature,
8668
}
8769
headers = {"Content-Type": "application/json", "Accept": "application/json"}
8870
if self.api_key:
@@ -95,87 +77,6 @@ def _predict_http(self, candidate: dict):
9577
return data["signal"]
9678
return data
9779

98-
def _predict_local(self, candidate: dict):
99-
runtime = self._ensure_local_runtime()
100-
tokenizer = runtime["tokenizer"]
101-
model = runtime["model"]
102-
torch = runtime["torch"]
103-
messages = self._build_messages(candidate)
104-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105-
encoded = tokenizer(prompt, return_tensors="pt")
106-
input_len = encoded["input_ids"].shape[-1]
107-
with torch.no_grad():
108-
generated = model.generate(
109-
**encoded,
110-
max_new_tokens=self.max_new_tokens,
111-
do_sample=bool(self.temperature and self.temperature > 0.0),
112-
temperature=max(self.temperature, 1e-5) if self.temperature and self.temperature > 0 else 1.0,
113-
pad_token_id=tokenizer.eos_token_id,
114-
)
115-
text = tokenizer.decode(generated[0][input_len:], skip_special_tokens=True).strip()
116-
self.last_model_used = self.model_identifier
117-
return text
118-
119-
def _ensure_local_runtime(self):
120-
cache_key = f"{self.base_model}|{self.adapter_dir}"
121-
with self._runtime_lock:
122-
cached = self._runtime_cache.get(cache_key)
123-
if cached is not None:
124-
return cached
125-
import torch
126-
from peft import PeftModel
127-
from transformers import AutoModelForCausalLM, AutoTokenizer
128-
129-
torch.set_num_threads(max(1, int(self.cpu_threads or 1)))
130-
tokenizer = AutoTokenizer.from_pretrained(self.base_model, trust_remote_code=True)
131-
if tokenizer.pad_token is None:
132-
tokenizer.pad_token = tokenizer.eos_token
133-
model = AutoModelForCausalLM.from_pretrained(
134-
self.base_model,
135-
trust_remote_code=True,
136-
low_cpu_mem_usage=True,
137-
)
138-
model = PeftModel.from_pretrained(model, self.adapter_dir, is_trainable=False)
139-
model.eval()
140-
runtime = {"tokenizer": tokenizer, "model": model, "torch": torch}
141-
self._runtime_cache[cache_key] = runtime
142-
return runtime
143-
144-
def _build_messages(self, candidate: dict):
145-
symbol = str(candidate.get("symbol") or "UNKNOWN").strip().upper()
146-
as_of_date = candidate.get("as_of_date") or candidate.get("last_date") or "UNKNOWN"
147-
lines = [
148-
f"TICKER: {symbol}",
149-
f"DATE: {as_of_date}",
150-
"PRICE_ACTION:",
151-
f"- last_close: {candidate.get('last_close')}",
152-
f"- closes_tail: {candidate.get('closes_tail')}",
153-
f"- volume_1d: {candidate.get('volume_1d')}",
154-
f"- volume_20d_avg: {candidate.get('volume_20d_avg')}",
155-
"INDICATORS:",
156-
f"- return_1d: {candidate.get('return_1d')}",
157-
f"- return_5d: {candidate.get('return_5d')}",
158-
f"- return_10d: {candidate.get('return_10d')}",
159-
f"- volatility_20d: {candidate.get('volatility_20d')}",
160-
f"- dist_ma_20: {candidate.get('dist_ma_20')}",
161-
f"- dist_ma_50: {candidate.get('dist_ma_50')}",
162-
f"- rsi_14: {candidate.get('rsi_14')}",
163-
f"- volume_ratio: {candidate.get('volume_ratio')}",
164-
"NEWS_CONTEXT:",
165-
f"- news_count_7d: {candidate.get('news_count_7d')}",
166-
f"- news_sentiment_7d: {candidate.get('news_sentiment_7d')}",
167-
"",
168-
"QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL.",
169-
"Return ONLY JSON using this schema:",
170-
'{"label":"BUY","confidence":0.63,"reason":"..."}',
171-
]
172-
system = (
173-
"You are the trained AI trading decision engine. "
174-
"Return only valid JSON with label, confidence, and reason. "
175-
"Use the provided market snapshot to classify the next 5-day return."
176-
)
177-
return [{"role": "system", "content": system}, {"role": "user", "content": "\n".join(lines)}]
178-
17980
def _normalize_prediction(self, raw) -> Optional[dict]:
18081
parsed = raw
18182
raw_text = None

0 commit comments

Comments
 (0)