Skip to content

Commit be1f6aa

Browse files
committed
feat: Update litellm dependency constraints and enhance pricing functionality
- Updated litellm dependency in pyproject.toml and uv.lock to restrict versions between 1.71.1 and 1.82.0. - Added new functions in api_costs.py to retrieve and normalize token pricing from litellm. - Implemented a fallback mechanism for litellm pricing in test_track_api_cost.py to ensure accurate cost calculations. - Introduced normalization of tool calls in server.py for Qwen3.5 model compatibility.
1 parent 67e81fa commit be1f6aa

7 files changed

Lines changed: 159 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.11"
77
dependencies = [
88
"openai>=2.14.0",
99
"typer>=0.15.2",
10-
"litellm>=1.71.1",
10+
"litellm>=1.71.1,<=1.82.0",
1111
"weave>=0.52.24",
1212
"polars>=1.26.0",
1313
"tblib>=3.0.0",

src/art/api_costs.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,60 @@ class _AnthropicTokenUsage:
5757
}
5858

5959

60+
def _litellm_price_per_million(
61+
model_info: Mapping[str, Any], field: str
62+
) -> float | None:
63+
value = model_info.get(field)
64+
if value is None or isinstance(value, bool):
65+
return None
66+
try:
67+
return float(value) * 1_000_000
68+
except (TypeError, ValueError):
69+
return None
70+
71+
72+
def _litellm_token_pricing(model_name: str) -> TokenPricing | None:
73+
try:
74+
from litellm import get_model_info
75+
76+
model_info = get_model_info(model_name)
77+
except Exception:
78+
return None
79+
80+
if not isinstance(model_info, Mapping):
81+
return None
82+
83+
prompt_per_million = _litellm_price_per_million(model_info, "input_cost_per_token")
84+
completion_per_million = _litellm_price_per_million(
85+
model_info, "output_cost_per_token"
86+
)
87+
if prompt_per_million is None or completion_per_million is None:
88+
return None
89+
90+
cache_read_per_million = _litellm_price_per_million(
91+
model_info, "cache_read_input_token_cost"
92+
)
93+
cache_creation_per_million = _litellm_price_per_million(
94+
model_info, "cache_creation_input_token_cost"
95+
)
96+
return TokenPricing(
97+
prompt_per_million=prompt_per_million,
98+
completion_per_million=completion_per_million,
99+
cached_prompt_per_million=cache_read_per_million,
100+
cache_creation_per_million=cache_creation_per_million,
101+
cache_read_per_million=cache_read_per_million,
102+
)
103+
104+
60105
def _configured_token_pricing(model_name: str) -> TokenPricing | None:
61106
explicit = MODEL_TOKEN_PRICING.get(model_name)
62107
if explicit is not None:
63108
return explicit
64109

110+
litellm_pricing = _litellm_token_pricing(model_name)
111+
if litellm_pricing is not None:
112+
return litellm_pricing
113+
65114
pricing = get_model_pricing(model_name)
66115
if pricing is None:
67116
return None

src/art/preprocessing/tokenize.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Callable
12
from dataclasses import dataclass, field
23
from functools import cached_property
34
from itertools import takewhile
@@ -12,6 +13,22 @@
1213

1314
from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages
1415

16+
ChatTemplateTool = dict[Any, Any] | Callable[..., Any]
17+
18+
19+
def _normalize_tools_for_chat_template(tools: Any) -> list[ChatTemplateTool] | None:
20+
if tools is None:
21+
return None
22+
normalized_tools: list[ChatTemplateTool] = []
23+
for tool in tools:
24+
if callable(tool):
25+
normalized_tools.append(tool)
26+
elif isinstance(tool, dict) and "type" in tool:
27+
normalized_tools.append(cast(dict[Any, Any], tool))
28+
else:
29+
normalized_tools.append({"type": "function", "function": tool})
30+
return normalized_tools
31+
1532

