Skip to content

[NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities#3836

Draft
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-qk-clip-and-checkpoint-utils
Draft

[NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities#3836
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-qk-clip-and-checkpoint-utils

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 7, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. ❌ Linen↔NNX checkpoint comparator (sibling branch on PR4.5).
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. 🔄 [This PR] NNX-aware QK-Clip + checkpoint utilities. apply_qk_clip_nnx mutates state.model in place (resolves the train.py:517 TODO); NNX paths added to standalone_checkpointer, generate_param_only_checkpoint, convert_gpt3_ckpt_from_paxml. Originally bundled with NNX-AQT and a gpt3 prefill fix; on 2026-05-07 those were split into PR9.5. Stacks on PR8.
    9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix (stacked follow-up).
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

This PR closes the QK-Clip TODO and migrates three Linen-only checkpoint utilities (standalone_checkpointer, generate_param_only_checkpoint, convert_gpt3_ckpt_from_paxml) to NNX. NNX-shape walkers sit alongside the existing Linen ones, dispatching on config.pure_nnx (or runtime state-shape detection); every Linen path is preserved byte-for-byte.

The original PR9 also bundled NNX-AQT + a gpt3 prefill fix; on 2026-05-07 those were split into a stacked follow-up to keep each PR narrowly reviewable. This PR's diff: +907 / −161 across 8 files (5 source + 3 tests, of which 2 are new).

Part 1: NNX-aware QK-Clip

src/maxtext/utils/qk_clip_utils.py factors shared math out of the existing Linen helper and adds an NNX sibling:

  • _max_logits_at, _scale_from_max_logits, _clip_mla_weight — shared across Linen and NNX paths.
  • apply_qk_clip_nnx(state, intermediate_outputs, config) mutates state.model in place via nnx.split → pure-dict tree_mapnnx.replace_by_pure_dictnnx.update. Accepts both the production NNX intermediate shape (self_attention.attention_op.max_logits — sown inside AttentionOp) and the synthetic-fixture shape used by the existing Linen tests (self_attention.max_logits).
  • calculate_max_logit_metric recognizes both the Linen (array,)-tuple shape and the bare-array NNX shape.

src/maxtext/trainers/pre_train/train.py::train_step now branches on isinstance(model, nn.Module) to call apply_qk_clip for Linen and apply_qk_clip_nnx for NNX. The TODO at the QK-Clip call site is removed.

Part 2: NNX-format Checkpoint Utilities

Each utility gets an explicit NNX path; every routing-to-Linen comment is gone.

  • src/maxtext/utils/standalone_checkpointer.py: checkpoint_loop builds an NNX init_state_fn under pure_nnx (mirroring PR8's GRPO trainer). add_entropy_to_checkpoint dispatches across three input shapes (Linen TrainState, NNX TrainStateNNX Module, post-split nnx.State). All three produce identical cos(1000*p) / sin(1000*p) mu/nu replacements.

  • src/maxtext/utils/generate_param_only_checkpoint.py: _read_train_checkpoint builds an NNX init_state_fn under pure_nnx. New _possibly_unroll_params_nnx slices scanned NNX layers via dict-style mutation on state.model.decoder (uses pop / __setitem__ since nnx.State is dict-like). New _save_decode_checkpoint_nnx writes a bf16 pure-dict tree to orbax (matches the NNX-detection branch in from_pretrained). Parallel LoRA decode flow (_generate_lora_decode_checkpoints_nnx + _possibly_unroll_lora_params_nnx + _save_lora_decode_checkpoint_nnx) operates on the single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx ({"decoder": {...}}, no outer params wrap).

  • src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py: parallel NNX state_map keystr translation (.params['params']<rest>.model<rest>.value, .opt_state.mu['params']<rest>.optimizer.opt_state.mu<rest>.value, etc.). Save uses state.optimizer.step.value for the step number on NNX. End-to-end paxml → NNX conversion is wired but not yet validated on hardware (needs a paxml gpt3 checkpoint plus TPU/GPU); only the import / shape-walk parts are exercised in this PR.

Part 3: Carve-outs

Feature Tracked In
NNX-AQT in MaxEngine + serve-mode reload PR9.5 (stacked follow-up on this branch)
gpt3 prefill / non-TRAIN inference (pre-existing bug) PR9.5
End-to-end paxml → NNX gpt3 conversion validation Follow-up (needs paxml gpt3 checkpoint + TPU/GPU)

Tests

New unit tests (tests/unit/qk_clip_test.py — 7 new NNX cases on top of existing Linen suite):

  • QKClipNNXTest: attention-type guard, MLA wq_b / wkv_b math, both intermediate shapes, no-clip-below-threshold, missing-stats resilience, Linen↔NNX numeric parity on identical weights.
  • CalculateMaxLogitNNXTest: bare-array intermediate shape recognition.

New unit tests (tests/unit/standalone_checkpointer_nnx_test.py, 3 tests): adam mu/nu overwrite on TrainStateNNX Module shape, no mutation of state.model params, post-split nnx.State shape from setup_training_state.

New unit tests (tests/unit/generate_param_only_checkpoint_nnx_test.py, 3 tests): Llama-style scanned-layer slicing (single layers group), DeepSeek-style scanned-layer slicing (dense_layers + moe_layers split), LoRA delta unroll on the single-nested NNX-derived shape.

Existing Linen tests: untouched and still pass; pure_nnx=False stays default.

Test results: 23 passed, 2 skipped across the PR9 surface — qk_clip_test, standalone_checkpointer_nnx_test, generate_param_only_checkpoint_nnx_test.

Linting: bash lint.sh — pyink + pylint clean.

Stats

  • Diff: +907 / −161 across 8 files (2 new, 6 modified).
  • Production-code impact: Linen behavior preserved; every NNX edit is gated on config.pure_nnx or runtime state-shape detection.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 7, 2026

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 2 times, most recently from 68eb7ce to 02ff5f7 Compare May 7, 2026 20:12
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities + NNX-AQT [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities May 7, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch from 02ff5f7 to b5fd654 Compare May 7, 2026 21:53
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 12 times, most recently from 2a7775a to 6748af8 Compare May 14, 2026 22:51
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 9 times, most recently from 99b7f9d to ee99e98 Compare May 22, 2026 21:10
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch from ee99e98 to a4b9db9 Compare May 25, 2026 15:26
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch 8 times, most recently from 55c710e to 8c6d306 Compare May 29, 2026 22:10
…acked prefill cache)

PR7 (NNX-native MaxEngine inference) made the core prefill/generate/insert
path work under pure_nnx=True but left three serving features raising
NotImplementedError on the NNX path. This promotes all three to NNX-native.
Linen is preserved byte-for-byte: the original model.apply(..., mutable=["cache"])
calls are unchanged, just moved into else: branches, and every NNX edit is
gated `if config.pure_nnx:`.

maxengine.py:
- _prefill_multisampling_jit: drops the NotImplementedError; adds a pure_nnx
  branch that runs prefill through _nnx_run_model (MODEL_MODE_PREFILL, batch=1)
  with a fresh _nnx_init_cache_dict. The loop that draws num_samples first
  tokens from the shared logits is unchanged.
- prefill_concat: same swap; the packed positions and segment ids thread
  through _nnx_run_model unchanged.
- stack_prefill_result_cache=True: now supported for both scan_layers values.
  scan_layers=True already stacks the per-layer KV cache on axis 0 (the Linen
  post-stack shape), so _maybe_stack/_maybe_unstack_prefill_result_cache are
  no-ops and prefill_kv_cache_shardings stays the full tree. scan_layers=False
  keeps unstacked per-layer subtrees under cache["decoder"]["layers"][i] (int
  keys), so _maybe_stack stacks them into one subtree with a leading layer axis,
  _maybe_unstack splits it back into the int-keyed per-layer dict that
  bulk_insert/_insert_jit walk, and _load_params_nnx prepends a layer axis to
  each prefix-sharding spec (the NNX analog of the Linen P(None, *spec) +
  ["decoder"]["layers_0"] reshape).

tests/integration/maxengine_test.py:
- New _build_linen_params helper and a shared _stack_prefill_roundtrip helper.
- test_prefill_multisampling_nnx, test_prefill_concat_nnx: NNX vs Linen
  result-shape parity, finite logits + cache.
- test_stack_prefill_result_cache_nnx (scan_layers=True) and
  test_stack_prefill_result_cache_scan_layers_false_nnx (scan_layers=False):
  prefill -> insert -> generate round-trip, layer-stacked leaves, finite
  logits, next_pos advances.

Remaining NNX MaxEngine carve-outs are quantization (PR9) and LoRA (PR8),
which are other PRs' scope.
Closes the QK-Clip TODO and migrates the remaining Linen-only
checkpoint utilities to NNX. Linen paths preserved byte-for-byte
(every NNX edit is gated on `config.pure_nnx` or runtime state-shape
detection).

QK-Clip:
- qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via
  nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict ->
  nnx.update. Accepts both the production NNX intermediate shape
  (self_attention.attention_op.max_logits) and the synthetic-fixture
  shape from the existing Linen tests (self_attention.max_logits).
- train.py train_step dispatches to apply_qk_clip_nnx for NNX,
  removing the prior TODO at the QK-Clip call site.

Checkpoint utilities (NNX paths added):
- standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn
  under pure_nnx; add_entropy_to_checkpoint dispatches across Linen
  TrainState, NNX TrainStateNNX Module, and post-split nnx.State
  shapes.
- generate_param_only_checkpoint: NNX init_state_fn under pure_nnx;
  _possibly_unroll_params_nnx slices scanned NNX layers via dict-style
  state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict
  tree to orbax. Parallel LoRA decode flow operates on the
  single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx.
- convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr
  translation (.params['params']<rest> -> .model<rest>.value, etc.).
  End-to-end paxml -> NNX conversion is wired but not yet validated
  on hardware.

Tests:
- qk_clip_test: 7 new NNX cases covering attention-type guard, MLA
  wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold,
  missing-stats resilience, Linen<->NNX numeric parity.
- standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu
  overwrite on TrainStateNNX Module shape, no mutation of state.model
  params, post-split nnx.State shape from setup_training_state.
- generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned
  layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA
  delta unroll on the single-nested NNX shape).

NNX + AQT in MaxEngine and the layerwise_quantization NNX path are
split into the follow-up PR9.5.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-qk-clip-and-checkpoint-utils branch from 8c6d306 to 6847046 Compare May 30, 2026 03:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant