Skip to content

Commit 7307c5a

Browse files
committed
Add Groq primary + NIM backup provider support
1 parent 77dfcfe commit 7307c5a

1 file changed

Lines changed: 127 additions & 12 deletions

File tree

roast/roaster.py

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,23 @@
66
from dataclasses import dataclass
77
import json
88
import os
9-
from typing import Any
9+
import re
10+
from typing import Any, Literal
1011

1112
from openai import OpenAI
1213

1314
from roast.analyzer import AnalysisReport, Issue
1415
from roast.scanner import FileResult
1516

17+
Provider = Literal["auto", "groq", "nim", "openai", "none"]
18+
19+
DEFAULT_PRIMARY_PROVIDER = "groq"
20+
DEFAULT_BACKUP_PROVIDER = "nim"
21+
DEFAULT_GROQ_MODEL = "llama-3.3-70b-versatile"
22+
DEFAULT_GROQ_FAST_MODEL = "llama-3.1-8b-instant"
23+
DEFAULT_NIM_MODEL = "microsoft/phi-4-mini-instruct"
24+
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
25+
1626

1727
@dataclass(slots=True)
1828
class RoastResult:
@@ -137,35 +147,140 @@ def _generate_fallback_roast(report: AnalysisReport) -> RoastResult:
137147
)
138148

139149

140-
def generate_roast(
150+
def _provider_api_key(provider: Provider) -> str | None:
151+
if provider == "groq":
152+
return os.getenv("GROQ_API_KEY")
153+
if provider == "nim":
154+
return os.getenv("NVIDIA_NIM_API_KEY") or os.getenv("NIM_API_KEY")
155+
if provider == "openai":
156+
return os.getenv("OPENAI_API_KEY")
157+
return None
158+
159+
160+
def _provider_base_url(provider: Provider) -> str | None:
161+
if provider == "groq":
162+
return "https://api.groq.com/openai/v1"
163+
if provider == "nim":
164+
return "https://integrate.api.nvidia.com/v1"
165+
return None
166+
167+
168+
def _default_model_for_provider(provider: Provider) -> str:
169+
if provider == "groq":
170+
return DEFAULT_GROQ_MODEL
171+
if provider == "nim":
172+
return DEFAULT_NIM_MODEL
173+
return DEFAULT_OPENAI_MODEL
174+
175+
176+
def _extract_json_payload(content: str) -> dict[str, Any]:
177+
text = content.strip()
178+
try:
179+
payload = json.loads(text)
180+
if isinstance(payload, dict):
181+
return payload
182+
except json.JSONDecodeError:
183+
pass
184+
185+
match = re.search(r"\{[\s\S]*\}", text)
186+
if not match:
187+
raise ValueError("Model response was not valid JSON.")
188+
payload = json.loads(match.group(0))
189+
if not isinstance(payload, dict):
190+
raise ValueError("Model JSON payload must be an object.")
191+
return payload
192+
193+
194+
def _build_provider_plan(
195+
provider: Provider,
196+
model: str | None,
197+
backup_provider: Provider,
198+
backup_model: str | None,
199+
) -> list[tuple[Provider, str]]:
200+
plan: list[tuple[Provider, str]] = []
201+
202+
if provider == "auto":
203+
plan = [
204+
("groq", model or DEFAULT_GROQ_MODEL),
205+
("groq", DEFAULT_GROQ_FAST_MODEL),
206+
("nim", backup_model or DEFAULT_NIM_MODEL),
207+
("openai", DEFAULT_OPENAI_MODEL),
208+
]
209+
else:
210+
plan = [(provider, model or _default_model_for_provider(provider))]
211+
if backup_provider not in {"none", provider}:
212+
plan.append((backup_provider, backup_model or _default_model_for_provider(backup_provider)))
213+
214+
seen: set[tuple[Provider, str]] = set()
215+
deduped: list[tuple[Provider, str]] = []
216+
for item in plan:
217+
if item in seen:
218+
continue
219+
seen.add(item)
220+
deduped.append(item)
221+
return deduped
222+
223+
224+
def _call_roast_llm(
225+
provider: Provider,
226+
model: str,
141227
report: AnalysisReport,
142228
files: list[FileResult],
143-
model: str = "gpt-4o-mini",
144-
no_llm: bool = False,
145229
) -> RoastResult:
146-
"""Generate a roast using an LLM or deterministic fallback mode."""
147-
if no_llm:
148-
return _generate_fallback_roast(report)
230+
api_key = _provider_api_key(provider)
231+
if not api_key:
232+
raise RuntimeError(f"Missing API key for provider '{provider}'.")
233+
234+
base_url = _provider_base_url(provider)
235+
if base_url:
236+
client = OpenAI(api_key=api_key, base_url=base_url)
237+
else:
238+
client = OpenAI(api_key=api_key)
149239

150-
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
151240
overall_score = report.scores.get("Overall", 0)
152241
response = client.chat.completions.create(
153242
model=model,
154-
temperature=0.9,
155-
response_format={"type": "json_object"},
243+
temperature=0.8,
244+
max_tokens=500,
156245
messages=[
157246
{
158247
"role": "system",
159248
"content": (
160249
"You are a senior developer who has seen too much bad code. "
161250
"You are brutally honest but funny, like a Gordon Ramsay for codebases. "
162251
"Be specific, reference actual file names and issues. Never be generic. "
163-
"Keep roast lines under 20 words each."
252+
"Keep roast lines under 20 words each. "
253+
"Return strict JSON only."
164254
),
165255
},
166256
{"role": "user", "content": _build_user_prompt(report, files)},
167257
],
168258
)
169259
content = response.choices[0].message.content or "{}"
170-
payload = json.loads(content)
260+
payload = _extract_json_payload(content)
171261
return _normalize_roast_payload(payload, overall_score)
262+
263+
264+
def generate_roast(
265+
report: AnalysisReport,
266+
files: list[FileResult],
267+
model: str | None = None,
268+
no_llm: bool = False,
269+
provider: Provider = "auto",
270+
backup_provider: Provider = DEFAULT_BACKUP_PROVIDER,
271+
backup_model: str | None = None,
272+
) -> RoastResult:
273+
"""Generate a roast using an LLM or deterministic fallback mode."""
274+
if no_llm:
275+
return _generate_fallback_roast(report)
276+
277+
plan = _build_provider_plan(provider, model, backup_provider, backup_model)
278+
errors: list[str] = []
279+
280+
for provider_name, model_name in plan:
281+
try:
282+
return _call_roast_llm(provider_name, model_name, report, files)
283+
except Exception as exc: # noqa: BLE001
284+
errors.append(f"{provider_name}:{model_name} -> {exc}")
285+
286+
raise RuntimeError("All LLM providers failed. " + " | ".join(errors))

0 commit comments

Comments
 (0)