Skip to content

Update offline distillation and saving top-k teacher logits to be efficient#3990

Open
ajkv-google wants to merge 16 commits into
mainfrom
ajkv/offline-distillation-branch
Open

Update offline distillation and saving top-k teacher logits to be efficient#3990
ajkv-google wants to merge 16 commits into
mainfrom
ajkv/offline-distillation-branch

Conversation

@ajkv-google
Copy link
Copy Markdown
Collaborator

Description

This PR implements an optimized Offline Distillation pipeline in MaxText, allowing student models to train using pre-saved teacher logits to significantly reduce compute costs.

Key changes:

  • The save_top_k_teacher_logits.py script now supports parallel execution across hosts and uses asynchronous GCS uploads to prevent TPU idling
  • Offline distillation now uses grain input pipeline without having to use the custom offline arrayrecord iterator. This speeds up the offline distillation process without having to read logits on a single thread.
  • Introduced sparse KL divergence in distillation_utils.py that computes KL divergence using only the teacher's top-k predictions when running offline distillation

Tests

Used the following to run offline distillation:

Used the following to run saving of top-k teacher logits:

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ajkv-google ajkv-google requested a review from igorts-git as a code owner May 27, 2026 17:36
@ajkv-google ajkv-google changed the title Update offline distillation and saving top-k teacher logits to be efficient reliable Update offline distillation and saving top-k teacher logits to be efficient May 27, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented May 27, 2026

log_t_p_T_sparse = jax.nn.log_softmax(t_logits / temperature, axis=-1)

# 2. Student log-probs must be computed over the FULL vocabulary to be mathematically valid
log_s_T_full = jax.nn.log_softmax(s_logits / temperature, axis=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also compute teacher logits softmax over entire vocabulary before saving to files?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We save the raw logits instead of the full softmax so we can still tweak the temperature on the fly during training. Also, saving full-vocab probabilities would make the files bigger, which could reduce read speeds during offline training. I think as of now, it would be good to stick to the current approach where we store the information that is needed, and read it quickly during the offline training.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"lso, saving full-vocab probabilities would make the files bigger" - why do you consider saving the full vocab? the entire idea of offline is to use a limited set of logits.
My concern is purely mathematical - you normalize only over top-k while the student logits will be normalized over entire vocabulary, and then you calculate kl divergence over those distributions with completely different normalization scales.

Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants