1- # ruff: noqa: I001
2- # Import order is intentional - unsloth MUST be imported before transformers
31import math
42import random
53from dataclasses import dataclass
@@ -25,6 +23,7 @@ class TokenizedResult:
2523 logprobs : list [float ]
2624 pixel_values : torch .Tensor | None
2725 image_grid_thw : torch .Tensor | None
26+ trajectory : Trajectory
2827 weight : float = 0.0
2928 prompt_id : int = 0
3029 prompt_length : int = 0
@@ -40,6 +39,7 @@ def without_prompt(self) -> "TokenizedResult":
4039 logprobs = self .logprobs [self .prompt_length :],
4140 pixel_values = None ,
4241 image_grid_thw = None ,
42+ trajectory = self .trajectory ,
4343 weight = self .weight ,
4444 prompt_id = self .prompt_id ,
4545 prompt_length = 0 ,
@@ -103,6 +103,7 @@ def tokenize_trajectory_groups(
103103 history ,
104104 advantage ,
105105 allow_training_without_logprobs ,
106+ trajectory ,
106107 ):
107108 trajectory_results .append (result )
108109 weight = 1 / (
@@ -151,6 +152,7 @@ def tokenize_trajectory(
151152 history : History ,
152153 advantage : float ,
153154 allow_training_without_logprobs : bool ,
155+ trajectory : Trajectory ,
154156) -> TokenizedResult | None :
155157 """
156158 Tokenizes a trajectory and returns a TokenizedResult.
@@ -165,7 +167,8 @@ def tokenize_trajectory(
165167 ):
166168 last_assistant_index = i
167169 elif not isinstance (message , dict ) and (
168- message .logprobs or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
170+ message .logprobs
171+ or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute]
169172 ):
170173 last_assistant_index = i
171174 # If there are no trainable assistant messages, return None
@@ -238,7 +241,9 @@ def tokenize_trajectory(
238241 continue
239242 if not allow_training_without_logprobs :
240243 continue
241- elif message .logprobs is None and not allow_training_without_logprobs : # ty:ignore[possibly-missing-attribute]
244+ elif (
245+ message .logprobs is None and not allow_training_without_logprobs
246+ ): # ty:ignore[possibly-missing-attribute]
242247 continue
243248 start = token_ids .index (sentinal_token_id )
244249 end = start + 1
@@ -263,12 +268,16 @@ def tokenize_trajectory(
263268 assistant_mask [start :end ] = [1 ] * len (content_token_ids )
264269 else :
265270 choice = message
266- assert choice .logprobs or allow_training_without_logprobs , ( # ty:ignore[possibly-missing-attribute]
271+ assert (
272+ choice .logprobs or allow_training_without_logprobs
273+ ), ( # ty:ignore[possibly-missing-attribute]
267274 "Chat completion choices must have logprobs"
268275 )
269276 if not choice .logprobs : # ty:ignore[possibly-missing-attribute]
270277 continue
271- token_logprobs = choice .logprobs .content or choice .logprobs .refusal or [] # ty:ignore[possibly-missing-attribute]
278+ token_logprobs = (
279+ choice .logprobs .content or choice .logprobs .refusal or []
280+ ) # ty:ignore[possibly-missing-attribute]
272281 if (
273282 bytes (token_logprobs [0 ].bytes or []).decode ("utf-8" )
274283 == "<think>"
@@ -349,6 +358,7 @@ def tokenize_trajectory(
349358 logprobs = logprobs ,
350359 pixel_values = pixel_values ,
351360 image_grid_thw = image_grid_thw ,
361+ trajectory = trajectory ,
352362 )
353363
354364
0 commit comments