Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8c7a057
NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx)
ecnal-cienet Apr 29, 2026
2048e1e
NNX: native MaxEngine inference (drop route-to-Linen path in maxengin…
ecnal-cienet May 5, 2026
7927d4c
NNX: native LoRA + GRPO (drop maxengine LoRA carve-out, drop GRPO pur…
ecnal-cienet May 6, 2026
5fa5ce1
NNX: QK-Clip on NNX + NNX-format checkpoint utilities
ecnal-cienet May 7, 2026
8e1be5b
NNX: AQT in MaxEngine + serve-mode reload + gpt3 prefill fix
ecnal-cienet May 7, 2026
2a172cc
NNX: vocab tiling custom_vjp with output-head carve-out
ecnal-cienet May 8, 2026
e183c68
tests: pin Linen-only vocab tiling and pipeline tests for upcoming NN…
ecnal-cienet May 8, 2026
63c7840
NNX: flip pure_nnx/enable_nnx/pure_nnx_decoder defaults to True
ecnal-cienet May 8, 2026
27b6b7f
fix tests/unit/train_compile_test.py::TrainCompile::test_remat_save_q…
hsuan-lun-chiang May 19, 2026
a4c5f1b
Temp: tests/unit/train_compile_test.py::TrainCompile::test_qk_clip_do…
hsuan-lun-chiang May 19, 2026
9918595
fix cpu UT failure
May 19, 2026
25b6496
fix gpu UT failures
May 20, 2026
05cc14d
Fix tests/unit/muon_utils_test.py::TestGetMuonWeightDimensionNumbersN…
hsuan-lun-chiang May 20, 2026
c2b6545
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
hsuan-lun-chiang May 20, 2026
423f3b9
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
hsuan-lun-chiang May 20, 2026
606a8ca
Fix test compatibility with pure_nnx=True defaults
hsuan-lun-chiang May 20, 2026
c2c63b8
Fix DiLoCo test simulation for NNX: update state filtering, fix unpac…
hsuan-lun-chiang May 21, 2026
7d8f2fc
Update
hsuan-lun-chiang May 21, 2026
82bf13d
Update generate_param_only_checkpoint_test.py
mesakhcienet May 25, 2026
92999e4
Fix NNX checkpoint restore and Linen test compatibility with pure_nnx…
mesakhcienet May 25, 2026
f9f1efb
Fix param-only checkpoint to filter nnx.Param only, excluding RNG state
mesakhcienet May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import os
import sys

from flax import nnx
import jax
from jax import random
from jax.sharding import Mesh
Expand All @@ -48,11 +49,15 @@
from maxtext.common import checkpointing
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.layers import train_state_nnx
from maxtext.models.models import transformer_as_linen
from maxtext.optimizers import optimizers
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils
from maxtext.utils import train_utils
import numpy as np
from psutil import Process
import tensorstore as ts
Expand Down Expand Up @@ -87,12 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# Output is Linen-format (keystr_map below uses Linen tree paths). Route to
# Linen regardless of pure_nnx.
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
if cfg.pure_nnx:
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
_, tx = train_utils.create_training_optimizer(cfg, model)
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)

def init_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

else:
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
cfg.checkpoint_dir,
Expand All @@ -101,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down Expand Up @@ -186,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
}

state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}
if cfg.pure_nnx:
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
# model params -> ['model']<rest>.value
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
# step -> ['optimizer']['step'].value
# opt count -> ['optimizer']['opt_state']['count'].value
state_map = {
".optimizer.step.value": ("step", None),
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
}
else:
state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}

def get_layer_prefix(keystr_pax):
# different path format between decoder_layer variable
Expand All @@ -201,19 +227,27 @@ def get_layer_prefix(keystr_pax):
return prefix_pax_opt_state

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
# model variable
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
# first momentum in optimizer state
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
# second momentum in optimizer state
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
if cfg.pure_nnx:
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
else:
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
Expand Down Expand Up @@ -265,10 +299,11 @@ def map_fn(key_path, value):
max_logging.log("converted state finished")
max_utils.print_mem_stats("converted state finished")

