Skip to content

Commit 0005d7d

Browse files
committed
feat: add top-k entropy approximation for memory-efficient GRPO training
When training models with large vocabularies (128k+ tokens), computing entropy over the full vocabulary is a major memory bottleneck. This adds a `top_k_entropy` config parameter (default 0 = disabled) that computes entropy over only the top-k logits, dramatically reducing memory usage. Also skips entropy computation entirely for reference model logprobs since entropy is unused in the KL divergence calculation. https://claude.ai/code/session_017Y9KNNQX2RyVWnqpj3A4hh
1 parent e57ab66 commit 0005d7d

1 file changed

Lines changed: 36 additions & 18 deletions

File tree

src/art/unsloth/train.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def compute_loss(
113113
) # Shape [H, V]
114114
next_input_ids = shift_tensor(inputs["tokens"], 0)
115115
chunk_size = _config.get("logprob_calculation_chunk_size", 1024)
116+
top_k_entropy = _config.get("top_k_entropy", 0)
116117
# Assert that sequence length is evenly divisible by the chunk size
117118
assert seq_len % chunk_size == 0, (
118119
f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
@@ -135,6 +136,7 @@ def compute_loss(
135136
inference_mode=return_new_logprobs,
136137
no_grad=return_new_logprobs,
137138
reference_logprobs=False,
139+
top_k_entropy=top_k_entropy,
138140
)
139141
if return_new_logprobs:
140142
return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0)
@@ -151,6 +153,7 @@ def compute_loss(
151153
inference_mode=True,
152154
no_grad=False,
153155
reference_logprobs=True,
156+
top_k_entropy=0, # Don't compute entropy for reference model
154157
)
155158
else:
156159
ref_logprobs = None
@@ -260,9 +263,10 @@ def calculate_logprobs(
260263
inference_mode: bool,
261264
no_grad: bool,
262265
reference_logprobs: bool,
266+
top_k_entropy: int = 0,
263267
) -> tuple[
264-
torch.Tensor, torch.Tensor
265-
]: # Returns (log_probs, entropy) both shape [B, S]
268+
torch.Tensor, torch.Tensor | None
269+
]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
266270
with (
267271
torch.inference_mode() if inference_mode else nullcontext(),
268272
torch.no_grad() if no_grad else nullcontext(),
@@ -278,29 +282,36 @@ def calculate_logprobs(
278282
hidden_states = trainer.model( # type: ignore
279283
input_ids=input_ids, causal_mask=causal_mask, **forward_kwargs
280284
).logits # Shape [B, S, H]
281-
return _calculate_logprobs(lm_head_t, hidden_states, next_input_ids, chunk_size)
285+
return _calculate_logprobs(
286+
lm_head_t, hidden_states, next_input_ids, chunk_size, top_k_entropy
287+
)
282288

283289

284290
def _calculate_logprobs(
285291
lm_head_t: torch.Tensor, # Shape [H, V]
286292
hidden_states: torch.Tensor, # Shape [B, S, H]
287293
next_input_ids: torch.Tensor, # Shape [B, S]
288294
chunk_size: int,
295+
top_k_entropy: int = 0,
289296
) -> tuple[
290-
torch.Tensor, torch.Tensor
291-
]: # Returns (log_probs, entropy) both shape [B, S]
297+
torch.Tensor, torch.Tensor | None
298+
]: # Returns (log_probs, entropy) where entropy is shape [B, S] or None
292299
batch_size, seq_len, _ = hidden_states.shape
293300
# Output shape is [B, S]
294301
log_probs = torch.empty(
295302
(batch_size, seq_len),
296303
dtype=hidden_states.dtype,
297304
device=hidden_states.device,
298305
)
299-
entropy = torch.empty(
300-
(batch_size, seq_len),
301-
dtype=hidden_states.dtype,
302-
device=hidden_states.device,
303-
)
306+
# Only allocate entropy tensor if we're computing it
307+
if top_k_entropy > 0:
308+
entropy = torch.empty(
309+
(batch_size, seq_len),
310+
dtype=hidden_states.dtype,
311+
device=hidden_states.device,
312+
)
313+
else:
314+
entropy = None
304315
# Ensure lm_head_t is in the same dtype as hidden_states
305316
lm_head_t = lm_head_t.to(hidden_states.dtype)
306317

@@ -315,21 +326,28 @@ def _calculate_logprobs(
315326
chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size]
316327
log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp
317328

318-
# Compute entropy for the chunk
319-
log_probs_full = chunk_logits - chunk_logsumexp.unsqueeze(-1)
320-
chunk_entropy = (-torch.exp(log_probs_full) * log_probs_full).sum(
321-
dim=-1
322-
) # [B, chunk_size]
323-
entropy[:, i : i + chunk_size] = chunk_entropy
329+
# Compute entropy for the chunk (only if top_k_entropy > 0)
330+
if top_k_entropy > 0:
331+
# Use top-k approximation for entropy
332+
topk_logits, _ = torch.topk(
333+
chunk_logits, k=min(top_k_entropy, chunk_logits.size(-1)), dim=-1
334+
) # [B, chunk_size, k]
335+
topk_logsumexp = torch.logsumexp(
336+
topk_logits, dim=-1, keepdim=True
337+
) # [B, chunk_size, 1]
338+
log_probs_topk = topk_logits - topk_logsumexp # [B, chunk_size, k]
339+
chunk_entropy = (-torch.exp(log_probs_topk) * log_probs_topk).sum(
340+
dim=-1
341+
) # [B, chunk_size]
342+
entropy[:, i : i + chunk_size] = chunk_entropy
343+
del topk_logits, topk_logsumexp, log_probs_topk, chunk_entropy
324344

325345
del (
326346
chunk_hs,
327347
chunk_input_ids,
328348
chunk_logits,
329349
chunk_selected_logits,
330350
chunk_logsumexp,
331-
log_probs_full,
332-
chunk_entropy,
333351
)
334352
del hidden_states
335353
return log_probs, entropy

0 commit comments

Comments
 (0)