Refactor moe.py PR #1#3981
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| local_expert_size: The number of experts handled by the current shard. | ||
| shard_index: The index of the current expert shard (0 to | ||
| num_expert_parallelism - 1). | ||
| self.get_expert_parallelism_size() - 1). |
There was a problem hiding this comment.
you don't need to update comments?
| ), | ||
| check_vma=False, | ||
| ) | ||
| def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): |
There was a problem hiding this comment.
We should give this function a better name, like sparse_matmul_core?
There was a problem hiding this comment.
thoughts on moe_route_and_compute? (prev suggestion by matt)
There was a problem hiding this comment.
Potentially sparse_matmul_route_and_compute since we are inside of moe.py? We have sparse_matmul vs. dense_matmul strategies.
| ) | ||
| def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): | ||
| batch_size, sequence_length, _ = x.shape | ||
| x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs) |
There was a problem hiding this comment.
It would be my dream if we can break the shardmap of wrapper into smaller shardmaps, e.g. separate shardmap on route function, gmm_up, and permute/unpermute! But not in this PR of course
There was a problem hiding this comment.
added follow up bug
| def route(x, logits, pre_bias_logits, rngs): | ||
| """Route tokens by expert""" | ||
| num_ep = self.get_expert_parallelism_size() | ||
| expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) if num_ep > 1 else 0 |
There was a problem hiding this comment.
maybe else None? I imagine expert_shard_id won't ever get used when num_ep == 1.
There was a problem hiding this comment.
roll_to_expert_id=num_experts_per_shard * expert_shard_id is called even when num_ep == 1 in the ring of experts = true path.
Technically users should not configure ring of experts when ep = 1, but setting 0 would prevent a crash if that was set
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully refactors sparse_matmul by extracting its core components—routing and up-projection—into modular helper functions. This is a solid architectural improvement that paves the way for more complex features like chunking and overlapping communication. However, a critical logic issue was identified in the ring-of-experts implementation that needs to be addressed before merging.
🔍 General Feedback
- Positive Modularization: Extracting
routeandgmm_upsignificantly improves the readability of the Moe layer and makes it much more extensible. - Data Encapsulation: The introduction of
RouteMetadataandRouteOutputdataclasses is a great way to handle the numerous return values from the routing logic. - Critical Bug: In
use_ring_of_expertsmode, the filtering ofgroup_sizeswithout corresponding filtering of other routing arrays (likeselected_experts) will likely lead to shape mismatches and runtime errors in downstream projections, especially when using AQT. - JAX Compatibility: It's highly recommended to use
flax.struct.dataclassfor the new data structures to ensure they are properly registered as JAX pytrees, avoiding potential issues with transformations likejitorshard_map.
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! Could you help run a few perf sanity checks to ensure no perf hit after the refactor? I think test for 1) FSDP & 2) FSDP + EP cases should be sufficient. You could test deepseek v2 locally between with/without changes.
| class RouteMetadata: | ||
| """EP communication state needed to undo the forward all-to-all after expert computation.""" | ||
|
|
||
| expert_shard_id: int |
There was a problem hiding this comment.
Could you help add some comments for attributes?
| class RouteOutput: | ||
| """Embed-independent routing state returned by route(). x and RouteMetadata are returned separately.""" | ||
|
|
||
| group_sizes: jax.Array # tokens per expert → gmm_fn |
There was a problem hiding this comment.
Shall we add some comments for them? i.e. lb_loss: Optional[jax.Array]: # auxiliary loss for token distribution among experts
FYI:
NuojCheng
left a comment
There was a problem hiding this comment.
Thank you for the efforts Shuwen!
Description
Refactor sparse_matmul for implementing chunking feature to overlap AG-RS. The refactor extracts individual components (routing, gmm-up, gmm-down) from the monolithic wrapper function so it's more extensible and makes it possible to implement features like chunking.
High level psuedocode for what the end state should look like:
This PR refactors routing and gmm_up into helper functions, and creates Dataclass RouteOutput and RouteMetadata to hold the 11 return values.
Next steps:
Tests
Testing
Sparse matmul:
source /home/shuwenf_google_com/venv-maxtext/bin/activate && PYTHONPATH=src python -m maxtext.trainers.pre_train.train
src/maxtext/configs/base.yml
model_name=deepseek3-tiny
base_output_directory=/tmp/moe_refactor_test
run_name=smoke_test
dataset_type=synthetic
enable_checkpointing=false
sparse_matmul=True
steps=3
per_device_batch_size=1
max_target_length=128
ici_expert_parallelism=8
attention=dot_product
2>&1 | tail -20
dense matmul:
source /home/shuwenf_google_com/venv-maxtext/bin/activate && PYTHONPATH=/home/shuwenf_google_com/maxtext python -m maxtext.trainers.pre_train.train
src/maxtext/configs/base.yml
model_name=deepseek3-tiny
base_output_directory=/tmp/moe_refactor_test
run_name=dense_ep8
dataset_type=synthetic
enable_checkpointing=false
sparse_matmul=False
steps=3
per_device_batch_size=1
max_target_length=128
ici_expert_parallelism=8
attention=dot_product
2>&1 | grep -E "completed step|Error|Traceback" | head -10
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.