@@ -113,6 +113,7 @@ def compute_loss(
113113 ) # Shape [H, V]
114114 next_input_ids = shift_tensor (inputs ["tokens" ], 0 )
115115 chunk_size = _config .get ("logprob_calculation_chunk_size" , 1024 )
116+ top_k_entropy = _config .get ("top_k_entropy" , 0 )
116117 # Assert that sequence length is evenly divisible by the chunk size
117118 assert seq_len % chunk_size == 0 , (
118119 f"Sequence length ({ seq_len } ) must be evenly divisible by chunk size ({ chunk_size } )"
@@ -135,6 +136,7 @@ def compute_loss(
135136 inference_mode = return_new_logprobs ,
136137 no_grad = return_new_logprobs ,
137138 reference_logprobs = False ,
139+ top_k_entropy = top_k_entropy ,
138140 )
139141 if return_new_logprobs :
140142 return torch .nn .functional .pad (new_logprobs [:, :- 1 ], (1 , 0 ), value = 0.0 )
@@ -151,6 +153,7 @@ def compute_loss(
151153 inference_mode = True ,
152154 no_grad = False ,
153155 reference_logprobs = True ,
156+ top_k_entropy = 0 , # Don't compute entropy for reference model
154157 )
155158 else :
156159 ref_logprobs = None
@@ -260,9 +263,10 @@ def calculate_logprobs(
260263 inference_mode : bool ,
261264 no_grad : bool ,
262265 reference_logprobs : bool ,
266+ top_k_entropy : int = 0 ,
263267) -> tuple [
264- torch .Tensor , torch .Tensor
265- ]: # Returns (log_probs, entropy) both shape [B, S]
268+ torch .Tensor , torch .Tensor | None
269+ ]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
266270 with (
267271 torch .inference_mode () if inference_mode else nullcontext (),
268272 torch .no_grad () if no_grad else nullcontext (),
@@ -278,29 +282,36 @@ def calculate_logprobs(
278282 hidden_states = trainer .model ( # type: ignore
279283 input_ids = input_ids , causal_mask = causal_mask , ** forward_kwargs
280284 ).logits # Shape [B, S, H]
281- return _calculate_logprobs (lm_head_t , hidden_states , next_input_ids , chunk_size )
285+ return _calculate_logprobs (
286+ lm_head_t , hidden_states , next_input_ids , chunk_size , top_k_entropy
287+ )
282288
283289
284290def _calculate_logprobs (
285291 lm_head_t : torch .Tensor , # Shape [H, V]
286292 hidden_states : torch .Tensor , # Shape [B, S, H]
287293 next_input_ids : torch .Tensor , # Shape [B, S]
288294 chunk_size : int ,
295+ top_k_entropy : int = 0 ,
289296) -> tuple [
290- torch .Tensor , torch .Tensor
291- ]: # Returns (log_probs, entropy) both shape [B, S]
297+ torch .Tensor , torch .Tensor | None
298+ ]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
292299 batch_size , seq_len , _ = hidden_states .shape
293300 # Output shape is [B, S]
294301 log_probs = torch .empty (
295302 (batch_size , seq_len ),
296303 dtype = hidden_states .dtype ,
297304 device = hidden_states .device ,
298305 )
299- entropy = torch .empty (
300- (batch_size , seq_len ),
301- dtype = hidden_states .dtype ,
302- device = hidden_states .device ,
303- )
306+ # Only allocate entropy tensor if we're computing it
307+ if top_k_entropy > 0 :
308+ entropy = torch .empty (
309+ (batch_size , seq_len ),
310+ dtype = hidden_states .dtype ,
311+ device = hidden_states .device ,
312+ )
313+ else :
314+ entropy = None
304315 # Ensure lm_head_t is in the same dtype as hidden_states
305316 lm_head_t = lm_head_t .to (hidden_states .dtype )
306317
@@ -315,21 +326,28 @@ def _calculate_logprobs(
315326 chunk_logsumexp = torch .logsumexp (chunk_logits , dim = - 1 ) # [B, chunk_size]
316327 log_probs [:, i : i + chunk_size ] = chunk_selected_logits - chunk_logsumexp
317328
318- # Compute entropy for the chunk
319- log_probs_full = chunk_logits - chunk_logsumexp .unsqueeze (- 1 )
320- chunk_entropy = (- torch .exp (log_probs_full ) * log_probs_full ).sum (
321- dim = - 1
322- ) # [B, chunk_size]
323- entropy [:, i : i + chunk_size ] = chunk_entropy
329+ # Compute entropy for the chunk (only if top_k_entropy > 0)
330+ if top_k_entropy > 0 :
331+ # Use top-k approximation for entropy
332+ topk_logits , _ = torch .topk (
333+ chunk_logits , k = min (top_k_entropy , chunk_logits .size (- 1 )), dim = - 1
334+ ) # [B, chunk_size, k]
335+ topk_logsumexp = torch .logsumexp (
336+ topk_logits , dim = - 1 , keepdim = True
337+ ) # [B, chunk_size, 1]
338+ log_probs_topk = topk_logits - topk_logsumexp # [B, chunk_size, k]
339+ chunk_entropy = (- torch .exp (log_probs_topk ) * log_probs_topk ).sum (
340+ dim = - 1
341+ ) # [B, chunk_size]
342+ entropy [:, i : i + chunk_size ] = chunk_entropy
343+ del topk_logits , topk_logsumexp , log_probs_topk , chunk_entropy
324344
325345 del (
326346 chunk_hs ,
327347 chunk_input_ids ,
328348 chunk_logits ,
329349 chunk_selected_logits ,
330350 chunk_logsumexp ,
331- log_probs_full ,
332- chunk_entropy ,
333351 )
334352 del hidden_states
335353 return log_probs , entropy
0 commit comments