@@ -153,7 +153,7 @@ def compute_loss(
153153 inference_mode = True ,
154154 no_grad = False ,
155155 reference_logprobs = True ,
156- top_k_entropy = 0 , # Don't compute entropy for reference model
156+ top_k_entropy = top_k_entropy ,
157157 )
158158 else :
159159 ref_logprobs = None
@@ -265,8 +265,8 @@ def calculate_logprobs(
265265 reference_logprobs : bool ,
266266 top_k_entropy : int = 0 ,
267267) -> tuple [
268- torch .Tensor , torch .Tensor | None
269- ]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
268+ torch .Tensor , torch .Tensor
269+ ]: # Returns (log_probs, entropy) both shape [B, S]
270270 with (
271271 torch .inference_mode () if inference_mode else nullcontext (),
272272 torch .no_grad () if no_grad else nullcontext (),
@@ -294,24 +294,20 @@ def _calculate_logprobs(
294294 chunk_size : int ,
295295 top_k_entropy : int = 0 ,
296296) -> tuple [
297- torch .Tensor , torch .Tensor | None
298- ]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
297+ torch .Tensor , torch .Tensor
298+ ]: # Returns (log_probs, entropy) both shape [B, S]
299299 batch_size , seq_len , _ = hidden_states .shape
300300 # Output shape is [B, S]
301301 log_probs = torch .empty (
302302 (batch_size , seq_len ),
303303 dtype = hidden_states .dtype ,
304304 device = hidden_states .device ,
305305 )
306- # Only allocate entropy tensor if we're computing it
307- if top_k_entropy > 0 :
308- entropy = torch .empty (
309- (batch_size , seq_len ),
310- dtype = hidden_states .dtype ,
311- device = hidden_states .device ,
312- )
313- else :
314- entropy = None
306+ entropy = torch .empty (
307+ (batch_size , seq_len ),
308+ dtype = hidden_states .dtype ,
309+ device = hidden_states .device ,
310+ )
315311 # Ensure lm_head_t is in the same dtype as hidden_states
316312 lm_head_t = lm_head_t .to (hidden_states .dtype )
317313
@@ -326,9 +322,9 @@ def _calculate_logprobs(
326322 chunk_logsumexp = torch .logsumexp (chunk_logits , dim = - 1 ) # [B, chunk_size]
327323 log_probs [:, i : i + chunk_size ] = chunk_selected_logits - chunk_logsumexp
328324
329- # Compute entropy for the chunk (only if top_k_entropy > 0)
325+ # Compute entropy for the chunk
330326 if top_k_entropy > 0 :
331- # Use top-k approximation for entropy
327+ # Use top-k approximation for memory-efficient entropy
332328 topk_logits , _ = torch .topk (
333329 chunk_logits , k = min (top_k_entropy , chunk_logits .size (- 1 )), dim = - 1
334330 ) # [B, chunk_size, k]
@@ -341,6 +337,14 @@ def _calculate_logprobs(
341337 ) # [B, chunk_size]
342338 entropy [:, i : i + chunk_size ] = chunk_entropy
343339 del topk_logits , topk_logsumexp , log_probs_topk , chunk_entropy
340+ else :
341+ # Full-vocabulary entropy (original behavior)
342+ log_probs_full = chunk_logits - chunk_logsumexp .unsqueeze (- 1 )
343+ chunk_entropy = (- torch .exp (log_probs_full ) * log_probs_full ).sum (
344+ dim = - 1
345+ ) # [B, chunk_size]
346+ entropy [:, i : i + chunk_size ] = chunk_entropy
347+ del log_probs_full , chunk_entropy
344348
345349 del (
346350 chunk_hs ,
0 commit comments