feat(ltx2): implement centralized, configuration-driven logical sharding strategy for LTX-2 and LTX-2.3#414
Conversation
f6789b2 to
a37a799
Compare
c35d3da to
b45ac10
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
The Pull Request successfully centralizes the sharding strategy for LTX-2 and LTX-2.3, which is a great architectural improvement. It eliminates hardware-specific logic from individual model layers and moves it to a configuration-driven registry. This significantly improves maintainability and makes it easier to support new hardware in the future.
🔍 General Feedback
- Correctness: Identified a potential
AttributeErrorin core model files (attention_flax.py,embeddings_flax.py) whensharding_specsisNone. This needs to be addressed as it will cause crashes when these components are used with default arguments. - Efficiency: The sharding specs are resolved repeatedly during inference in the pipeline. Storing these specs as pipeline attributes during initialization would be a minor but worthwhile optimization.
- Robustness: The strategy lookup logic silently defaults to a specific hardware profile on unknown input, which could hide configuration typos.
- Tests: The inclusion of
test_logical_sharding_ltx2.pyand updates to existing tests provide good coverage for the new sharding logic.
5bcf2e8 to
48c2d3d
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully implements a centralized, configuration-driven logical sharding strategy registry for LTX-2 and LTX-2.3. The refactoring significantly improves the modularity and maintainability of the sharding logic by decoupling it from individual model layers and hardware-specific checks.
🔍 General Feedback
- Architecture: The introduction of
logical_sharding_ltx2.pyis a great architectural improvement, making sharding strategies explicit and easily extensible. - Robustness: The use of
safe_getattrand fallback logic ensures that the model remains functional even with incomplete sharding specifications. - Performance: Moving pipeline decisions like VAE replication and text-encoder batching to the registry allows for better hardware-specific tuning.
- Minor Issues: I've noted a likely discrepancy in the text encoder batching logic for Ironwood and a change in the default VAE replication behavior that should be confirmed.
0abfd7e to
794703f
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully introduces a centralized, configuration-driven sharding strategy for LTX-2 and LTX-2.3 models, which is a significant improvement over hardcoded hardware checks. The implementation uses a registry-based approach that correctly maps hardware profiles to logical sharding specifications.
🔍 General Feedback
- Regression on Ironwood: The
use_batched_text_encoderflag for theironwoodprofile is currently set toFalse, which contradicts the previous logic that enabled it for TPU v7x. This should be corrected to avoid a performance/correctness regression. - Incomplete Parameterization: While the major components (Transformer, Attention) are well-parameterized, some internal layers in the VAE ResNet blocks and the Text Encoder Feature Extractor were missed and still use hardcoded partitioning.
- Improved Maintainability: Moving hardware-specific logic to
logical_sharding_ltx2.pygreatly simplifies the model code and makes it easier to support future hardware. - Unit Testing: The addition of
test_logical_sharding_ltx2.pyprovides good coverage for the new factory logic.
26a35e7 to
441a307
Compare
prishajain1
left a comment
There was a problem hiding this comment.
Just one comment about batched text encoder for ironwood, rest looks good to me!
mbohlool
left a comment
There was a problem hiding this comment.
there are merge conflicts. generally looks good. just two comments.
83a74c5 to
634b642
Compare
|
PTAL @mbohlool, I have addressed your comments |
eb9598e to
3136b13
Compare
Summary
This PR introduces a centralized, configuration-driven logical sharding strategy registry for LTX-2 and LTX-2.3 in MaxDiffusion. It eliminates ad-hoc hardware checks and hardcoded sharding constraints in model layers by moving sharding specifications to a centralized, hardware-aware registry.
Key Changes
logical_sharding_ltx2.pyto define sharding spec profiles for Ironwood (TPU v7x, 1D heads-wise sharding) and Trillium (TPU v6e, 2D heads + embed sharding).attention_flax.pyandembeddings_flax.pyusing generic duck-typing interfaces (getattrfallback logic) to prevent code coupling.force_replication) and text-encoding batching (use_batched_text_encoder) pipeline decisions to be configuration-driven under the central spec registry.sharding,text_encoder_dtype,compile_text_encoder, andbase_output_directoryparameters to LTX-2/2.3 configs, enabling dynamic text-encoder compilation and clean overrides via the CLI.test_logical_sharding_ltx2.py) to verify routing and hardware auto-detection logic.Performance
Conclusion: This change is purely structuring and does not impact peformance