Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 123 additions & 1 deletion src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
"""

from abc import ABC, abstractmethod
import json
from typing import Optional, Tuple
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
import jax
from flax import nnx
from maxdiffusion.checkpointing.checkpointing_utils import (
add_sharding_to_struct,
create_orbax_checkpoint_manager,
get_cpu_mesh_and_sharding,
)
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
Expand Down Expand Up @@ -50,6 +57,121 @@ def _create_optimizer(self, model, config, learning_rate):
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

@classmethod
def load_pretrained_pipeline_or_diffusers(
cls, config, pipeline_cls, pretrained_state_sources, pretrained_config_transformer_attr
):
"""Load a WAN pipeline from the pretrained Orbax cache, or seed it from diffusers.

This helper is used only for inference when no training checkpoint exists.
`pretrained_config_transformer_attr` is explicit because WAN 2.2 has separate
transformer states but still saves one `wan_config`, matching the existing
training checkpoint format.
"""
pretrained_dir = getattr(config, "pretrained_orbax_dir", "")
if pretrained_dir:
restored_checkpoint = cls._restore_pretrained_checkpoint(
pretrained_dir, tuple(state_item_name for state_item_name, _ in pretrained_state_sources)
)
if restored_checkpoint is not None:
max_logging.log(f"Loading WAN pipeline from pretrained orbax checkpoint at {pretrained_dir}")
return pipeline_cls.from_checkpoint(config, restored_checkpoint)

max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = pipeline_cls.from_pretrained(config)
if pretrained_dir:
cls._save_pretrained_checkpoint(pretrained_dir, pipeline, pretrained_state_sources, pretrained_config_transformer_attr)
return pipeline

@classmethod
def _restore_pretrained_checkpoint(cls, pretrained_dir: str, state_item_names: Tuple[str, ...]):
"""Restore pretrained WAN transformer states and config from an Orbax cache."""
try:
checkpoint_manager = create_orbax_checkpoint_manager(
pretrained_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=WAN_CHECKPOINT,
use_async=False,
)
step = checkpoint_manager.latest_step()
if step is None:
max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}")
return None

max_logging.log(f"Found pretrained orbax checkpoint step {step} in {pretrained_dir}")
metadatas = checkpoint_manager.item_metadata(step)
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
restore_items = {"wan_config": ocp.args.JsonRestore()}
for state_item_name in state_item_names:
restore_items[state_item_name] = cls._standard_restore_arg(
getattr(metadatas, state_item_name), mesh, replicated_sharding
)
return checkpoint_manager.restore(step=step, args=ocp.args.Composite(**restore_items))
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}")
return None

@staticmethod
def _standard_restore_arg(metadata, mesh, replicated_sharding):
target_shardings = jax.tree_util.tree_map(lambda _: replicated_sharding, metadata)
with mesh:
abstract_state = jax.tree_util.tree_map(add_sharding_to_struct, metadata, target_shardings)
return ocp.args.StandardRestore(abstract_state)

@classmethod
def _save_pretrained_checkpoint(
cls, pretrained_dir: str, pipeline, pretrained_state_sources, pretrained_config_transformer_attr
):
"""Save pretrained WAN transformer states to the inference-only Orbax cache."""
try:
max_logging.log(f"Saving pretrained WAN weights to orbax at {pretrained_dir}")
checkpoint_manager = create_orbax_checkpoint_manager(
pretrained_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=WAN_CHECKPOINT,
use_async=False,
)
save_items = cls._pretrained_save_items(pipeline, pretrained_state_sources, pretrained_config_transformer_attr)
checkpoint_manager.save(0, args=ocp.args.Composite(**save_items))
checkpoint_manager.wait_until_finished()
max_logging.log(f"Pretrained weights saved to {pretrained_dir}")
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Failed to save pretrained orbax checkpoint to {pretrained_dir}: {e}")

@staticmethod
def _pretrained_save_items(pipeline, pretrained_state_sources, pretrained_config_transformer_attr):
"""Build Orbax save args for pretrained WAN transformer states.

`pretrained_state_sources` contains `(orbax_item_name, pipeline_attribute)` pairs.
`pretrained_config_transformer_attr` names the transformer whose config should be
serialized as `wan_config`.
"""
pretrained_state_sources = tuple(pretrained_state_sources)
if not pretrained_state_sources:
raise ValueError("pretrained_state_sources must contain at least one transformer source.")

try:
config_transformer = getattr(pipeline, pretrained_config_transformer_attr)
except AttributeError as e:
raise ValueError(
f"Pipeline does not have pretrained config transformer attribute `{pretrained_config_transformer_attr}`."
) from e

items = {}
for state_item_name, transformer_attr in pretrained_state_sources:
try:
transformer = getattr(pipeline, transformer_attr)
except AttributeError as e:
raise ValueError(f"Pipeline does not have pretrained transformer attribute `{transformer_attr}`.") from e

_, state, _ = nnx.split(transformer, nnx.Param, ...)
items[state_item_name] = ocp.args.StandardSave(state.to_pure_dict())

items["wan_config"] = ocp.args.JsonSave(json.loads(config_transformer.to_json_string()))
return items

@abstractmethod
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_animate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
22 changes: 21 additions & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import flax
from maxdiffusion.common_types import WAN2_1, WAN2_2
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader
from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2


def upload_video_to_gcs(output_dir: str, video_path: str):
Expand Down Expand Up @@ -196,18 +200,34 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
load_start = time.perf_counter()
model_type = config.model_type
if model_key == WAN2_1:
pipeline_cls = WanPipelineI2V_2_1 if model_type == "I2V" else WanPipeline2_1
pretrained_state_sources = (("wan_state", "transformer"),)
pretrained_config_transformer_attr = "transformer"
if model_type == "I2V":
checkpoint_loader = WanCheckpointerI2V_2_1(config=config)
else:
checkpoint_loader = WanCheckpointer2_1(config=config)
elif model_key == WAN2_2:
pipeline_cls = WanPipelineI2V_2_2 if model_type == "I2V" else WanPipeline2_2
pretrained_state_sources = (
("low_noise_transformer_state", "low_noise_transformer"),
("high_noise_transformer_state", "high_noise_transformer"),
)
# WAN 2.2 training checkpoints save `wan_config` from the low-noise transformer.
pretrained_config_transformer_attr = "low_noise_transformer"
if model_type == "I2V":
checkpoint_loader = WanCheckpointerI2V_2_2(config=config)
else:
checkpoint_loader = WanCheckpointer2_2(config=config)
else:
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
pipeline, _, _ = checkpoint_loader.load_checkpoint()
checkpoint_step = checkpoint_loader.checkpoint_manager.latest_step()
if checkpoint_step is not None:
pipeline, _, _ = checkpoint_loader.load_checkpoint(checkpoint_step)
else:
pipeline = checkpoint_loader.load_pretrained_pipeline_or_diffusers(
config, pipeline_cls, pretrained_state_sources, pretrained_config_transformer_attr
)
load_time = time.perf_counter() - load_start
max_logging.log(f"load_time: {load_time:.1f}s")
else:
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/generate_wan_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax

from maxdiffusion import max_logging, max_utils, pyconfig
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
from maxdiffusion.pipelines.wan.wan_pipeline_animate import WanAnimatePipeline
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.utils import export_to_video
Expand Down Expand Up @@ -43,7 +44,9 @@ def run(config):
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")

load_start = time.perf_counter()
pipeline = WanAnimatePipeline.from_pretrained(config)
pipeline = WanCheckpointer.load_pretrained_pipeline_or_diffusers(
config, WanAnimatePipeline, (("wan_state", "transformer"),), "transformer"
)
load_time = time.perf_counter() - load_start
max_logging.log(f"load_time: {load_time:.1f}s")

Expand Down
33 changes: 29 additions & 4 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,30 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
return vs


def _select_restored_transformer_state(restored_checkpoint, subfolder: str):
"""Select the transformer state that belongs to a WAN checkpoint restore call.

WAN 2.1 and Wan Animate checkpoints use a single `wan_state`. WAN 2.2 checkpoints
store separate transformer states: diffusers `transformer_2` is the low-noise
transformer, while `transformer` is the high-noise transformer.
"""
checkpoint_keys = restored_checkpoint.keys()
if "wan_state" in checkpoint_keys:
return restored_checkpoint["wan_state"]

if subfolder == "transformer_2":
if "low_noise_transformer_state" not in checkpoint_keys:
raise ValueError("WAN checkpoint is missing `low_noise_transformer_state` for subfolder `transformer_2`.")
return restored_checkpoint["low_noise_transformer_state"]

if subfolder == "transformer":
if "high_noise_transformer_state" not in checkpoint_keys:
raise ValueError("WAN checkpoint is missing `high_noise_transformer_state` for subfolder `transformer`.")
return restored_checkpoint["high_noise_transformer_state"]

raise ValueError(f"Unsupported WAN checkpoint transformer subfolder `{subfolder}`.")


# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
def create_sharded_logical_transformer(
devices_array: np.array,
Expand Down Expand Up @@ -155,13 +179,14 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
state = dict(nnx.to_flat_state(state))

# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
# This helps with loading sharded weights directly into the accelerators without fist copying them
# This helps with loading sharded weights directly into the accelerators without first copying them
# all to one device and then distributing them, thus using low HBM memory.
if restored_checkpoint:
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
params = restored_checkpoint["wan_state"]["params"]
checkpoint_state = _select_restored_transformer_state(restored_checkpoint, subfolder)
if "params" in checkpoint_state: # if checkpointed with optimizer
params = checkpoint_state["params"]
else: # if not checkpointed with optimizer
params = restored_checkpoint["wan_state"]
params = checkpoint_state
else:
params = load_wan_transformer(
config.wan_transformer_pretrained_model_name_or_path,
Expand Down
Loading
Loading