Skip to content

Commit bb6e775

Browse files
committed
Speed up Modal CPU trained-model inference
1 parent 40a6701 commit bb6e775

1 file changed

Lines changed: 16 additions & 24 deletions

File tree

modal_trained_model_service.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,20 @@ def _candidate_prompt(candidate: Dict[str, Any]) -> str:
7171
lines = [
7272
f"TICKER: {symbol}",
7373
f"DATE: {as_of_date}",
74-
"PRICE_ACTION:",
75-
f"- last_close: {candidate.get('last_close')}",
76-
f"- closes_tail: {candidate.get('closes_tail')}",
77-
f"- volume_1d: {candidate.get('volume_1d')}",
78-
f"- volume_20d_avg: {candidate.get('volume_20d_avg')}",
79-
"INDICATORS:",
80-
f"- return_1d: {candidate.get('return_1d')}",
81-
f"- return_5d: {candidate.get('return_5d')}",
82-
f"- return_10d: {candidate.get('return_10d')}",
83-
f"- volatility_20d: {candidate.get('volatility_20d')}",
84-
f"- dist_ma_20: {candidate.get('dist_ma_20')}",
85-
f"- dist_ma_50: {candidate.get('dist_ma_50')}",
86-
f"- rsi_14: {candidate.get('rsi_14')}",
87-
f"- volume_ratio: {candidate.get('volume_ratio')}",
88-
"NEWS_CONTEXT:",
89-
f"- news_count_7d: {candidate.get('news_count_7d')}",
90-
f"- news_sentiment_7d: {candidate.get('news_sentiment_7d')}",
91-
"",
74+
f"LAST_CLOSE: {candidate.get('last_close')}",
75+
f"CLOSES_TAIL: {candidate.get('closes_tail')}",
76+
f"RETURN_1D: {candidate.get('return_1d')}",
77+
f"RETURN_5D: {candidate.get('return_5d')}",
78+
f"RETURN_10D: {candidate.get('return_10d')}",
79+
f"VOLATILITY_20D: {candidate.get('volatility_20d')}",
80+
f"DIST_MA_20: {candidate.get('dist_ma_20')}",
81+
f"DIST_MA_50: {candidate.get('dist_ma_50')}",
82+
f"RSI_14: {candidate.get('rsi_14')}",
83+
f"VOLUME_RATIO: {candidate.get('volume_ratio')}",
84+
f"NEWS_COUNT_7D: {candidate.get('news_count_7d')}",
85+
f"NEWS_SENTIMENT_7D: {candidate.get('news_sentiment_7d')}",
9286
"QUESTION: Classify the expected 5-day return as STRONG_BUY | BUY | NEUTRAL | SELL | STRONG_SELL.",
93-
"Return ONLY JSON using this schema:",
94-
'{"label":"BUY","confidence":0.63,"reason":"..."}',
87+
'Return only compact JSON: {"label":"BUY","confidence":0.63,"reason":"short english phrase"}',
9588
]
9689
return "\n".join(lines)
9790

@@ -130,8 +123,7 @@ def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
130123
model, tokenizer, torch = _load_runtime()
131124
system = (
132125
"You are the trained AI trading decision engine. "
133-
"Return only valid JSON with label, confidence, and reason. "
134-
"Use the provided market snapshot to classify the next 5-day return."
126+
"Return only valid compact JSON with label, confidence, and a very short reason."
135127
)
136128
prompt = tokenizer.apply_chat_template(
137129
[
@@ -141,12 +133,12 @@ def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
141133
tokenize=False,
142134
add_generation_prompt=True,
143135
)
144-
encoded = tokenizer(prompt, return_tensors="pt")
136+
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536)
145137
input_len = encoded["input_ids"].shape[-1]
146138
with torch.no_grad():
147139
generated = model.generate(
148140
**encoded,
149-
max_new_tokens=64,
141+
max_new_tokens=24,
150142
do_sample=False,
151143
pad_token_id=tokenizer.eos_token_id,
152144
)

0 commit comments

Comments
 (0)