1633
@dataclass
1734
class TokenizedResult:
@@ -199,11 +216,7 @@ def tokenize_trajectory(
199216
return None
200217
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
201218
messages = get_messages(messages_and_choices)
202-
tools: Any = (
203-
[{"type": "function", "function": tool} for tool in history.tools]
204-
if history.tools is not None
205-
else None
206-
)
219+
tools = _normalize_tools_for_chat_template(history.tools)
207220
chat = cast(
208221
str,
209222
tokenizer.apply_chat_template(

src/art/tinker/renderers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ def get_renderer_name(base_model: str) -> str:
22
if base_model.startswith("meta-llama/"):
33
return "llama3"
44
elif base_model.startswith("Qwen/Qwen3.5-"):
5-
print("Defaulting to Qwen3.5 renderer with thinking for", base_model)
6-
print(renderer_name_message)
7-
return "qwen3_5"
5+
# print("Defaulting to Qwen3.5 renderer with thinking for", base_model)
6+
# print(renderer_name_message)
7+
return "qwen3_5_disable_thinking"
88
elif base_model.startswith("Qwen/Qwen3-"):
99
if "Instruct" in base_model:
1010
return "qwen3_instruct"

src/art/tinker/server.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import socket
99
import time
10-
from typing import Annotated, AsyncGenerator, Literal
10+
from typing import Annotated, Any, AsyncGenerator, Literal, cast
1111
import uuid
1212

1313
from fastapi import FastAPI, HTTPException, Request
@@ -47,6 +47,47 @@ class ModelUpsert(BaseModel):
4747
target: str
4848

4949

50+
def _normalize_qwen3_5_messages(
51+
base_model: str, messages: list[ChatCompletionMessageParam]
52+
) -> list[dict[str, Any]]:
53+
normalized_messages = [cast(dict[str, Any], message) for message in messages]
54+
if not base_model.startswith("Qwen/Qwen3.5"):
55+
return normalized_messages
56+
for i, message in enumerate(normalized_messages):
57+
tool_calls = message.get("tool_calls")
58+
if not isinstance(tool_calls, list):
59+
continue
60+
normalized_tool_calls: list[Any] = []
61+
changed = False
62+
for tool_call in tool_calls:
63+
if not isinstance(tool_call, dict):
64+
normalized_tool_calls.append(tool_call)
65+
continue
66+
function = tool_call.get("function")
67+
if not isinstance(function, dict):
68+
normalized_tool_calls.append(tool_call)
69+
continue
70+
arguments_json = function.get("arguments")
71+
if not isinstance(arguments_json, str):
72+
normalized_tool_calls.append(tool_call)
73+
continue
74+
try:
75+
arguments = json.loads(arguments_json)
76+
except json.JSONDecodeError:
77+
normalized_tool_calls.append(tool_call)
78+
continue
79+
if not isinstance(arguments, dict):
80+
normalized_tool_calls.append(tool_call)
81+
continue
82+
changed = True
83+
normalized_tool_calls.append(
84+
{**tool_call, "function": {**function, "arguments": arguments}}
85+
)
86+
if changed:
87+
normalized_messages[i] = {**message, "tool_calls": normalized_tool_calls}
88+
return normalized_messages
89+
90+
5091
@dataclass
5192
class OpenAICompatibleTinkerServer:
5293
host: str | None = None
@@ -389,9 +430,10 @@ async def prompt_tokens(
389430
messages: list[ChatCompletionMessageParam],
390431
tools: list[ChatCompletionToolUnionParam] | None,
391432
) -> list[int]:
433+
normalized_messages = _normalize_qwen3_5_messages(base_model, messages)
392434
encoding = self._get_renderer(base_model).tokenizer.apply_chat_template(
393-
messages, # type: ignore
394-
tools=tools, # type: ignore
435+
cast(Any, normalized_messages),
436+
tools=cast(Any, tools),
395437
add_generation_prompt=True,
396438
)
397439
if isinstance(encoding, BatchEncoding):

tests/unit/test_track_api_cost.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,48 @@ async def _judge() -> _AnthropicResponse:
346346
0.0021
347347
)
348348

349+
@pytest.mark.asyncio
350+
async def test_explicit_model_name_uses_litellm_pricing_fallback(
351+
self, monkeypatch: pytest.MonkeyPatch
352+
) -> None:
353+
import litellm
354+
355+
builder = MetricsBuilder(cost_context="train")
356+
357+
def _fake_get_model_info(model_name: str) -> dict[str, float]:
358+
assert model_name == "openai/fallback-model"
359+
return {
360+
"input_cost_per_token": 2.5e-06,
361+
"output_cost_per_token": 1.5e-05,
362+
"cache_read_input_token_cost": 2.5e-07,
363+
}
364+
365+
monkeypatch.setattr(litellm, "get_model_info", _fake_get_model_info)
366+
367+
@track_api_cost(
368+
source="llm_judge/litellm_fallback",
369+
provider="openai",
370+
model_name="openai/fallback-model",
371+
)
372+
async def _judge() -> _OpenAIResponse:
373+
return _OpenAIResponse(
374+
prompt_tokens=100,
375+
completion_tokens=50,
376+
cached_tokens=80,
377+
)
378+
379+
token = builder.activate()
380+
try:
381+
await _judge()
382+
finally:
383+
token.var.reset(token)
384+
385+
metrics = await builder.flush()
386+
expected = ((20 * 2.5) + (80 * 0.25) + (50 * 15.0)) / 1_000_000
387+
assert metrics["costs/train/llm_judge/litellm_fallback"] == pytest.approx(
388+
expected
389+
)
390+
349391
@pytest.mark.asyncio
350392
async def test_explicit_model_name_does_not_depend_on_response_model(self) -> None:
351393
builder = MetricsBuilder(cost_context="train")

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)