[JAX] Support calling MOE router kernels from JAX side#2711
[JAX] Support calling MOE router kernels from JAX side#2711tdophung wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR adds JAX support for MoE (Mixture of Experts) router kernels, enabling JAX/Flax models to use the same fused CUDA kernels already available in PyTorch. Key Changes
Implementation Notes
Confidence Score: 5/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Input: Logits<br/>num_tokens x num_experts] --> B{Router Operation}
B --> C[fused_topk_with_score_function]
B --> D[fused_compute_score_for_moe_aux_loss]
C --> C1[Apply score function<br/>softmax or sigmoid]
C1 --> C2[Top-K selection<br/>with optional grouping]
C2 --> C3[Output: Sparse probs +<br/>routing_map boolean]
D --> D1[Apply score function<br/>softmax or sigmoid]
D1 --> D2[Plain top-K selection]
D2 --> D3[Output: Dense scores +<br/>routing_map boolean]
C3 --> E[Token permutation<br/>using routing_map]
D3 --> F[fused_moe_aux_loss]
F --> F1[Compute per-expert<br/>token counts]
F1 --> F2[Calculate auxiliary loss<br/>for load balancing]
F2 --> F3[Output: Scalar loss]
F3 --> G[Add to total loss]
style C fill:#e1f5ff
style D fill:#e1f5ff
style F fill:#ffe1e1
style C3 fill:#e8f5e9
style D3 fill:#e8f5e9
style F3 fill:#e8f5e9
Last reviewed commit: 1fa388f |
| SCORE_FUNCTION_MAP = {"sigmoid": 0, "softmax": 1} | ||
|
|
||
|
|
||
| # =========================================== ================================== |
There was a problem hiding this comment.
inconsistent section header formatting
| # =========================================== ================================== | |
| # ============================================================================= |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
I created this PR without being aware of https://github.com/NVIDIA/TransformerEngine/pull/2385/changes But after reviewing it, I see that both PR are doing very similar things, I will startt with addressing @phu0ngng comments on that PR |
Description
Current router kernels are present in common and callable from Pytorch side but not JAX. This PR support JAX router for either standalone use or later intergation to Maxtext moe layer.
Fixes # 2710
Type of change
Changes
Please list the changes introduced in this PR:
a.
fused_topk_with_score_function_kernel: main routing kernel that outputs a sparsed probs matrix for chosen experts + a routing map to feed in permutation. This supports 2 scoring functions: softmax and sigmoid. Also support group_topk algorithmb.
fused_score_for_moe_aux_loss: step 1 of the side path to calculate auxiliary loss for load balancing. This step calculate the binary routing map, and the dense probs matrix for every expertsc.
fused_moe_aux_loss: step 2 of calculating auxiliary loss. This calculates the lossfused_score_for_moe_aux_loss)tokens_per_expert[i], derived fromrouting_map.sum(dim=0))total_num_tokens,num_expertsAdd custom partitioning for each of the above kernels when possible (sharded on the num token dimensions on the first 2 kernels, and just pure repetition on last kernel)
Add tests for both single GPU and distributed case to verify sharding correctness
Checklist: