[NNX] NNX migration prep (8/N): NNX native lora grpo#3824
Open
ecnal-cienet wants to merge 2 commits into
Open
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
03c43e2 to
626ce66
Compare
This was referenced May 7, 2026
78049f9 to
6c65652
Compare
Draft
4 tasks
b47ad17 to
82af9cb
Compare
47d28ee to
2b3f99f
Compare
…acked prefill cache) PR7 (NNX-native MaxEngine inference) made the core prefill/generate/insert path work under pure_nnx=True but left three serving features raising NotImplementedError on the NNX path. This promotes all three to NNX-native. Linen is preserved byte-for-byte: the original model.apply(..., mutable=["cache"]) calls are unchanged, just moved into else: branches, and every NNX edit is gated `if config.pure_nnx:`. maxengine.py: - _prefill_multisampling_jit: drops the NotImplementedError; adds a pure_nnx branch that runs prefill through _nnx_run_model (MODEL_MODE_PREFILL, batch=1) with a fresh _nnx_init_cache_dict. The loop that draws num_samples first tokens from the shared logits is unchanged. - prefill_concat: same swap; the packed positions and segment ids thread through _nnx_run_model unchanged. - stack_prefill_result_cache=True: now supported for both scan_layers values. scan_layers=True already stacks the per-layer KV cache on axis 0 (the Linen post-stack shape), so _maybe_stack/_maybe_unstack_prefill_result_cache are no-ops and prefill_kv_cache_shardings stays the full tree. scan_layers=False keeps unstacked per-layer subtrees under cache["decoder"]["layers"][i] (int keys), so _maybe_stack stacks them into one subtree with a leading layer axis, _maybe_unstack splits it back into the int-keyed per-layer dict that bulk_insert/_insert_jit walk, and _load_params_nnx prepends a layer axis to each prefix-sharding spec (the NNX analog of the Linen P(None, *spec) + ["decoder"]["layers_0"] reshape). tests/integration/maxengine_test.py: - New _build_linen_params helper and a shared _stack_prefill_roundtrip helper. - test_prefill_multisampling_nnx, test_prefill_concat_nnx: NNX vs Linen result-shape parity, finite logits + cache. - test_stack_prefill_result_cache_nnx (scan_layers=True) and test_stack_prefill_result_cache_scan_layers_false_nnx (scan_layers=False): prefill -> insert -> generate round-trip, layer-stacked leaves, finite logits, next_pos advances. Remaining NNX MaxEngine carve-outs are quantization (PR9) and LoRA (PR8), which are other PRs' scope.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
4.6. ❌ Linen↔NNX checkpoint comparator (sibling branch on PR4.5).
apply_lora_on_base_params_nnx/unapply_lora_from_base_params_nnx/get_lora_abstract_state_nnx(the maxenginepure_nnx + LoRAcarve-out from PR7 is cleared); NNX-native GRPO trainer viagrpo_loss_fn_nnx+compute_log_probs_nnx+ NNXsetup_train_loop/train_step/eval_steppaths. Stacks on PR7.9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
custom_vjpfor NNX.True; regenerate sharding goldens; flip back integration-testpure_nnx=Falseannotations.Description
This PR implements NNX-native LoRA serving and NNX-native GRPO by adding NNX-shape walkers and step helpers alongside the existing Linen ones, then dispatching on
config.pure_nnx. Every NNX modification is gated byif config.pure_nnx:, preserving the Linen path byte-for-byte. The diff spans +551 / −84 across 5 source files, plus 2 new test files (515 lines).Part 1: NNX-shape LoRA Walkers
New helpers in
src/maxtext/utils/lora_utils.pyoperating onnnx.Statepure trees (no{"params": ...}outer wrap):apply_lora_on_base_params_nnxmutatesbase_paramsin place:W += B @ A * scaleat target attention pathsunapply_lora_from_base_params_nnxis the symmetric inverseget_lora_abstract_state_nnxwalks the abstractstate.modelsubstate and emits a parallel tree withlora_a.kernel/lora_b.kernelleaves at target attention paths andNoneelsewhere_nnx_param_subtreedrops the outerTrainStateNNXwrappingThe base model stays pristine; "apply" merges the delta into the kernel, "unapply" reverses. No
nnx.LoRAwrapper, no model surgery. The on-disk format (HuggingFace PEFT-stylelora_a.kernel/lora_b.kernel) round-trips between Linen and NNX consumers unchanged.Part 2: LoRA Dispatch in
setup_initial_lora_stateandload_adapterBoth top-level entry points in
lora_utils.pybranch onconfig.pure_nnx:model_creation_utils.create_nnx_abstract_model+TrainStateNNX(model, optimizer)init_initial_state+get_lora_abstract_statepath, untouchedPart 3: MaxEngine LoRA Carve-out Cleared
src/maxtext/inference/maxengine/maxengine.py:load_single_adapterno longer raisesNotImplementedErroronpure_nnxapply_adapter/unapply_adapterbranch onconfig.pure_nnxto call the_nnxsiblingsPart 4: GRPO Loss and Step Helpers
src/maxtext/experimental/rl/grpo_trainer.py:grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train). Signature matches Linengrpo_loss_fnso callers dispatch on the same shape.dropout_rngandparamsare unused on NNX;reference_modelis a frozennnx.Moduleand the reference forward is wrapped instop_gradient. Returns(loss, LossAux), same dataclass as Linen._train_step_nnx:nnx.merge(graphdef, state)to reconstructTrainStateNNX,value_and_gradover policy params,state.apply_gradients(grads), returnnnx.state(new_state, nnx.Not(nnx.Intermediate))._eval_step_nnx: same merge + loss-fn call, no state update.train_step/eval_stepearly-dispatch onconfig.pure_nnx; Linen branches verbatim.Part 5: GRPO setup_train_loop on NNX
grpo_trainer.py::setup_train_loop:mt.from_config(rngs=create_nnx_rngs(...))create_nnx_abstract_model+TrainStateNNX(model, optimizer, reference_model=...)apply_gradients(sibling field onTrainStateNNX, not embedded inparams)WARNING: GRPO RL trainer does not yet support pure_nnx nativelylog is removedPart 6: GRPO train_loop NNX Branches
grpo_trainer.py::train_loop— three Linen-coupled spots branched onpure_nnx:init_state_fn)metric_logger.write_setup_info_to_tensorboardreceives a flatnnx.Paramstate on NNXTrainStateNNXon NNX; the Linen_split_grpo_state(state)[0]strip is bypassedThe reshard call routes to
pathways_reshard_nnxwhenpure_nnx. New helpers ingrpo_utils.py:compute_log_probs_nnx: NNX model is called directly; intermediates pulled viannx.state(model, nnx.Intermediate).to_pure_dict()pathways_reshard_nnx: splitsstate.modelto a flatnnx.Paramstate, reshards onto the inference mesh, callsinference_engine.update_params(...)Part 7: Carve-outs (NotImplementedError Sites)
gradient_accumulation_steps > 1scan_layers=FalseTests
New unit tests (
tests/unit/lora_utils_nnx_test.py, 10 tests):get_lora_abstract_state_nnx: q/k/v/o shape derivation, target-vs-non-target masking, sharding propagation, leaf type validation, error pathsapply_lora_on_base_params_nnx: apply→unapply identity, target-only mutation, numerical parity vs Linenapply_lora_on_base_paramson the same random inputsapply_lora_on_base_paramsandunapply_lora_from_base_params(no existing unit test for these helpers in the tree)New unit tests (
tests/unit/grpo_nnx_test.py, 8 tests):grpo_loss_fn_nnx:LossAuxshape parity, signature compatibility, identical-policy/reference → zero KL,grpo_beta=0→aux.avg_kl=None, finite policy gradscompute_log_probs_nnx: shape[B, S] → [B, S-1]grpo_loss_fnandcompute_log_probs(the existing Linen integration test is TPU-only and currently@pytest.mark.skip)Modified test:
tests/unit/maxengine_test.pyswapstest_lora_raises_for_nnx(assertedNotImplementedError) fortest_lora_load_single_adapter_reaches_loader_on_nnx(assertsFileNotFoundErrorfrom the loader).Existing Linen tests: untouched and still pass;
pure_nnx=Falsestays default.Test results: 198 passed, 1 skipped (pre-existing CPU-only skip) across the broader NNX regression sweep —
maxengine_test,dpo_nnx_test,train_nnx_test,lora_utils_nnx_test,grpo_nnx_test,train_state_nnx_test,train_utils_nnx_test,gradient_accumulation_nnx_test,linen_nnx_converter_test,compare_linen_nnx_checkpoint_test.Linting:
bash lint.sh— pyink + pylint 10.00/10.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.