From 4851a8b3dc96044b6017a496420a5b83e3106c22 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Sat, 30 May 2026 10:21:14 +0000 Subject: [PATCH] feat(ltx2): make run_text_encoder_on_tpu default False and dynamically load torchax --- src/maxdiffusion/configs/ltx2_3_video.yml | 4 ++++ src/maxdiffusion/configs/ltx2_video.yml | 2 +- src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py | 9 +++++---- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/configs/ltx2_3_video.yml b/src/maxdiffusion/configs/ltx2_3_video.yml index 042678bc..a5710623 100644 --- a/src/maxdiffusion/configs/ltx2_3_video.yml +++ b/src/maxdiffusion/configs/ltx2_3_video.yml @@ -108,6 +108,10 @@ profiler_steps: 5 replicate_vae: False +run_text_encoder_on_tpu: False +# Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes. +enable_dynamic_vae_sharding: True + allow_split_physical_axes: False learning_rate_schedule_steps: -1 max_train_steps: 500 diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 34cfae41..8ddfaba1 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -114,7 +114,7 @@ profiler_steps: 5 replicate_vae: False use_bwe: False -run_text_encoder_on_tpu: True +run_text_encoder_on_tpu: False # Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes. enable_dynamic_vae_sharding: True allow_split_physical_axes: False diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index b8062569..259fff59 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -23,8 +23,6 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from torchax import default_env -from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs import contextlib import flax @@ -352,7 +350,10 @@ def load_text_encoder(cls, config: HyperParameters): ) text_encoder.eval() - if getattr(config, "run_text_encoder_on_tpu", True): + if getattr(config, "run_text_encoder_on_tpu", False): + from torchax import default_env + from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder + with default_env(): text_encoder = text_encoder.to("jax") text_encoder = TorchaxGemma3TextEncoder(text_encoder) @@ -855,7 +856,7 @@ def _get_gemma_prompt_embeds( prompt = [p.strip() for p in prompt] if self.text_encoder is not None: - run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", True) if hasattr(self, "config") else True + run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", False) if hasattr(self, "config") else False if run_text_encoder_on_tpu: # Torchax Text Encoder text_inputs = self.tokenizer(