|
7 | 7 | import os |
8 | 8 | import socket |
9 | 9 | import time |
10 | | -from typing import Annotated, AsyncGenerator, Literal |
| 10 | +from typing import Annotated, Any, AsyncGenerator, Literal, cast |
11 | 11 | import uuid |
12 | 12 |
|
13 | 13 | from fastapi import FastAPI, HTTPException, Request |
@@ -47,6 +47,47 @@ class ModelUpsert(BaseModel): |
47 | 47 | target: str |
48 | 48 |
|
49 | 49 |
|
| 50 | +def _normalize_qwen3_5_messages( |
| 51 | + base_model: str, messages: list[ChatCompletionMessageParam] |
| 52 | +) -> list[dict[str, Any]]: |
| 53 | + normalized_messages = [cast(dict[str, Any], message) for message in messages] |
| 54 | + if not base_model.startswith("Qwen/Qwen3.5"): |
| 55 | + return normalized_messages |
| 56 | + for i, message in enumerate(normalized_messages): |
| 57 | + tool_calls = message.get("tool_calls") |
| 58 | + if not isinstance(tool_calls, list): |
| 59 | + continue |
| 60 | + normalized_tool_calls: list[Any] = [] |
| 61 | + changed = False |
| 62 | + for tool_call in tool_calls: |
| 63 | + if not isinstance(tool_call, dict): |
| 64 | + normalized_tool_calls.append(tool_call) |
| 65 | + continue |
| 66 | + function = tool_call.get("function") |
| 67 | + if not isinstance(function, dict): |
| 68 | + normalized_tool_calls.append(tool_call) |
| 69 | + continue |
| 70 | + arguments_json = function.get("arguments") |
| 71 | + if not isinstance(arguments_json, str): |
| 72 | + normalized_tool_calls.append(tool_call) |
| 73 | + continue |
| 74 | + try: |
| 75 | + arguments = json.loads(arguments_json) |
| 76 | + except json.JSONDecodeError: |
| 77 | + normalized_tool_calls.append(tool_call) |
| 78 | + continue |
| 79 | + if not isinstance(arguments, dict): |
| 80 | + normalized_tool_calls.append(tool_call) |
| 81 | + continue |
| 82 | + changed = True |
| 83 | + normalized_tool_calls.append( |
| 84 | + {**tool_call, "function": {**function, "arguments": arguments}} |
| 85 | + ) |
| 86 | + if changed: |
| 87 | + normalized_messages[i] = {**message, "tool_calls": normalized_tool_calls} |
| 88 | + return normalized_messages |
| 89 | + |
| 90 | + |
50 | 91 | @dataclass |
51 | 92 | class OpenAICompatibleTinkerServer: |
52 | 93 | host: str | None = None |
@@ -389,9 +430,10 @@ async def prompt_tokens( |
389 | 430 | messages: list[ChatCompletionMessageParam], |
390 | 431 | tools: list[ChatCompletionToolUnionParam] | None, |
391 | 432 | ) -> list[int]: |
| 433 | + normalized_messages = _normalize_qwen3_5_messages(base_model, messages) |
392 | 434 | encoding = self._get_renderer(base_model).tokenizer.apply_chat_template( |
393 | | - messages, # type: ignore |
394 | | - tools=tools, # type: ignore |
| 435 | + cast(Any, normalized_messages), |
| 436 | + tools=cast(Any, tools), |
395 | 437 | add_generation_prompt=True, |
396 | 438 | ) |
397 | 439 | if isinstance(encoding, BatchEncoding): |
|
0 commit comments