Skip to content

[JAX] CGEMM with Shardy #2714

Open
phu0ngng wants to merge 6 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_shardy
Open

[JAX] CGEMM with Shardy #2714
phu0ngng wants to merge 6 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_shardy

Conversation

@phu0ngng
Copy link
Collaborator

Description

This PR replaces NotImplementedError with UserWarning in the Shardy lowering rule for GemmPrimitive when a CollectiveOp is present, allowing Collective GEMM to be used with Shardy propagation as long as the output sharding constraint is set correctly.

Previously, any attempt to use Collective GEMM with Shardy active would raise an error and require users to disable Shardy entirely. Now a UserWarning is emitted with actionable guidance on how to apply the correct output sharding constraint for each TE entry point:

  • te.dense vjp: set output_axes
  • te.layernorm_mlp vjp: set dot_2_input_axes
  • te.jax.cpp_extensions.gemm directly: apply jax.lax.with_sharding_constraint on the output
  • MaxText: no action needed

Note that without a sharding constraint, Shardy propagation does not work as intended occasionally is a known-issue. A similar WAR was implemented in #2128.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR enables Collective GEMM to work with Shardy propagation by replacing NotImplementedError with UserWarning in the lowering rule. The implementation adds warnings at entry points (te.dense and te.layernorm_mlp) to guide users on setting proper output sharding constraints (output_axes and dot_2_input_axes respectively). Tests are updated to properly set sharding constraints and no longer disable Shardy. The changes are backward-compatible and follow the same pattern as a previous workaround in PR #2128.

Key changes:

  • Replaced error with actionable warning in GemmPrimitive Shardy lowering rule
  • Added parameter-based checks in dense and layernorm_mlp entry points
  • Updated all example tests to use proper sharding constraints
  • Removed all Shardy disabling code from examples

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are well-contained and follow a clear pattern. The core logic replaces a hard error with a warning, which is strictly less restrictive. The warning messages provide clear, actionable guidance. Tests are updated to demonstrate proper usage. No complex logic changes or refactoring. Previous similar approach was successfully implemented in PR [JAX] dot_1_output sharding constraint + use AXIS_IS_UNSHARDED #2128.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Replaced NotImplementedError with UserWarning in Shardy lowering rule for CollectiveGEMM, providing guidance on how to apply correct sharding constraints
transformer_engine/jax/dense.py Added warning when collective op is used without output_axes parameter to guide users on proper sharding constraint setup
transformer_engine/jax/layernorm_mlp.py Added warning when collective_op_set_1 is used without dot_2_input_axes parameter to guide users on proper sharding constraint setup

Last reviewed commit: 0fae7a4

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +278 to +283
if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning condition misses collective_op_set_2

The guard only checks collective_op_set_1, so users who configure a collective op only on the second GEMM (e.g. a reduce-scatter-only row-parallel layer, where collective_op_set_1 remains noop_collective_op_set) will never see this warning, even though dot_2_input_axes is equally needed to constrain the input to dot_2 (lines 341–345 and 363) and the gradient in _layernorm_mlp_bwd_rule (line 504).

The condition should be broadened to cover either collective op set being active:

Suggested change
if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)
if (collective_op_set_1 != tex.noop_collective_op_set or collective_op_set_2 != tex.noop_collective_op_set) and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to check for thecollective_op_set_2 here as the dot2 output sharding (for RS) will work correctly without additional sharding constraint.

Comment on lines +97 to +102
if collective_op_set != tex.noop_collective_op_set and not output_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `output_axes` to apply the correct sharding constraint.",
UserWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning fires even when Shardy is not active

This check runs unconditionally at Python-call time, so users who explicitly disable Shardy (e.g. via JAX_USE_SHARDY_PARTITIONER=false) or who rely on GSPMD propagation will also see this UserWarning whenever they use collective GEMM without setting output_axes. The warning text specifically says "Shardy propagation", which makes it confusing in a non-Shardy context.

Consider gating the check on whether Shardy is active to keep the signal actionable:

if (
    collective_op_set != tex.noop_collective_op_set
    and not output_axes
    and jax.config.jax_use_shardy_partitioner
):
    warnings.warn(...)

The same pattern applies to the analogous guard added in layernorm_mlp.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shardy propagation is the default in JAX, and GSPMD will be deprecated this month. So this check is not needed.

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant