Skip to content

Commit eb6a50d

Browse files
committed
feat: Enhance tokenization and add logprobs support in Tinker
- Introduced `choice_offsets` and `extra_logprobs` to `TokenizedResult` for improved tokenization handling. - Updated `tokenize_trajectory` function to process `Choice` instances and capture log probabilities. - Added new `TinkerAsyncMessagesAndChoices` class to handle messages and choices with log probabilities in the Tinker server. - Implemented a new API endpoint `/v1/messages_and_choices/with_logprobs` to retrieve messages and choices along with their log probabilities. - Refactored server logic to integrate log probability handling for enhanced model interactions.
1 parent be1f6aa commit eb6a50d

3 files changed

Lines changed: 309 additions & 4 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
from typing import Any, Generator, cast
88

9+
from openai.types.chat.chat_completion import Choice
910
from PIL import Image
1011
import torch
1112
from transformers.image_processing_utils import BaseImageProcessor
@@ -41,6 +42,8 @@ class TokenizedResult:
4142
pixel_values: torch.Tensor | None
4243
image_grid_thw: torch.Tensor | None
4344
trajectory: Trajectory
45+
choice_offsets: list[int]
46+
extra_logprobs: dict[str, list[float]]
4447
_tokenizer: "PreTrainedTokenizerBase" = field(repr=False, compare=False)
4548
weight: float = 0.0
4649
prompt_id: int = 0
@@ -63,6 +66,11 @@ def without_prompt(self) -> "TokenizedResult":
6366
pixel_values=None,
6467
image_grid_thw=None,
6568
trajectory=self.trajectory,
69+
choice_offsets=self.choice_offsets,
70+
extra_logprobs={
71+
key: values[self.prompt_length :]
72+
for key, values in self.extra_logprobs.items()
73+
},
6674
_tokenizer=self._tokenizer,
6775
weight=self.weight,
6876
prompt_id=self.prompt_id,
@@ -207,8 +215,8 @@ def tokenize_trajectory(
207215
and allow_training_without_logprobs
208216
):
209217
last_assistant_index = i
210-
elif not isinstance(message, dict) and (
211-
message.logprobs or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
218+
elif isinstance(message, Choice) and (
219+
message.logprobs or allow_training_without_logprobs
212220
):
213221
last_assistant_index = i
214222
# If there are no trainable assistant messages, return None
@@ -265,6 +273,8 @@ def tokenize_trajectory(
265273
)
266274
assistant_mask: list[int] = [0] * len(token_ids)
267275
logprobs = [float("nan")] * len(token_ids)
276+
choice_offsets, choice_token_logprobs = [], []
277+
268278
for message in messages_and_choices:
269279
if isinstance(message, dict):
270280
if message["role"] != "assistant":
@@ -304,12 +314,14 @@ def tokenize_trajectory(
304314
if not choice.logprobs: # ty:ignore[possibly-missing-attribute]
305315
continue
306316
token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] # ty:ignore[possibly-missing-attribute]
307-
if (
317+
if token_logprobs and (
308318
bytes(token_logprobs[0].bytes or []).decode("utf-8")
309319
== "<think>"
310320
== tokenizer.decode(token_ids[start - 4])
311321
):
312322
start -= 4
323+
choice_offsets.append(start)
324+
choice_token_logprobs.append(token_logprobs)
313325
try:
314326
token_ids[start:end] = (
315327
int(token_logprob.token.split(":")[1])
@@ -336,6 +348,18 @@ def tokenize_trajectory(
336348
token_ids.pop(start + len(token_logprobs))
337349
logprobs.pop(start + len(token_logprobs))
338350
assistant_mask.pop(start + len(token_logprobs))
351+
extra_logprobs: dict[str, list[float]] = {}
352+
for start, token_logprobs in zip(choice_offsets, choice_token_logprobs):
353+
for i, token_logprob in enumerate(token_logprobs):
354+
token_extra_logprobs = (token_logprob.model_extra or {}).get(
355+
"extra_logprobs"
356+
)
357+
if not isinstance(token_extra_logprobs, dict):
358+
continue
359+
for key, value in token_extra_logprobs.items():
360+
extra_logprobs.setdefault(key, [float("nan")] * len(token_ids))[
361+
start + i
362+
] = float("nan") if value is None else float(value)
339363
if image_processor:
340364
images: list[Image.Image] = []
341365
for message in messages_and_choices:
@@ -369,6 +393,8 @@ def tokenize_trajectory(
369393
token_ids[start:end] = [image_token_id] * num_image_tokens
370394
logprobs[start:end] = [float("nan")] * num_image_tokens
371395
assistant_mask[start:end] = [0] * num_image_tokens
396+
for values in extra_logprobs.values():
397+
values[start:end] = [float("nan")] * num_image_tokens
372398
pixel_values = result["pixel_values"]
373399
image_grid_thw = result["image_grid_thw"]
374400
else:
@@ -387,6 +413,8 @@ def tokenize_trajectory(
387413
pixel_values=pixel_values,
388414
image_grid_thw=image_grid_thw,
389415
trajectory=trajectory,
416+
choice_offsets=choice_offsets,
417+
extra_logprobs=extra_logprobs,
390418
_tokenizer=tokenizer,
391419
)
392420

src/art/tinker/client.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
3+
from functools import cached_property
4+
from typing import Any, Iterable, Mapping, cast
5+
6+
import httpx
7+
from openai import AsyncOpenAI, BaseModel, _legacy_response
8+
from openai._base_client import make_request_options
9+
from openai._resource import AsyncAPIResource
10+
from openai._response import async_to_streamed_response_wrapper
11+
from openai._types import Body, Headers, NotGiven, Query, not_given
12+
from openai.resources.models import AsyncModels
13+
from openai.types import Model
14+
from openai.types.chat.chat_completion import Choice
15+
from openai.types.completion_usage import CompletionUsage
16+
17+
from art.types import MessageOrChoice, MessagesAndChoices, Tools
18+
19+
ParsedMessageOrChoice = Choice | dict[str, Any]
20+
ParsedMessagesAndChoices = list[ParsedMessageOrChoice]
21+
22+
23+
def _message_or_choice_to_dict(message_or_choice: MessageOrChoice) -> dict[str, Any]:
24+
if isinstance(message_or_choice, dict):
25+
return cast(dict[str, Any], message_or_choice)
26+
return cast(dict[str, Any], message_or_choice.to_dict())
27+
28+
29+
class MessagesAndChoicesWithLogprobs(BaseModel):
30+
messages_and_choices: ParsedMessagesAndChoices
31+
usages: list[CompletionUsage]
32+
33+
34+
class TinkerAsyncModels(AsyncModels):
35+
@cached_property
36+
def with_raw_response(self) -> "TinkerAsyncModelsWithRawResponse":
37+
return TinkerAsyncModelsWithRawResponse(self)
38+
39+
@cached_property
40+
def with_streaming_response(self) -> "TinkerAsyncModelsWithStreamingResponse":
41+
return TinkerAsyncModelsWithStreamingResponse(self)
42+
43+
async def put(
44+
self,
45+
model: str,
46+
*,
47+
target: str,
48+
extra_headers: Headers | None = None,
49+
extra_query: Query | None = None,
50+
extra_body: Body | None = None,
51+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
52+
) -> Model:
53+
if not model:
54+
raise ValueError(
55+
f"Expected a non-empty value for `model` but received {model!r}"
56+
)
57+
58+
return await self._put(
59+
f"/models/{model}",
60+
body={"target": target},
61+
options=make_request_options(
62+
extra_headers=extra_headers,
63+
extra_query=extra_query,
64+
extra_body=extra_body,
65+
timeout=timeout,
66+
),
67+
cast_to=Model,
68+
)
69+
70+
71+
class TinkerAsyncModelsWithRawResponse:
72+
def __init__(self, models: TinkerAsyncModels) -> None:
73+
self._models = models
74+
75+
self.put = _legacy_response.async_to_raw_response_wrapper(models.put)
76+
self.retrieve = _legacy_response.async_to_raw_response_wrapper(models.retrieve)
77+
self.list = _legacy_response.async_to_raw_response_wrapper(models.list)
78+
self.delete = _legacy_response.async_to_raw_response_wrapper(models.delete)
79+
80+
81+
class TinkerAsyncModelsWithStreamingResponse:
82+
def __init__(self, models: TinkerAsyncModels) -> None:
83+
self._models = models
84+
85+
self.put = async_to_streamed_response_wrapper(models.put)
86+
self.retrieve = async_to_streamed_response_wrapper(models.retrieve)
87+
self.list = async_to_streamed_response_wrapper(models.list)
88+
self.delete = async_to_streamed_response_wrapper(models.delete)
89+
90+
91+
class TinkerAsyncMessagesAndChoices(AsyncAPIResource):
92+
@cached_property
93+
def with_raw_response(self) -> "TinkerAsyncMessagesAndChoicesWithRawResponse":
94+
return TinkerAsyncMessagesAndChoicesWithRawResponse(self)
95+
96+
@cached_property
97+
def with_streaming_response(
98+
self,
99+
) -> "TinkerAsyncMessagesAndChoicesWithStreamingResponse":
100+
return TinkerAsyncMessagesAndChoicesWithStreamingResponse(self)
101+
102+
async def with_logprobs(
103+
self,
104+
messages_and_choices: MessagesAndChoices,
105+
*,
106+
models: Iterable[str],
107+
model_aliases: Mapping[str, str] | None = None,
108+
tools: Tools | None,
109+
extra_headers: Headers | None = None,
110+
extra_query: Query | None = None,
111+
extra_body: Body | None = None,
112+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
113+
) -> MessagesAndChoicesWithLogprobs:
114+
return await self._post(
115+
"/messages_and_choices/with_logprobs",
116+
body={
117+
"messages_and_choices": [
118+
_message_or_choice_to_dict(item) for item in messages_and_choices
119+
],
120+
"models": list(models),
121+
"model_aliases": dict(model_aliases or {}),
122+
"tools": tools,
123+
},
124+
options=make_request_options(
125+
extra_headers=extra_headers,
126+
extra_query=extra_query,
127+
extra_body=extra_body,
128+
timeout=timeout,
129+
),
130+
cast_to=MessagesAndChoicesWithLogprobs,
131+
)
132+
133+
134+
class TinkerAsyncMessagesAndChoicesWithRawResponse:
135+
def __init__(self, messages_and_choices: TinkerAsyncMessagesAndChoices) -> None:
136+
self._messages_and_choices = messages_and_choices
137+
138+
self.with_logprobs = _legacy_response.async_to_raw_response_wrapper(
139+
messages_and_choices.with_logprobs
140+
)
141+
142+
143+
class TinkerAsyncMessagesAndChoicesWithStreamingResponse:
144+
def __init__(self, messages_and_choices: TinkerAsyncMessagesAndChoices) -> None:
145+
self._messages_and_choices = messages_and_choices
146+
147+
self.with_logprobs = async_to_streamed_response_wrapper(
148+
messages_and_choices.with_logprobs
149+
)
150+
151+
152+
class TinkerAsyncOpenAI(AsyncOpenAI):
153+
@cached_property
154+
def models(self) -> TinkerAsyncModels:
155+
return TinkerAsyncModels(self)
156+
157+
@cached_property
158+
def messages_and_choices(self) -> TinkerAsyncMessagesAndChoices:
159+
return TinkerAsyncMessagesAndChoices(self)

0 commit comments

Comments
 (0)