66from dataclasses import dataclass
77import json
88import os
9- from typing import Any
9+ import re
10+ from typing import Any , Literal
1011
1112from openai import OpenAI
1213
1314from roast .analyzer import AnalysisReport , Issue
1415from roast .scanner import FileResult
1516
17+ Provider = Literal ["auto" , "groq" , "nim" , "openai" , "none" ]
18+
19+ DEFAULT_PRIMARY_PROVIDER = "groq"
20+ DEFAULT_BACKUP_PROVIDER = "nim"
21+ DEFAULT_GROQ_MODEL = "llama-3.3-70b-versatile"
22+ DEFAULT_GROQ_FAST_MODEL = "llama-3.1-8b-instant"
23+ DEFAULT_NIM_MODEL = "microsoft/phi-4-mini-instruct"
24+ DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
25+
1626
1727@dataclass (slots = True )
1828class RoastResult :
@@ -137,35 +147,140 @@ def _generate_fallback_roast(report: AnalysisReport) -> RoastResult:
137147 )
138148
139149
140- def generate_roast (
150+ def _provider_api_key (provider : Provider ) -> str | None :
151+ if provider == "groq" :
152+ return os .getenv ("GROQ_API_KEY" )
153+ if provider == "nim" :
154+ return os .getenv ("NVIDIA_NIM_API_KEY" ) or os .getenv ("NIM_API_KEY" )
155+ if provider == "openai" :
156+ return os .getenv ("OPENAI_API_KEY" )
157+ return None
158+
159+
160+ def _provider_base_url (provider : Provider ) -> str | None :
161+ if provider == "groq" :
162+ return "https://api.groq.com/openai/v1"
163+ if provider == "nim" :
164+ return "https://integrate.api.nvidia.com/v1"
165+ return None
166+
167+
168+ def _default_model_for_provider (provider : Provider ) -> str :
169+ if provider == "groq" :
170+ return DEFAULT_GROQ_MODEL
171+ if provider == "nim" :
172+ return DEFAULT_NIM_MODEL
173+ return DEFAULT_OPENAI_MODEL
174+
175+
176+ def _extract_json_payload (content : str ) -> dict [str , Any ]:
177+ text = content .strip ()
178+ try :
179+ payload = json .loads (text )
180+ if isinstance (payload , dict ):
181+ return payload
182+ except json .JSONDecodeError :
183+ pass
184+
185+ match = re .search (r"\{[\s\S]*\}" , text )
186+ if not match :
187+ raise ValueError ("Model response was not valid JSON." )
188+ payload = json .loads (match .group (0 ))
189+ if not isinstance (payload , dict ):
190+ raise ValueError ("Model JSON payload must be an object." )
191+ return payload
192+
193+
194+ def _build_provider_plan (
195+ provider : Provider ,
196+ model : str | None ,
197+ backup_provider : Provider ,
198+ backup_model : str | None ,
199+ ) -> list [tuple [Provider , str ]]:
200+ plan : list [tuple [Provider , str ]] = []
201+
202+ if provider == "auto" :
203+ plan = [
204+ ("groq" , model or DEFAULT_GROQ_MODEL ),
205+ ("groq" , DEFAULT_GROQ_FAST_MODEL ),
206+ ("nim" , backup_model or DEFAULT_NIM_MODEL ),
207+ ("openai" , DEFAULT_OPENAI_MODEL ),
208+ ]
209+ else :
210+ plan = [(provider , model or _default_model_for_provider (provider ))]
211+ if backup_provider not in {"none" , provider }:
212+ plan .append ((backup_provider , backup_model or _default_model_for_provider (backup_provider )))
213+
214+ seen : set [tuple [Provider , str ]] = set ()
215+ deduped : list [tuple [Provider , str ]] = []
216+ for item in plan :
217+ if item in seen :
218+ continue
219+ seen .add (item )
220+ deduped .append (item )
221+ return deduped
222+
223+
224+ def _call_roast_llm (
225+ provider : Provider ,
226+ model : str ,
141227 report : AnalysisReport ,
142228 files : list [FileResult ],
143- model : str = "gpt-4o-mini" ,
144- no_llm : bool = False ,
145229) -> RoastResult :
146- """Generate a roast using an LLM or deterministic fallback mode."""
147- if no_llm :
148- return _generate_fallback_roast (report )
230+ api_key = _provider_api_key (provider )
231+ if not api_key :
232+ raise RuntimeError (f"Missing API key for provider '{ provider } '." )
233+
234+ base_url = _provider_base_url (provider )
235+ if base_url :
236+ client = OpenAI (api_key = api_key , base_url = base_url )
237+ else :
238+ client = OpenAI (api_key = api_key )
149239
150- client = OpenAI (api_key = os .getenv ("OPENAI_API_KEY" ))
151240 overall_score = report .scores .get ("Overall" , 0 )
152241 response = client .chat .completions .create (
153242 model = model ,
154- temperature = 0.9 ,
155- response_format = { "type" : "json_object" } ,
243+ temperature = 0.8 ,
244+ max_tokens = 500 ,
156245 messages = [
157246 {
158247 "role" : "system" ,
159248 "content" : (
160249 "You are a senior developer who has seen too much bad code. "
161250 "You are brutally honest but funny, like a Gordon Ramsay for codebases. "
162251 "Be specific, reference actual file names and issues. Never be generic. "
163- "Keep roast lines under 20 words each."
252+ "Keep roast lines under 20 words each. "
253+ "Return strict JSON only."
164254 ),
165255 },
166256 {"role" : "user" , "content" : _build_user_prompt (report , files )},
167257 ],
168258 )
169259 content = response .choices [0 ].message .content or "{}"
170- payload = json . loads (content )
260+ payload = _extract_json_payload (content )
171261 return _normalize_roast_payload (payload , overall_score )
262+
263+
264+ def generate_roast (
265+ report : AnalysisReport ,
266+ files : list [FileResult ],
267+ model : str | None = None ,
268+ no_llm : bool = False ,
269+ provider : Provider = "auto" ,
270+ backup_provider : Provider = DEFAULT_BACKUP_PROVIDER ,
271+ backup_model : str | None = None ,
272+ ) -> RoastResult :
273+ """Generate a roast using an LLM or deterministic fallback mode."""
274+ if no_llm :
275+ return _generate_fallback_roast (report )
276+
277+ plan = _build_provider_plan (provider , model , backup_provider , backup_model )
278+ errors : list [str ] = []
279+
280+ for provider_name , model_name in plan :
281+ try :
282+ return _call_roast_llm (provider_name , model_name , report , files )
283+ except Exception as exc : # noqa: BLE001
284+ errors .append (f"{ provider_name } :{ model_name } -> { exc } " )
285+
286+ raise RuntimeError ("All LLM providers failed. " + " | " .join (errors ))
0 commit comments