@@ -95,6 +95,9 @@ def _candidate_prompt(candidate: Dict[str, Any]) -> str:
9595 return "\n " .join (lines )
9696
9797
98+ LABEL_RE = re .compile (r"\\b(STRONG_BUY|BUY|NEUTRAL|SELL|STRONG_SELL)\\b" , re .IGNORECASE )
99+
100+
98101def _extract_json (text : str ):
99102 if not text :
100103 return None
@@ -109,10 +112,19 @@ def _extract_json(text: str):
109112 try :
110113 return json .loads (text [start : end + 1 ])
111114 except Exception :
112- return None
115+ pass
113116 return None
114117
115118
119+ def _parse_plain_label (text : str ):
120+ if not text :
121+ return None
122+ match = LABEL_RE .search (str (text ))
123+ if not match :
124+ return None
125+ return {"label" : match .group (1 ).upper (), "confidence" : 0.5 , "reason" : str (text ).strip ()}
126+
127+
116128def _predict_one (candidate : Dict [str , Any ]) -> Dict [str , Any ]:
117129 model , tokenizer , torch = _load_runtime ()
118130 system = (
@@ -138,7 +150,7 @@ def _predict_one(candidate: Dict[str, Any]) -> Dict[str, Any]:
138150 pad_token_id = tokenizer .eos_token_id ,
139151 )
140152 text = tokenizer .decode (generated [0 ][input_len :], skip_special_tokens = True ).strip ()
141- parsed = _extract_json (text ) or {"label" : "NEUTRAL" , "confidence" : 0.5 , "reason" : text or "No parsable output." }
153+ parsed = _extract_json (text ) or _parse_plain_label ( text ) or {"label" : "NEUTRAL" , "confidence" : 0.5 , "reason" : text or "No parsable output." }
142154 parsed ["symbol" ] = candidate .get ("symbol" )
143155 return parsed
144156
0 commit comments