Skip to content

Commit 14b8711

Browse files
saurabhbikramclaude
andcommitted
feat: add top-k entropy approximation for memory-efficient GRPO training
Preserve full-vocab entropy as default (top_k_entropy=0), only use top-k approximation when explicitly configured. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0005d7d commit 14b8711

1 file changed

Lines changed: 20 additions & 16 deletions

File tree

src/art/unsloth/train.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def compute_loss(
153153
inference_mode=True,
154154
no_grad=False,
155155
reference_logprobs=True,
156-
top_k_entropy=0, # Don't compute entropy for reference model
156+
top_k_entropy=top_k_entropy,
157157
)
158158
else:
159159
ref_logprobs = None
@@ -265,8 +265,8 @@ def calculate_logprobs(
265265
reference_logprobs: bool,
266266
top_k_entropy: int = 0,
267267
) -> tuple[
268-
torch.Tensor, torch.Tensor | None
269-
]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
268+
torch.Tensor, torch.Tensor
269+
]: # Returns (log_probs, entropy) both shape [B, S]
270270
with (
271271
torch.inference_mode() if inference_mode else nullcontext(),
272272
torch.no_grad() if no_grad else nullcontext(),
@@ -294,24 +294,20 @@ def _calculate_logprobs(
294294
chunk_size: int,
295295
top_k_entropy: int = 0,
296296
) -> tuple[
297-
torch.Tensor, torch.Tensor | None
298-
]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
297+
torch.Tensor, torch.Tensor
298+
]: # Returns (log_probs, entropy) both shape [B, S]
299299
batch_size, seq_len, _ = hidden_states.shape
300300
# Output shape is [B, S]
301301
log_probs = torch.empty(
302302
(batch_size, seq_len),
303303
dtype=hidden_states.dtype,
304304
device=hidden_states.device,
305305
)
306-
# Only allocate entropy tensor if we're computing it
307-
if top_k_entropy > 0:
308-
entropy = torch.empty(
309-
(batch_size, seq_len),
310-
dtype=hidden_states.dtype,
311-
device=hidden_states.device,
312-
)
313-
else:
314-
entropy = None
306+
entropy = torch.empty(
307+
(batch_size, seq_len),
308+
dtype=hidden_states.dtype,
309+
device=hidden_states.device,
310+
)
315311
# Ensure lm_head_t is in the same dtype as hidden_states
316312
lm_head_t = lm_head_t.to(hidden_states.dtype)
317313

@@ -326,9 +322,9 @@ def _calculate_logprobs(
326322
chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size]
327323
log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp
328324

329-
# Compute entropy for the chunk (only if top_k_entropy > 0)
325+
# Compute entropy for the chunk
330326
if top_k_entropy > 0:
331-
# Use top-k approximation for entropy
327+
# Use top-k approximation for memory-efficient entropy
332328
topk_logits, _ = torch.topk(
333329
chunk_logits, k=min(top_k_entropy, chunk_logits.size(-1)), dim=-1
334330
) # [B, chunk_size, k]
@@ -341,6 +337,14 @@ def _calculate_logprobs(
341337
) # [B, chunk_size]
342338
entropy[:, i : i + chunk_size] = chunk_entropy
343339
del topk_logits, topk_logsumexp, log_probs_topk, chunk_entropy
340+
else:
341+
# Full-vocabulary entropy (original behavior)
342+
log_probs_full = chunk_logits - chunk_logsumexp.unsqueeze(-1)
343+
chunk_entropy = (-torch.exp(log_probs_full) * log_probs_full).sum(
344+
dim=-1
345+
) # [B, chunk_size]
346+
entropy[:, i : i + chunk_size] = chunk_entropy
347+
del log_probs_full, chunk_entropy
344348

345349
del (
346350
chunk_hs,

0 commit comments

Comments
 (0)