Skip to content

Commit 145d4f2

Browse files
committed
Improve Modal inference parsing for malformed model JSON
1 parent 07ea033 commit 145d4f2

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

modal_trained_model_service.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def _candidate_prompt(candidate: Dict[str, Any]) -> str:
9595
return "\n".join(lines)
9696

9797

98+
LABEL_RE = re.compile(r"\\b(STRONG_BUY|BUY|NEUTRAL|SELL|STRONG_SELL)\\b", re.IGNORECASE)
99+
100+
98101
def _extract_json(text: str):
99102
if not text:
100103
return None
@@ -109,10 +112,19 @@ def _extract_json(text: str):
109112
try:
110113
return json.loads(text[start : end + 1])
111114
except Exception:
112-
return None
115+
pass
113116
return None
114117

115118

119+
def _parse_plain_label(text: str):
120+
if not text:
121+
return None
122+
match = LABEL_RE.search(str(text))
123+
if not match:
124+
return None
125+
return {"label": match.group(1).upper(), "confidence": 0.5, "reason": str(text).strip()}
126+
127+
116128
def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
117129
model, tokenizer, torch = _load_runtime()
118130
system = (
@@ -138,7 +150,7 @@ def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
138150
pad_token_id=tokenizer.eos_token_id,
139151
)
140152
text = tokenizer.decode(generated[0][input_len:], skip_special_tokens=True).strip()
141-
parsed = _extract_json(text) or {"label": "NEUTRAL", "confidence": 0.5, "reason": text or "No parsable output."}
153+
parsed = _extract_json(text) or _parse_plain_label(text) or {"label": "NEUTRAL", "confidence": 0.5, "reason": text or "No parsable output."}
142154
parsed["symbol"] = candidate.get("symbol")
143155
return parsed
144156

0 commit comments

Comments
 (0)