Skip to content

Refactor moe.py PR #1#3981

Open
Shuwen-Fang wants to merge 1 commit into
mainfrom
refactor_moe
Open

Refactor moe.py PR #1#3981
Shuwen-Fang wants to merge 1 commit into
mainfrom
refactor_moe

Conversation

@Shuwen-Fang
Copy link
Copy Markdown
Collaborator

@Shuwen-Fang Shuwen-Fang commented May 26, 2026

Description

Refactor sparse_matmul for implementing chunking feature to overlap AG-RS. The refactor extracts individual components (routing, gmm-up, gmm-down) from the monolithic wrapper function so it's more extensible and makes it possible to implement features like chunking.

High level psuedocode for what the end state should look like:

def moe_route_and_compute(x, ...):
    routing, route_metadata = route(logits)
    if not chunking:
        x = dispatch(x, route_metadata)
        intermediate = gmm_up(x, w0, w1, routing)
    else:
        intermediate = zeros(...)
        for x_chunk in chunks(x):                      
            x_chunk = dispatch(x_chunk, route_metadata)
            intermediate += gmm_up(x_chunk, w0_chunk, w1_chunk, routing)
    intermediate = maybe_tp_reduce(intermediate)
    output = gmm_down(intermediate, wo, routing)
    return unpermute(output, routing)

This PR refactors routing and gmm_up into helper functions, and creates Dataclass RouteOutput and RouteMetadata to hold the 11 return values.

Next steps:

  • Extract gmm_down and unpermute into helpers; remote wrapper
  • Implement the chunking loop (may required splitting route further)

Tests

Testing

  • Runing maxtext training succeeds:
  • Sparse matmul:
    source /home/shuwenf_google_com/venv-maxtext/bin/activate && PYTHONPATH=src python -m maxtext.trainers.pre_train.train
    src/maxtext/configs/base.yml
    model_name=deepseek3-tiny
    base_output_directory=/tmp/moe_refactor_test
    run_name=smoke_test
    dataset_type=synthetic
    enable_checkpointing=false
    sparse_matmul=True
    steps=3
    per_device_batch_size=1
    max_target_length=128
    ici_expert_parallelism=8
    attention=dot_product
    2>&1 | tail -20

  • dense matmul:
    source /home/shuwenf_google_com/venv-maxtext/bin/activate && PYTHONPATH=/home/shuwenf_google_com/maxtext python -m maxtext.trainers.pre_train.train
    src/maxtext/configs/base.yml
    model_name=deepseek3-tiny
    base_output_directory=/tmp/moe_refactor_test
    run_name=dense_ep8
    dataset_type=synthetic
    enable_checkpointing=false
    sparse_matmul=False
    steps=3
    per_device_batch_size=1
    max_target_length=128
    ici_expert_parallelism=8
    attention=dot_product
    2>&1 | grep -E "completed step|Error|Traceback" | head -10

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@Shuwen-Fang Shuwen-Fang requested a review from NuojCheng May 26, 2026 18:15
@Shuwen-Fang Shuwen-Fang changed the title refactor moe Refactor moe.py PR #1 May 26, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented May 26, 2026

Codecov Report

❌ Patch coverage is 77.19298% with 13 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 77.19% 11 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/layers/moe.py Outdated
local_expert_size: The number of experts handled by the current shard.
shard_index: The index of the current expert shard (0 to
num_expert_parallelism - 1).
self.get_expert_parallelism_size() - 1).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to update comments?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread src/maxtext/layers/moe.py
),
check_vma=False,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should give this function a better name, like sparse_matmul_core?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on moe_route_and_compute? (prev suggestion by matt)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially sparse_matmul_route_and_compute since we are inside of moe.py? We have sparse_matmul vs. dense_matmul strategies.

Comment thread src/maxtext/layers/moe.py
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be my dream if we can break the shardmap of wrapper into smaller shardmaps, e.g. separate shardmap on route function, gmm_up, and permute/unpermute! But not in this PR of course

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added follow up bug

Comment thread src/maxtext/layers/moe.py Outdated
def route(x, logits, pre_bias_logits, rngs):
"""Route tokens by expert"""
num_ep = self.get_expert_parallelism_size()
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) if num_ep > 1 else 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe else None? I imagine expert_shard_id won't ever get used when num_ep == 1.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roll_to_expert_id=num_experts_per_shard * expert_shard_id is called even when num_ep == 1 in the ring of experts = true path.

Technically users should not configure ring of experts when ep = 1, but setting 0 would prevent a crash if that was set

@github-actions
Copy link
Copy Markdown

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully refactors sparse_matmul by extracting its core components—routing and up-projection—into modular helper functions. This is a solid architectural improvement that paves the way for more complex features like chunking and overlapping communication. However, a critical logic issue was identified in the ring-of-experts implementation that needs to be addressed before merging.

🔍 General Feedback

  • Positive Modularization: Extracting route and gmm_up significantly improves the readability of the Moe layer and makes it much more extensible.
  • Data Encapsulation: The introduction of RouteMetadata and RouteOutput dataclasses is a great way to handle the numerous return values from the routing logic.
  • Critical Bug: In use_ring_of_experts mode, the filtering of group_sizes without corresponding filtering of other routing arrays (like selected_experts) will likely lead to shape mismatches and runtime errors in downstream projections, especially when using AQT.
  • JAX Compatibility: It's highly recommended to use flax.struct.dataclass for the new data structures to ensure they are properly registered as JAX pytrees, avoiding potential issues with transformations like jit or shard_map.

Comment thread src/maxtext/layers/moe.py Outdated
Comment thread src/maxtext/layers/moe.py Outdated
Comment thread src/maxtext/layers/moe.py
Comment thread src/maxtext/layers/moe.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Could you help run a few perf sanity checks to ensure no perf hit after the refactor? I think test for 1) FSDP & 2) FSDP + EP cases should be sufficient. You could test deepseek v2 locally between with/without changes.

Comment thread src/maxtext/layers/moe.py
class RouteMetadata:
"""EP communication state needed to undo the forward all-to-all after expert computation."""

expert_shard_id: int
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help add some comments for attributes?

Comment thread src/maxtext/layers/moe.py
class RouteOutput:
"""Embed-independent routing state returned by route(). x and RouteMetadata are returned separately."""

group_sizes: jax.Array # tokens per expert → gmm_fn
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add some comments for them? i.e. lb_loss: Optional[jax.Array]: # auxiliary loss for token distribution among experts

FYI:

Licensed under the Apache License, Version 2.0 (the "License");

Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the efforts Shuwen!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants