Skip to content

Plumb rl.loss_agg_mode to tunix GrpoConfig#4025

Open
py4 wants to merge 1 commit into
mainfrom
pr/loss-agg-mode-plumbing
Open

Plumb rl.loss_agg_mode to tunix GrpoConfig#4025
py4 wants to merge 1 commit into
mainfrom
pr/loss-agg-mode-plumbing

Conversation

@py4
Copy link
Copy Markdown
Collaborator

@py4 py4 commented May 30, 2026

tunix's GrpoConfig defaults loss_agg_mode to 'sequence-mean-token-mean', but GPU NeMo-RL stacks use 'token-mean'. With group-normalized advantages the two modes produce materially different losses (~10× gradient magnitude difference), breaking GPU↔TPU recipe parity when reproducing a GPU recipe on TPU.

This PR exposes the existing tunix knob via the maxtext RL config so users can override on the cmdline: rl.loss_agg_mode=token-mean.

Changes:

  • New Pydantic field rl.loss_agg_mode (Literal: "token-mean", "sequence-mean", "sequence-mean-token-mean")
  • Default 'sequence-mean-token-mean' in rl.yml matches tunix's default → backward compatible (no behavior change for users who don't set it)
  • Plumbed through to GrpoConfig(...) construction in train_rl.py

Empirical impact (Qwen3-1.7B GRPO on GSM8K, GPU recipe LR=2e-6 β=0.04, 50 outer steps):

  • Without override (default sequence-mean-token-mean): grad_norm ~0.5 → underflows, training stalls
  • With rl.loss_agg_mode=token-mean: grad_norm ~5 → matches GPU stack, training converges normally

Checklist

  • Tested locally on TPU v6e 8×8 with Qwen3-1.7B GSM8K
  • Verified token-mean override produces gradient magnitude matching GPU reference
  • Backward compatible: default value preserves existing tunix behavior
  • No effect on non-RL paths (only RL Pydantic config + GRPO training driver touched)

@codecov
Copy link
Copy Markdown

codecov Bot commented May 30, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/configs/types.py Outdated
tunix's GrpoConfig defaults loss_agg_mode to 'sequence-mean-token-mean',
but GPU NeMo-RL stacks use 'token-mean'. With group-normalized advantages
the two modes produce materially different losses, breaking GPU<->TPU
recipe parity.

Adds the field to the RL Pydantic schema + rl.yml default + passes it
through to GrpoConfig construction so users can override via cmdline:
'rl.loss_agg_mode=token-mean'.
@py4 py4 force-pushed the pr/loss-agg-mode-plumbing branch from 0d7759b to ff7eb45 Compare May 30, 2026 00:45
@py4
Copy link
Copy Markdown
Collaborator Author

py4 commented May 30, 2026

Updated both types.py description and rl.yml comment to the simpler form as suggested. PTAL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants