Skip to content

[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844

Open
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-aqt-maxengine
Open

[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-aqt-maxengine

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 7, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. ❌ Linen↔NNX checkpoint comparator (sibling branch on PR4.5).
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + remaining checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. 🔄 [This PR] NNX + AQT in MaxEngine: pre-quantized loading (checkpoint_is_quantized=True) via quant_mode_str="serve", convert-on-load via TRAIN-mode AQT, and a pre-existing gpt3 prefill / non-TRAIN inference bug fix (Gpt3MultiHeadAttention.__call__ was missing update_kv_caches). Split out of original PR9 on 2026-05-07. Stacks on PR9; PR9 and PR9.5 are file-disjoint.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

This PR migrates the NNX + AQT integration in MaxEngine so pure_nnx=True can both load pre-quantized checkpoints directly and convert full-precision checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill / autoregressive bug surfaced by the AQT end-to-end validation.

Originally part of PR9; split into its own follow-up on 2026-05-07 because the AQT chain has 5 chained QTensor / sharding / restore-target bugs that warrant focused review independent of the QK-Clip + checkpoint-utilities work in PR9.

Diff: +637 / −57 across 8 files (5 source + 3 tests, of which 2 are new). Stacks on PR9 (feat/nnx-qk-clip-and-checkpoint-utils); both halves are file-disjoint, so PR9.5 could equally well sibling-target PR8 once PR9 lands.

Part 1: NNX + AQT in MaxEngine — two paths

  • Pre-quantized load (checkpoint_is_quantized=True): from_pretrained(quant_mode_str="serve") reads the on-disk qrhs.frozen directly so AQT layers don't materialize the full-precision kernel.
  • Convert-on-load (checkpoint_is_quantized=False + quantization=int8): full-precision kernels load normally, AQT layers quantize per-forward against them. Same numerical result as serve mode for absmax calibration; slower but correct.

Threaded quant_mode_str ("train" | "convert" | "serve") through from_configcreate_modelget_nnx_create_model_fncreate_nnx_abstract_modelfrom_pretrained. Default "train" preserves existing callers; "serve" propagates to configure_quantization. maxengine.__init__ selects the quant mode from config.checkpoint_is_quantized; _load_params_nnx drops its NotImplementedError.

Part 2: _load_and_quantize_nnx — NNX whole-model convert path

src/maxtext/utils/layerwise_quantization.py:

  • Loads full-precision in TRAIN mode via from_pretrained(quant_mode_str="train").
  • Builds a separate CONVERT-mode model and copies kernels into it via _copy_kernel_leaves_.
  • Runs a forward — the ToNNX(AqtDotGeneral) bridge auto-captures qrhs.frozen per flax/nnx/bridge/wrappers.py:230-243.
  • Strips kernels at quantized paths via _strip_kernels_at_quantized_paths.
  • Saves serve-mode-shaped state.

The DeepSeek-only assertion is lifted for NNX since the whole-model approach is decoder-agnostic.

Part 3: Sharding helpers + from_pretrained QTensor handling — 5 chained fixes

The serve-mode reload chain hit five surface bugs in NNX/AQT-serve interaction. All closed here:

  1. Sharding helper for QTensor leaves (maxtext_utils.get_nnx_named_sharding_with_scan_axis): emits a parallel-tree of replicated NamedSharding leaves when a Variable's value is a composite pytree (AQT serve-mode QTensor with an int8 qvalue leaf and a list of bf16 scale leaves). Previously returned the Variable as-is when val had no .shape, leaving ShapeDtypeStruct leaves where the downstream jax.ShapeDtypeStruct(..., sharding=s) call expected Shardings.
  2. Variable indexing on QTensor: _build_value_target, _free_device_memory, and _unwrap_for_align in from_pretrained now use Variable.get_value() instead of v[...]. QTensor's __getitem__ calls qvalue[idx] on a LogicallyPartitioned wrapper — that fails. Composite leaves now flow through unchanged.
  3. Filter widening: both from_pretrained's NNX-detection branch and maxengine._load_params_nnx previously filtered sharded_state to nnx.Param only, dropping AQT qrhs.frozen leaves (which are stored as a separate aqt Variable type, not a Param subclass). They now filter to "everything except nnx.RngState and nnx.Cache". _load_params_nnx also adds a 4-way nnx.split + overlay step so the loaded aqt-typed leaves survive into _nnx_rest_state.
  4. Partitioned-unwrap for QTensor leaf paths: the abstract NNX model's QTensor qvalue / scale come back wrapped in LogicallyPartitioned. Under jax.tree.flatten_with_path, that wrapper adds an extra GetAttrKey('value') to every leaf — so the restore target's tree path looks like qrhs.frozen.value.qvalue.value, but _load_and_quantize_nnx flushes the QTensor as plain arrays at qrhs.frozen.value.qvalue (no extra .value). Orbax silently filled the missing paths with the model's init values (qvalue=0, scale=1 — exactly the symptom we saw). _build_value_target now strips Partitioned wrappers around composite-leaf values so the tree path matches the on-disk layout.
  5. Shape-alignment crash on QTensor: _walk_align previously called ckpt_arr.shape on every leaf, which hit qvalue.shape on a LogicallyPartitioned. Composite leaves are now passed through unchanged in the per-axis alignment dispatch — quantized payloads are saved at the exact model shape, no alignment needed.

Also dropped a redundant jax.set_mesh(mesh) wrap inside create_nnx_abstract_model's nnx.eval_shape call. Under jax.set_mesh, Flax 0.12.6's _to_variable rejects serve-mode AQT variables because they hit NamedSharding(mesh=AbstractMesh, spec=None). Sharding is resolved afterwards via get_nnx_named_sharding_with_scan_axis, so the wrap was redundant; removing it lets serve-mode model construction reach the orbax restore step.

Part 4: gpt3 Prefill / Autoregressive Fix

A pre-existing gpt3 bug surfaced when validating the AQT pre-quantized load end-to-end: Gpt3MultiHeadAttention.__call__ (src/maxtext/models/gpt3.py) invoked self.attention_op(...) without ever calling update_kv_caches to build cached_values, so any non-TRAIN forward (prefill or autoregressive) tripped the assert prefill_kv_cache check at the top of AttentionOp.__call__. Affects every gpt3 inference call regardless of quantization; included here because the AQT e2e validation exercises this path.

Mirrors the standard Attention class plumbing in attentions.py:

  • __init__ constructs a KVCache_0 module when model_mode != MODEL_MODE_TRAIN, sized from max_prefill_predict_length / max_target_length / batch / num_heads / head_dim.
  • __init__ also threads max_prefill_predict_length into AttentionOp (was previously left at the -1 default, breaking the prefill-cache shape sizing).
  • __call__ calls self.KVCache_0(...) to produce [prefill_kv_cache, ar_kv_cache] and passes that as the cached_values argument to attention_op.

TRAIN-mode shape unchanged (KVCache_0 stays None, no extra parameters).

Tests

New unit tests (tests/unit/layerwise_quantization_nnx_test.py, 3 tests): _strip_kernels_at_quantized_paths covering quantized-kernel removal, non-quantized-kernel preservation (norms, embeddings), mixed-shape trees.

New unit tests (tests/unit/aqt_serve_roundtrip_nnx_test.py, 1 test — end-to-end regression): builds a small NNX model in CONVERT mode with int8, runs a forward to populate qrhs.frozen via the ToNNX bridge, saves the serve-mode-shape state to a tmp local orbax checkpoint, reloads via from_pretrained(quant_mode_str="serve"), and asserts every saved qrhs.frozen.qvalue array byte-matches what came back. Guards the full chain of QTensor / Partitioned / filter fixes. Runs on CPU under DECOUPLE_GCLOUD=TRUE.

Modified test (tests/unit/maxengine_test.py): test_quantize_raises_for_nnx (asserted NotImplementedError) replaced by test_quantize_passes_gate_for_nnx (verifies the convert-on-load path reaches from_pretrained in TRAIN mode). Added test_load_pre_quantized_nnx_passes_quant_gate (verifies checkpoint_is_quantized=True reaches from_pretrained in SERVE mode) and test_quantized_prefill_nnx_train_mode (full prefill with quantization=int8 + random params + TRAIN mode produces finite logits — real numerical verification).

Existing Linen tests: untouched and still pass.

End-to-end on TPU (gpt3-52k): convert-mode forward → qrhs.frozen extraction → serve-mode-shape save to orbax → reload via from_pretrained(quant_mode_str="serve") → quantized prefill forward → finite logits. Save-side qvalue nonzero_frac=0.99x; reload preserves bytes exactly.

Linting: bash lint.sh — pyink + pylint clean.

Stats

  • Diff: +637 / −57 across 8 files (2 new, 6 modified).
  • Production-code impact: Linen behavior preserved; every NNX edit is gated on config.pure_nnx or runtime state-shape detection. The gpt3 KVCache plumbing is gated on model_mode != MODEL_MODE_TRAIN, so TRAIN-mode shape is unchanged.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch from ca957ab to e173538 Compare May 7, 2026 21:53
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch 15 times, most recently from 31ac0e6 to 88417d0 Compare May 14, 2026 22:51
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch 9 times, most recently from 54f4f9d to 71525e7 Compare May 21, 2026 19:58
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch 2 times, most recently from d4bcba4 to b3dd0c1 Compare May 25, 2026 15:26
Closes the QK-Clip TODO and migrates the remaining Linen-only
checkpoint utilities to NNX. Linen paths preserved byte-for-byte
(every NNX edit is gated on `config.pure_nnx` or runtime state-shape
detection).

QK-Clip:
- qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via
  nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict ->
  nnx.update. Accepts both the production NNX intermediate shape
  (self_attention.attention_op.max_logits) and the synthetic-fixture
  shape from the existing Linen tests (self_attention.max_logits).
- train.py train_step dispatches to apply_qk_clip_nnx for NNX,
  removing the prior TODO at the QK-Clip call site.

Checkpoint utilities (NNX paths added):
- standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn
  under pure_nnx; add_entropy_to_checkpoint dispatches across Linen
  TrainState, NNX TrainStateNNX Module, and post-split nnx.State
  shapes.
- generate_param_only_checkpoint: NNX init_state_fn under pure_nnx;
  _possibly_unroll_params_nnx slices scanned NNX layers via dict-style
  state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict
  tree to orbax. Parallel LoRA decode flow operates on the
  single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx.
- convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr
  translation (.params['params']<rest> -> .model<rest>.value, etc.).
  End-to-end paxml -> NNX conversion is wired but not yet validated
  on hardware.

Tests:
- qk_clip_test: 7 new NNX cases covering attention-type guard, MLA
  wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold,
  missing-stats resilience, Linen<->NNX numeric parity.
- standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu
  overwrite on TrainStateNNX Module shape, no mutation of state.model
  params, post-split nnx.State shape from setup_training_state.
- generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned
  layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA
  delta unroll on the single-nested NNX shape).

NNX + AQT in MaxEngine and the layerwise_quantization NNX path are
split into the follow-up PR9.5.
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.

NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
  "serve") through from_config / create_model /
  get_nnx_create_model_fn / create_nnx_abstract_model /
  from_pretrained. Default "train" preserves existing callers; "serve"
  propagates to configure_quantization so AQT layers don't materialize
  the full-precision kernel when the on-disk checkpoint already
  carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
  config.checkpoint_is_quantized; _load_params_nnx drops its
  NotImplementedError. Two paths: pre-quantized
  (checkpoint_is_quantized=True) loads via quant_mode_str="serve";
  full-precision + quantization=int8 loads in TRAIN mode and AQT
  layers quantize per-forward (same numerical result for absmax
  calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
  convert path. Loads full-precision in TRAIN mode, transfers kernels
  into a CONVERT-mode model, runs forward to populate qrhs.frozen via
  the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
  saves serve-mode-shaped state.

Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
  parallel-tree of replicated NamedSharding leaves when a Variable's
  value is a composite pytree (AQT serve-mode QTensor with a qvalue
  int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
  jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
  AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
  _unwrap_for_align use Variable.get_value() instead of v[...]
  indexing for QTensor leaves (QTensor.__getitem__ trips on the
  LogicallyPartitioned wrapper around qvalue). Widens the restore
  filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
  type. Skips QTensor leaves in the per-axis shape-alignment dispatch
  (their saved shape already matches the model). _build_value_target
  strips Partitioned wrappers around composite-leaf values so the
  restore tree path matches the on-disk layout (LogicallyPartitioned
  was adding an extra .value key under each QTensor leaf, which made
  orbax silently fill the path with zero-init values).

gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
  ever calling update_kv_caches to build cached_values, so any
  non-TRAIN forward (prefill or autoregressive) tripped the
  `assert prefill_kv_cache` check. Mirror the standard Attention
  plumbing in attentions.py: __init__ constructs a KVCache_0 module
  when model_mode != MODEL_MODE_TRAIN, threads
  max_prefill_predict_length into AttentionOp; __call__ calls
  self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
  cached_values to attention_op. TRAIN-mode shape unchanged.

Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
  _strip_kernels_at_quantized_paths covering quantized removal,
  non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
  builds a small NNX model in CONVERT mode with int8, runs a forward
  to populate qrhs.frozen via the ToNNX bridge, saves the
  serve-mode-shape state to a tmp local orbax checkpoint, reloads via
  from_pretrained(quant_mode_str="serve"), and asserts every saved
  qrhs.frozen.qvalue array byte-matches what came back. Guards the
  full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
  test_quantize_passes_gate_for_nnx; added
  test_load_pre_quantized_nnx_passes_quant_gate and
  test_quantized_prefill_nnx_train_mode (real numerical verification
  with quantization=int8 + random params + TRAIN mode).

End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant