diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 4ab909715..e02072f9c 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,8 +15,15 @@ """ from abc import ABC, abstractmethod +import json from typing import Optional, Tuple -from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) +import jax +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import ( + add_sharding_to_struct, + create_orbax_checkpoint_manager, + get_cpu_mesh_and_sharding, +) from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 @@ -50,6 +57,121 @@ def _create_optimizer(self, model, config, learning_rate): tx = max_utils.create_optimizer(config, learning_rate_scheduler) return tx, learning_rate_scheduler + @classmethod + def load_pretrained_pipeline_or_diffusers( + cls, config, pipeline_cls, pretrained_state_sources, pretrained_config_transformer_attr + ): + """Load a WAN pipeline from the pretrained Orbax cache, or seed it from diffusers. + + This helper is used only for inference when no training checkpoint exists. + `pretrained_config_transformer_attr` is explicit because WAN 2.2 has separate + transformer states but still saves one `wan_config`, matching the existing + training checkpoint format. + """ + pretrained_dir = getattr(config, "pretrained_orbax_dir", "") + if pretrained_dir: + restored_checkpoint = cls._restore_pretrained_checkpoint( + pretrained_dir, tuple(state_item_name for state_item_name, _ in pretrained_state_sources) + ) + if restored_checkpoint is not None: + max_logging.log(f"Loading WAN pipeline from pretrained orbax checkpoint at {pretrained_dir}") + return pipeline_cls.from_checkpoint(config, restored_checkpoint) + + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = pipeline_cls.from_pretrained(config) + if pretrained_dir: + cls._save_pretrained_checkpoint(pretrained_dir, pipeline, pretrained_state_sources, pretrained_config_transformer_attr) + return pipeline + + @classmethod + def _restore_pretrained_checkpoint(cls, pretrained_dir: str, state_item_names: Tuple[str, ...]): + """Restore pretrained WAN transformer states and config from an Orbax cache.""" + try: + checkpoint_manager = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + step = checkpoint_manager.latest_step() + if step is None: + max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}") + return None + + max_logging.log(f"Found pretrained orbax checkpoint step {step} in {pretrained_dir}") + metadatas = checkpoint_manager.item_metadata(step) + mesh, replicated_sharding = get_cpu_mesh_and_sharding() + restore_items = {"wan_config": ocp.args.JsonRestore()} + for state_item_name in state_item_names: + restore_items[state_item_name] = cls._standard_restore_arg( + getattr(metadatas, state_item_name), mesh, replicated_sharding + ) + return checkpoint_manager.restore(step=step, args=ocp.args.Composite(**restore_items)) + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}") + return None + + @staticmethod + def _standard_restore_arg(metadata, mesh, replicated_sharding): + target_shardings = jax.tree_util.tree_map(lambda _: replicated_sharding, metadata) + with mesh: + abstract_state = jax.tree_util.tree_map(add_sharding_to_struct, metadata, target_shardings) + return ocp.args.StandardRestore(abstract_state) + + @classmethod + def _save_pretrained_checkpoint( + cls, pretrained_dir: str, pipeline, pretrained_state_sources, pretrained_config_transformer_attr + ): + """Save pretrained WAN transformer states to the inference-only Orbax cache.""" + try: + max_logging.log(f"Saving pretrained WAN weights to orbax at {pretrained_dir}") + checkpoint_manager = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + save_items = cls._pretrained_save_items(pipeline, pretrained_state_sources, pretrained_config_transformer_attr) + checkpoint_manager.save(0, args=ocp.args.Composite(**save_items)) + checkpoint_manager.wait_until_finished() + max_logging.log(f"Pretrained weights saved to {pretrained_dir}") + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Failed to save pretrained orbax checkpoint to {pretrained_dir}: {e}") + + @staticmethod + def _pretrained_save_items(pipeline, pretrained_state_sources, pretrained_config_transformer_attr): + """Build Orbax save args for pretrained WAN transformer states. + + `pretrained_state_sources` contains `(orbax_item_name, pipeline_attribute)` pairs. + `pretrained_config_transformer_attr` names the transformer whose config should be + serialized as `wan_config`. + """ + pretrained_state_sources = tuple(pretrained_state_sources) + if not pretrained_state_sources: + raise ValueError("pretrained_state_sources must contain at least one transformer source.") + + try: + config_transformer = getattr(pipeline, pretrained_config_transformer_attr) + except AttributeError as e: + raise ValueError( + f"Pipeline does not have pretrained config transformer attribute `{pretrained_config_transformer_attr}`." + ) from e + + items = {} + for state_item_name, transformer_attr in pretrained_state_sources: + try: + transformer = getattr(pipeline, transformer_attr) + except AttributeError as e: + raise ValueError(f"Pipeline does not have pretrained transformer attribute `{transformer_attr}`.") from e + + _, state, _ = nnx.split(transformer, nnx.Param, ...) + items[state_item_name] = ocp.args.StandardSave(state.to_pure_dict()) + + items["wan_config"] = ocp.args.JsonSave(json.loads(config_transformer.to_json_string())) + return items + @abstractmethod def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: raise NotImplementedError diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 319bfbc72..9ccc9d8b4 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -290,6 +290,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 3134ed93d..f10a02292 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -242,6 +242,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index dfe300ddf..206737d91 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -258,6 +258,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 7b3334c79..7503716c9 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -253,6 +253,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index f722e04e2..4c1567fc6 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -256,6 +256,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 0aa533b40..dddf01804 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -257,6 +257,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 8c3c30b87..87d38fdaf 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -30,6 +30,10 @@ import flax from maxdiffusion.common_types import WAN2_1, WAN2_2 from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader +from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 +from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 +from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 +from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 def upload_video_to_gcs(output_dir: str, video_path: str): @@ -196,18 +200,34 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): load_start = time.perf_counter() model_type = config.model_type if model_key == WAN2_1: + pipeline_cls = WanPipelineI2V_2_1 if model_type == "I2V" else WanPipeline2_1 + pretrained_state_sources = (("wan_state", "transformer"),) + pretrained_config_transformer_attr = "transformer" if model_type == "I2V": checkpoint_loader = WanCheckpointerI2V_2_1(config=config) else: checkpoint_loader = WanCheckpointer2_1(config=config) elif model_key == WAN2_2: + pipeline_cls = WanPipelineI2V_2_2 if model_type == "I2V" else WanPipeline2_2 + pretrained_state_sources = ( + ("low_noise_transformer_state", "low_noise_transformer"), + ("high_noise_transformer_state", "high_noise_transformer"), + ) + # WAN 2.2 training checkpoints save `wan_config` from the low-noise transformer. + pretrained_config_transformer_attr = "low_noise_transformer" if model_type == "I2V": checkpoint_loader = WanCheckpointerI2V_2_2(config=config) else: checkpoint_loader = WanCheckpointer2_2(config=config) else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") - pipeline, _, _ = checkpoint_loader.load_checkpoint() + checkpoint_step = checkpoint_loader.checkpoint_manager.latest_step() + if checkpoint_step is not None: + pipeline, _, _ = checkpoint_loader.load_checkpoint(checkpoint_step) + else: + pipeline = checkpoint_loader.load_pretrained_pipeline_or_diffusers( + config, pipeline_cls, pretrained_state_sources, pretrained_config_transformer_attr + ) load_time = time.perf_counter() - load_start max_logging.log(f"load_time: {load_time:.1f}s") else: diff --git a/src/maxdiffusion/generate_wan_animate.py b/src/maxdiffusion/generate_wan_animate.py index d2d88473a..fa253cbe3 100644 --- a/src/maxdiffusion/generate_wan_animate.py +++ b/src/maxdiffusion/generate_wan_animate.py @@ -11,6 +11,7 @@ import jax from maxdiffusion import max_logging, max_utils, pyconfig +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer from maxdiffusion.pipelines.wan.wan_pipeline_animate import WanAnimatePipeline from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.utils import export_to_video @@ -43,7 +44,9 @@ def run(config): max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") load_start = time.perf_counter() - pipeline = WanAnimatePipeline.from_pretrained(config) + pipeline = WanCheckpointer.load_pretrained_pipeline_or_diffusers( + config, WanAnimatePipeline, (("wan_state", "transformer"),), "transformer" + ) load_time = time.perf_counter() - load_start max_logging.log(f"load_time: {load_time:.1f}s") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9e92449c7..d471e4fe4 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -99,6 +99,30 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl return vs +def _select_restored_transformer_state(restored_checkpoint, subfolder: str): + """Select the transformer state that belongs to a WAN checkpoint restore call. + + WAN 2.1 and Wan Animate checkpoints use a single `wan_state`. WAN 2.2 checkpoints + store separate transformer states: diffusers `transformer_2` is the low-noise + transformer, while `transformer` is the high-noise transformer. + """ + checkpoint_keys = restored_checkpoint.keys() + if "wan_state" in checkpoint_keys: + return restored_checkpoint["wan_state"] + + if subfolder == "transformer_2": + if "low_noise_transformer_state" not in checkpoint_keys: + raise ValueError("WAN checkpoint is missing `low_noise_transformer_state` for subfolder `transformer_2`.") + return restored_checkpoint["low_noise_transformer_state"] + + if subfolder == "transformer": + if "high_noise_transformer_state" not in checkpoint_keys: + raise ValueError("WAN checkpoint is missing `high_noise_transformer_state` for subfolder `transformer`.") + return restored_checkpoint["high_noise_transformer_state"] + + raise ValueError(f"Unsupported WAN checkpoint transformer subfolder `{subfolder}`.") + + # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( devices_array: np.array, @@ -155,13 +179,14 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): state = dict(nnx.to_flat_state(state)) # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them + # This helps with loading sharded weights directly into the accelerators without first copying them # all to one device and then distributing them, thus using low HBM memory. if restored_checkpoint: - if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer - params = restored_checkpoint["wan_state"]["params"] + checkpoint_state = _select_restored_transformer_state(restored_checkpoint, subfolder) + if "params" in checkpoint_state: # if checkpointed with optimizer + params = checkpoint_state["params"] else: # if not checkpointed with optimizer - params = restored_checkpoint["wan_state"] + params = checkpoint_state else: params = load_wan_transformer( config.wan_transformer_pretrained_model_name_or_path, diff --git a/src/maxdiffusion/tests/wan/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan/wan_checkpointer_test.py index a1674a57a..b18b0df7e 100644 --- a/src/maxdiffusion/tests/wan/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan/wan_checkpointer_test.py @@ -13,13 +13,150 @@ import unittest from unittest.mock import patch, MagicMock +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1 from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 +from maxdiffusion.pipelines.wan.wan_pipeline import _select_restored_transformer_state from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 +class WanPretrainedCacheTest(unittest.TestCase): + """Tests for the shared WAN pretrained Orbax cache helper.""" + + def setUp(self): + self.config = MagicMock() + self.config.pretrained_orbax_dir = "/tmp/wan_pretrained_cache" + self.pipeline_cls = MagicMock() + self.state_sources = (("wan_state", "transformer"),) + self.config_transformer_attr = "transformer" + + @patch.object(WanCheckpointer, "_restore_pretrained_checkpoint") + def test_loads_from_pretrained_cache_hit(self, mock_restore): + restored_checkpoint = MagicMock() + mock_restore.return_value = restored_checkpoint + pipeline = MagicMock() + self.pipeline_cls.from_checkpoint.return_value = pipeline + + result = WanCheckpointer.load_pretrained_pipeline_or_diffusers( + self.config, self.pipeline_cls, self.state_sources, self.config_transformer_attr + ) + + mock_restore.assert_called_once_with(self.config.pretrained_orbax_dir, ("wan_state",)) + self.pipeline_cls.from_checkpoint.assert_called_once_with(self.config, restored_checkpoint) + self.pipeline_cls.from_pretrained.assert_not_called() + self.assertEqual(result, pipeline) + + @patch.object(WanCheckpointer, "_save_pretrained_checkpoint") + @patch.object(WanCheckpointer, "_restore_pretrained_checkpoint") + def test_loads_from_diffusers_and_saves_on_cache_miss(self, mock_restore, mock_save): + mock_restore.return_value = None + pipeline = MagicMock() + self.pipeline_cls.from_pretrained.return_value = pipeline + + result = WanCheckpointer.load_pretrained_pipeline_or_diffusers( + self.config, self.pipeline_cls, self.state_sources, self.config_transformer_attr + ) + + self.pipeline_cls.from_pretrained.assert_called_once_with(self.config) + mock_save.assert_called_once_with( + self.config.pretrained_orbax_dir, pipeline, self.state_sources, self.config_transformer_attr + ) + self.assertEqual(result, pipeline) + + @patch.object(WanCheckpointer, "_save_pretrained_checkpoint") + @patch.object(WanCheckpointer, "_restore_pretrained_checkpoint") + def test_empty_pretrained_dir_uses_diffusers_without_cache(self, mock_restore, mock_save): + self.config.pretrained_orbax_dir = "" + pipeline = MagicMock() + self.pipeline_cls.from_pretrained.return_value = pipeline + + result = WanCheckpointer.load_pretrained_pipeline_or_diffusers( + self.config, self.pipeline_cls, self.state_sources, self.config_transformer_attr + ) + + mock_restore.assert_not_called() + mock_save.assert_not_called() + self.pipeline_cls.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(result, pipeline) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.nnx.split") + def test_pretrained_save_items_uses_explicit_transformer_config(self, mock_split): + pipeline = MagicMock() + low_noise_transformer = MagicMock() + high_noise_transformer = MagicMock() + low_noise_transformer.to_json_string.return_value = '{"model_type": "wan"}' + pipeline.low_noise_transformer = low_noise_transformer + pipeline.high_noise_transformer = high_noise_transformer + low_noise_state = MagicMock() + high_noise_state = MagicMock() + low_noise_state.to_pure_dict.return_value = {"low": "state"} + high_noise_state.to_pure_dict.return_value = {"high": "state"} + mock_split.side_effect = [ + (None, low_noise_state, None), + (None, high_noise_state, None), + ] + + items = WanCheckpointer._pretrained_save_items( + pipeline, + ( + ("low_noise_transformer_state", "low_noise_transformer"), + ("high_noise_transformer_state", "high_noise_transformer"), + ), + "low_noise_transformer", + ) + + low_noise_transformer.to_json_string.assert_called_once() + high_noise_transformer.to_json_string.assert_not_called() + mock_split.assert_any_call(low_noise_transformer, unittest.mock.ANY, ...) + mock_split.assert_any_call(high_noise_transformer, unittest.mock.ANY, ...) + self.assertIn("low_noise_transformer_state", items) + self.assertIn("high_noise_transformer_state", items) + self.assertIn("wan_config", items) + + def test_pretrained_save_items_requires_transformer_source(self): + with self.assertRaisesRegex(ValueError, "at least one transformer source"): + WanCheckpointer._pretrained_save_items(MagicMock(), (), "transformer") + + +class WanRestoredTransformerStateTest(unittest.TestCase): + """Tests for strict WAN checkpoint state selection.""" + + def test_selects_single_wan_state(self): + restored_checkpoint = {"wan_state": {"params": {}}} + + self.assertEqual(_select_restored_transformer_state(restored_checkpoint, ""), {"params": {}}) + + def test_selects_wan_2_2_low_noise_state(self): + restored_checkpoint = {"low_noise_transformer_state": {"low": "state"}} + + self.assertEqual( + _select_restored_transformer_state(restored_checkpoint, "transformer_2"), + {"low": "state"}, + ) + + def test_selects_wan_2_2_high_noise_state(self): + restored_checkpoint = {"high_noise_transformer_state": {"high": "state"}} + + self.assertEqual( + _select_restored_transformer_state(restored_checkpoint, "transformer"), + {"high": "state"}, + ) + + def test_rejects_mismatched_wan_2_2_state(self): + restored_checkpoint = {"low_noise_transformer_state": {"low": "state"}} + + with self.assertRaisesRegex(ValueError, "high_noise_transformer_state"): + _select_restored_transformer_state(restored_checkpoint, "transformer") + + def test_rejects_unknown_subfolder(self): + restored_checkpoint = {"low_noise_transformer_state": {}, "high_noise_transformer_state": {}} + + with self.assertRaisesRegex(ValueError, "Unsupported WAN checkpoint transformer subfolder"): + _select_restored_transformer_state(restored_checkpoint, "unexpected") + + class WanCheckpointer2_1Test(unittest.TestCase): """Tests for WAN 2.1 checkpointer."""