Skip to content

[JAX] Support calling MOE router kernels from JAX side#2711

Open
tdophung wants to merge 3 commits intoNVIDIA:mainfrom
tdophung:router_jax
Open

[JAX] Support calling MOE router kernels from JAX side#2711
tdophung wants to merge 3 commits intoNVIDIA:mainfrom
tdophung:router_jax

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Feb 26, 2026

Description

Current router kernels are present in common and callable from Pytorch side but not JAX. This PR support JAX router for either standalone use or later intergation to Maxtext moe layer.

Fixes # 2710

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  1. Add custom calls to router kernels: including
    a. fused_topk_with_score_function_kernel: main routing kernel that outputs a sparsed probs matrix for chosen experts + a routing map to feed in permutation. This supports 2 scoring functions: softmax and sigmoid. Also support group_topk algorithm
    b. fused_score_for_moe_aux_loss: step 1 of the side path to calculate auxiliary loss for load balancing. This step calculate the binary routing map, and the dense probs matrix for every experts
    c. fused_moe_aux_loss: step 2 of calculating auxiliary loss. This calculates the loss $$L_{\text{aux}} = C \cdot \sum_{i=1}^{E} \left(\sum_{t=1}^{N} p_{t,i}\right) \cdot f_i$$ where:
  • $p_{t,i}$ = probability that token $t$ assigns to expert $i$ (from fused_score_for_moe_aux_loss)
  • $f_i$ = number of tokens routed to expert $i$ (tokens_per_expert[i], derived from routing_map.sum(dim=0))
  • $C = \frac{E \cdot \text{coeff}}{K \cdot T^2}$ where $T$ = total_num_tokens, $K$ = topk, $E$ = num_experts
  1. Add custom partitioning for each of the above kernels when possible (sharded on the num token dimensions on the first 2 kernels, and just pure repetition on last kernel)

  2. Add tests for both single GPU and distributed case to verify sharding correctness

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 26, 2026

Greptile Summary

This PR adds JAX support for MoE (Mixture of Experts) router kernels, enabling JAX/Flax models to use the same fused CUDA kernels already available in PyTorch.

Key Changes

  • Three router operations with forward/backward passes:

    • fused_topk_with_score_function: Main routing kernel that computes sparse probability matrix and routing map
    • fused_compute_score_for_moe_aux_loss: Computes routing map and dense probabilities for auxiliary loss calculation
    • fused_moe_aux_loss: Computes the MoE auxiliary load-balancing loss scalar
  • Custom partitioning for distributed execution:

    • Token dimension can be sharded across GPUs (data parallelism)
    • Expert dimension is kept replicated (no expert parallelism)
    • Aux loss operation uses full replication
  • Comprehensive test coverage:

    • Single-GPU tests validate correctness against reference JAX implementations
    • Distributed tests verify sharding behavior with both GSPMD and Shardy partitioners
    • Forward and backward passes are tested

Implementation Notes

  • The JAX implementation matches the PyTorch API and behavior (including not computing gradients for expert_bias parameter)
  • FFI (Foreign Function Interface) is used to call underlying CUDA kernels
  • Custom VJP (Vector-Jacobian Product) rules handle automatic differentiation

Confidence Score: 5/5

  • This PR is safe to merge - well-structured implementation following established patterns
  • The code is comprehensive with proper testing, matches the existing PyTorch implementation, includes correct gradient computations, and has appropriate error handling. Only minor style issue found.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/router.py New high-level JAX API for MoE router operations with proper custom_vjp definitions for automatic differentiation
transformer_engine/jax/cpp_extensions/router.py Primitive definitions with FFI lowering, custom partitioning for distributed execution, and shardy sharding rules
transformer_engine/jax/csrc/extensions/router.cpp C++ FFI handlers that bridge JAX operations to underlying CUDA kernels with proper tensor wrapping and stream handling
tests/jax/test_fused_router.py Comprehensive single-GPU tests comparing fused kernels against reference JAX implementations for forward and backward passes
tests/jax/test_distributed_router.py Distributed execution tests verifying correct sharding behavior across multiple GPUs with both GSPMD and Shardy partitioners

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Input: Logits<br/>num_tokens x num_experts] --> B{Router Operation}
    
    B --> C[fused_topk_with_score_function]
    B --> D[fused_compute_score_for_moe_aux_loss]
    
    C --> C1[Apply score function<br/>softmax or sigmoid]
    C1 --> C2[Top-K selection<br/>with optional grouping]
    C2 --> C3[Output: Sparse probs +<br/>routing_map boolean]
    
    D --> D1[Apply score function<br/>softmax or sigmoid]
    D1 --> D2[Plain top-K selection]
    D2 --> D3[Output: Dense scores +<br/>routing_map boolean]
    
    C3 --> E[Token permutation<br/>using routing_map]
    D3 --> F[fused_moe_aux_loss]
    
    F --> F1[Compute per-expert<br/>token counts]
    F1 --> F2[Calculate auxiliary loss<br/>for load balancing]
    F2 --> F3[Output: Scalar loss]
    
    F3 --> G[Add to total loss]
    
    style C fill:#e1f5ff
    style D fill:#e1f5ff
    style F fill:#ffe1e1
    style C3 fill:#e8f5e9
    style D3 fill:#e8f5e9
    style F3 fill:#e8f5e9
Loading

Last reviewed commit: 1fa388f

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

SCORE_FUNCTION_MAP = {"sigmoid": 0, "softmax": 1}


# =========================================== ==================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inconsistent section header formatting

Suggested change
# =========================================== ==================================
# =============================================================================

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@tdophung tdophung requested review from jberchtold-nvidia and phu0ngng and removed request for jberchtold-nvidia February 26, 2026 19:35
@tdophung
Copy link
Collaborator Author

I created this PR without being aware of https://github.com/NVIDIA/TransformerEngine/pull/2385/changes

But after reviewing it, I see that both PR are doing very similar things, I will startt with addressing @phu0ngng comments on that PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant