Skip to content

Commit 520aba6

Browse files
committed
feat: enhance TokenizedResult and tokenize_trajectory functions
- Added trajectory attribute to TokenizedResult for improved data handling. - Updated tokenize_trajectory and tokenize_trajectory_groups functions to incorporate trajectory parameter, enhancing the tokenization process. - Improved code readability by restructuring conditional statements and assertions.
1 parent 176632c commit 520aba6

1 file changed

Lines changed: 16 additions & 6 deletions

File tree

src/art/preprocessing/tokenize.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# ruff: noqa: I001
2-
# Import order is intentional - unsloth MUST be imported before transformers
31
import math
42
import random
53
from dataclasses import dataclass
@@ -25,6 +23,7 @@ class TokenizedResult:
2523
logprobs: list[float]
2624
pixel_values: torch.Tensor | None
2725
image_grid_thw: torch.Tensor | None
26+
trajectory: Trajectory
2827
weight: float = 0.0
2928
prompt_id: int = 0
3029
prompt_length: int = 0
@@ -40,6 +39,7 @@ def without_prompt(self) -> "TokenizedResult":
4039
logprobs=self.logprobs[self.prompt_length :],
4140
pixel_values=None,
4241
image_grid_thw=None,
42+
trajectory=self.trajectory,
4343
weight=self.weight,
4444
prompt_id=self.prompt_id,
4545
prompt_length=0,
@@ -103,6 +103,7 @@ def tokenize_trajectory_groups(
103103
history,
104104
advantage,
105105
allow_training_without_logprobs,
106+
trajectory,
106107
):
107108
trajectory_results.append(result)
108109
weight = 1 / (
@@ -151,6 +152,7 @@ def tokenize_trajectory(
151152
history: History,
152153
advantage: float,
153154
allow_training_without_logprobs: bool,
155+
trajectory: Trajectory,
154156
) -> TokenizedResult | None:
155157
"""
156158
Tokenizes a trajectory and returns a TokenizedResult.
@@ -165,7 +167,8 @@ def tokenize_trajectory(
165167
):
166168
last_assistant_index = i
167169
elif not isinstance(message, dict) and (
168-
message.logprobs or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
170+
message.logprobs
171+
or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
169172
):
170173
last_assistant_index = i
171174
# If there are no trainable assistant messages, return None
@@ -238,7 +241,9 @@ def tokenize_trajectory(
238241
continue
239242
if not allow_training_without_logprobs:
240243
continue
241-
elif message.logprobs is None and not allow_training_without_logprobs: # ty:ignore[possibly-missing-attribute]
244+
elif (
245+
message.logprobs is None and not allow_training_without_logprobs
246+
): # ty:ignore[possibly-missing-attribute]
242247
continue
243248
start = token_ids.index(sentinal_token_id)
244249
end = start + 1
@@ -263,12 +268,16 @@ def tokenize_trajectory(
263268
assistant_mask[start:end] = [1] * len(content_token_ids)
264269
else:
265270
choice = message
266-
assert choice.logprobs or allow_training_without_logprobs, ( # ty:ignore[possibly-missing-attribute]
271+
assert (
272+
choice.logprobs or allow_training_without_logprobs
273+
), ( # ty:ignore[possibly-missing-attribute]
267274
"Chat completion choices must have logprobs"
268275
)
269276
if not choice.logprobs: # ty:ignore[possibly-missing-attribute]
270277
continue
271-
token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] # ty:ignore[possibly-missing-attribute]
278+
token_logprobs = (
279+
choice.logprobs.content or choice.logprobs.refusal or []
280+
) # ty:ignore[possibly-missing-attribute]
272281
if (
273282
bytes(token_logprobs[0].bytes or []).decode("utf-8")
274283
== "<think>"
@@ -349,6 +358,7 @@ def tokenize_trajectory(
349358
logprobs=logprobs,
350359
pixel_values=pixel_values,
351360
image_grid_thw=image_grid_thw,
361+
trajectory=trajectory,
352362
)
353363

354364

0 commit comments

Comments
 (0)