22import logging
33import os
44import re
5- import threading
6- from typing import Any , Dict , Optional
5+ from typing import Optional
76
87import requests
98
2322
2423
2524class TrainedModelTradeClient :
26- _runtime_lock = threading .Lock ()
27- _runtime_cache : Dict [str , Dict [str , Any ]] = {}
28-
2925 def __init__ (self , ai_cfg : Optional [dict ] = None ):
3026 ai_cfg = dict (ai_cfg or {})
3127 model_cfg = dict (ai_cfg .get ("trained_model" ) or {})
@@ -37,40 +33,28 @@ def __init__(self, ai_cfg: Optional[dict] = None):
3733 self .api_key_env = str (model_cfg .get ("api_key_env" , "" ) or "" ).strip ()
3834 self .api_key = os .getenv (self .api_key_env ).strip () if self .api_key_env and os .getenv (self .api_key_env ) else ""
3935 self .timeout_seconds = int (model_cfg .get ("timeout_seconds" , 60 ) or 60 )
40- self .base_model = str (model_cfg .get ("base_model" , "Qwen/Qwen2.5-7B-Instruct" ) or "Qwen/Qwen2.5-7B-Instruct" ).strip ()
41- self .adapter_dir = str (model_cfg .get ("adapter_dir" , "./models/lora_solid_adapter" ) or "./models/lora_solid_adapter" ).strip ()
42- self .max_new_tokens = int (model_cfg .get ("max_new_tokens" , 64 ) or 64 )
43- self .temperature = float (model_cfg .get ("temperature" , 0.0 ) or 0.0 )
44- self .cpu_threads = int (model_cfg .get ("cpu_threads" , 4 ) or 4 )
36+ self .model_name = str (model_cfg .get ("model_name" , "trained-trading-model" ) or "trained-trading-model" ).strip ()
4537 self .last_error = None
4638 self .last_model_used = None
4739
4840 @property
4941 def model_identifier (self ) -> str :
50- if self .backend == "http" :
51- return self .inference_url or "trained-model-http"
52- adapter_name = os .path .basename (os .path .normpath (self .adapter_dir or "adapter" )) or "adapter"
53- return f"{ self .base_model } +{ adapter_name } "
42+ return self .model_name or self .inference_url or "trained-model-http"
5443
5544 def is_ready (self ) -> bool :
56- if self .backend == "http" :
57- if not self .inference_url :
58- self .last_error = "trained_model.inference_url is not configured"
59- return False
60- return True
61- if self .backend == "local" :
62- if not self .base_model or not self .adapter_dir :
63- self .last_error = "trained_model.base_model or trained_model.adapter_dir is missing"
64- return False
65- return True
66- self .last_error = f"Unsupported trained model backend: { self .backend } "
67- return False
45+ if self .backend != "http" :
46+ self .last_error = f"Unsupported trained model backend: { self .backend } . Use remote HTTP inference only."
47+ return False
48+ if not self .inference_url :
49+ self .last_error = "trained_model.inference_url is not configured"
50+ return False
51+ return True
6852
6953 def predict_candidate (self , candidate : dict ) -> Optional [dict ]:
7054 if not self .is_ready ():
7155 return None
7256 try :
73- raw = self ._predict_http (candidate ) if self . backend == "http" else self . _predict_local ( candidate )
57+ raw = self ._predict_http (candidate )
7458 except Exception as exc :
7559 self .last_error = str (exc )
7660 logger .warning ("Trained model inference failed for %s: %s" , candidate .get ("symbol" ), exc )
@@ -81,8 +65,6 @@ def _predict_http(self, candidate: dict):
8165 payload = {
8266 "candidate" : candidate ,
8367 "task" : "trade_signal_classification" ,
84- "max_new_tokens" : self .max_new_tokens ,
85- "temperature" : self .temperature ,
8668 }
8769 headers = {"Content-Type" : "application/json" , "Accept" : "application/json" }
8870 if self .api_key :
@@ -95,87 +77,6 @@ def _predict_http(self, candidate: dict):
9577 return data ["signal" ]
9678 return data
9779
98- def _predict_local (self , candidate : dict ):
99- runtime = self ._ensure_local_runtime ()
100- tokenizer = runtime ["tokenizer" ]
101- model = runtime ["model" ]
102- torch = runtime ["torch" ]
103- messages = self ._build_messages (candidate )
104- prompt = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
105- encoded = tokenizer (prompt , return_tensors = "pt" )
106- input_len = encoded ["input_ids" ].shape [- 1 ]
107- with torch .no_grad ():
108- generated = model .generate (
109- ** encoded ,
110- max_new_tokens = self .max_new_tokens ,
111- do_sample = bool (self .temperature and self .temperature > 0.0 ),
112- temperature = max (self .temperature , 1e-5 ) if self .temperature and self .temperature > 0 else 1.0 ,
113- pad_token_id = tokenizer .eos_token_id ,
114- )
115- text = tokenizer .decode (generated [0 ][input_len :], skip_special_tokens = True ).strip ()
116- self .last_model_used = self .model_identifier
117- return text
118-
119- def _ensure_local_runtime (self ):
120- cache_key = f"{ self .base_model } |{ self .adapter_dir } "
121- with self ._runtime_lock :
122- cached = self ._runtime_cache .get (cache_key )
123- if cached is not None :
124- return cached
125- import torch
126- from peft import PeftModel
127- from transformers import AutoModelForCausalLM , AutoTokenizer
128-
129- torch .set_num_threads (max (1 , int (self .cpu_threads or 1 )))
130- tokenizer = AutoTokenizer .from_pretrained (self .base_model , trust_remote_code = True )
131- if tokenizer .pad_token is None :
132- tokenizer .pad_token = tokenizer .eos_token
133- model = AutoModelForCausalLM .from_pretrained (
134- self .base_model ,
135- trust_remote_code = True ,
136- low_cpu_mem_usage = True ,
137- )
138- model = PeftModel .from_pretrained (model , self .adapter_dir , is_trainable = False )
139- model .eval ()
140- runtime = {"tokenizer" : tokenizer , "model" : model , "torch" : torch }
141- self ._runtime_cache [cache_key ] = runtime
142- return runtime
143-
144- def _build_messages (self , candidate : dict ):
145- symbol = str (candidate .get ("symbol" ) or "UNKNOWN" ).strip ().upper ()
146- as_of_date = candidate .get ("as_of_date" ) or candidate .get ("last_date" ) or "UNKNOWN"
147- lines = [
148- f"TICKER: { symbol } " ,
149- f"DATE: { as_of_date } " ,
150- "PRICE_ACTION:" ,
151- f"- last_close: { candidate .get ('last_close' )} " ,
152- f"- closes_tail: { candidate .get ('closes_tail' )} " ,
153- f"- volume_1d: { candidate .get ('volume_1d' )} " ,
154- f"- volume_20d_avg: { candidate .get ('volume_20d_avg' )} " ,
155- "INDICATORS:" ,
156- f"- return_1d: { candidate .get ('return_1d' )} " ,
157- f"- return_5d: { candidate .get ('return_5d' )} " ,
158- f"- return_10d: { candidate .get ('return_10d' )} " ,
159- f"- volatility_20d: { candidate .get ('volatility_20d' )} " ,
160- f"- dist_ma_20: { candidate .get ('dist_ma_20' )} " ,
161- f"- dist_ma_50: { candidate .get ('dist_ma_50' )} " ,
162- f"- rsi_14: { candidate .get ('rsi_14' )} " ,
163- f"- volume_ratio: { candidate .get ('volume_ratio' )} " ,
164- "NEWS_CONTEXT:" ,
165- f"- news_count_7d: { candidate .get ('news_count_7d' )} " ,
166- f"- news_sentiment_7d: { candidate .get ('news_sentiment_7d' )} " ,
167- "" ,
168- "QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL." ,
169- "Return ONLY JSON using this schema:" ,
170- '{"label":"BUY","confidence":0.63,"reason":"..."}' ,
171- ]
172- system = (
173- "You are the trained AI trading decision engine. "
174- "Return only valid JSON with label, confidence, and reason. "
175- "Use the provided market snapshot to classify the next 5-day return."
176- )
177- return [{"role" : "system" , "content" : system }, {"role" : "user" , "content" : "\n " .join (lines )}]
178-
17980 def _normalize_prediction (self , raw ) -> Optional [dict ]:
18081 parsed = raw
18182 raw_text = None
0 commit comments