if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
max_logging.log(f"saved a checkpoint at step {step_value}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
if checkpoint_manager.reached_preemption(step_value):
checkpoint_manager.wait_until_finished()
sys.exit()

Expand Down
31 changes: 22 additions & 9 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,23 +313,36 @@ def get_maxtext_model_info(config):
# Get abstract model structure (name, shape) without materializing the weights to save memory
abstract_params_tree = maxtext_utils.get_abstract_param(maxtext_model_flax, config)["params"]

abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_params_tree)
# Standardize abstract tree for later unflattening
abstract_params_tree = jax.tree.map(
lambda _: 0,
abstract_params_tree,
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
abstract_params_flat, abstract_params_treedef = jax.tree_util.tree_flatten_with_path(
abstract_params_tree, is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned)
)
abstract_params_treedef = jax.tree_util.tree_structure(abstract_params_tree)

max_logging.log("MaxText abstract model and state initialized.")

# preprocess state
maxtext_abstract_dict = {}
for mt_target_idx, (path_tuple, abstract_leaf_value) in enumerate(abstract_params_flat):
key_parts = [k.key for k in path_tuple if hasattr(k, "key")]
key_parts = []
for k in path_tuple:
# JAX path components can be DictKey(key), GetItemKey(key), or SequenceKey(idx).
# We prefer string keys. If we see an integer or digit-string index, we assume it's
# a layer/block index and join it with the previous part using '_', matching
# MaxText's Linen-style naming convention (e.g., layers_0).
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)

val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and key_parts:
key_parts[-1] = f"{key_parts[-1]}_{val_str}"
else:
key_parts.append(val_str)

mt_param_key = "params-" + "-".join(key_parts)
mt_target_shape = abstract_leaf_value.shape
if isinstance(abstract_leaf_value, nn.LogicallyPartitioned):
mt_target_shape = abstract_leaf_value.value.shape
else:
mt_target_shape = abstract_leaf_value.shape
maxtext_abstract_dict[mt_param_key] = (mt_target_idx, mt_target_shape)

return maxtext_abstract_dict, abstract_params_treedef
Expand Down
24 changes: 22 additions & 2 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,17 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = []
for k in path_tuple:
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)
val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and path_keys:
path_keys[-1] = f"{path_keys[-1]}_{val_str}"
else:
path_keys.append(val_str)

# Skip NNX RNG state variables (not model weights)
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
continue
Expand Down Expand Up @@ -909,7 +919,17 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = []
for k in path_tuple:
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)
val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and path_keys:
path_keys[-1] = f"{path_keys[-1]}_{val_str}"
else:
path_keys.append(val_str)

# Construct maxtext_param_key from path_tuple
maxtext_param_key = "params-" + "-".join(path_keys)
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
Expand Down
12 changes: 9 additions & 3 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,17 @@ def _load_full_state_from_path(
The loaded state.
"""

# Convert nnx.State to pure dict to match how NNX checkpoints are saved
# (maybe_save_checkpoint calls state.to_pure_dict() before saving).
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

if enable_orbax_v1:
if source_checkpoint_layout == "orbax":
context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.ORBAX)
with context:
return ocp_v1.load_pytree(path, abstract_unboxed_pre_state)
return ocp_v1.load_pytree(path, restore_target)
elif source_checkpoint_layout == "safetensors":
context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS)
with context:
Expand Down Expand Up @@ -226,9 +232,9 @@ def combine_sharding(sds, shardings):
)
# Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
restore_args = jax.tree_util.tree_map(
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), restore_target
)
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
return ocp.Checkpointer(handler).restore(p, restore_target, restore_args=restore_args)


def create_orbax_checkpoint_manager(
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1180,9 +1180,9 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: False
pure_nnx_decoder: False
pure_nnx: False
enable_nnx: True
pure_nnx_decoder: True
pure_nnx: True

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
Loading
Loading