Update offline distillation and saving top-k teacher logits to be efficient#3990
Update offline distillation and saving top-k teacher logits to be efficient#3990ajkv-google wants to merge 16 commits into
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| log_t_p_T_sparse = jax.nn.log_softmax(t_logits / temperature, axis=-1) | ||
|
|
||
| # 2. Student log-probs must be computed over the FULL vocabulary to be mathematically valid | ||
| log_s_T_full = jax.nn.log_softmax(s_logits / temperature, axis=-1) |
There was a problem hiding this comment.
should we also compute teacher logits softmax over entire vocabulary before saving to files?
There was a problem hiding this comment.
We save the raw logits instead of the full softmax so we can still tweak the temperature on the fly during training. Also, saving full-vocab probabilities would make the files bigger, which could reduce read speeds during offline training. I think as of now, it would be good to stick to the current approach where we store the information that is needed, and read it quickly during the offline training.
There was a problem hiding this comment.
"lso, saving full-vocab probabilities would make the files bigger" - why do you consider saving the full vocab? the entire idea of offline is to use a limited set of logits.
My concern is purely mathematical - you normalize only over top-k while the student logits will be normalized over entire vocabulary, and then you calculate kl divergence over those distributions with completely different normalization scales.
…e new efficient version
Description
This PR implements an optimized Offline Distillation pipeline in MaxText, allowing student models to train using pre-saved teacher logits to significantly reduce compute costs.
Key changes:
save_top_k_teacher_logits.pyscript now supports parallel execution across hosts and uses asynchronous GCS uploads to prevent TPU idlingdistillation_utils.pythat computes KL divergence using only the teacher's top-k predictions when running offline distillationTests
Used the following to run offline distillation:
Used the following to run saving of top-k teacher logits:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.