Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Greptile SummaryThis PR enables Collective GEMM to work with Shardy propagation by replacing Key changes:
Confidence Score: 5/5
Important Files Changed
Last reviewed commit: 0fae7a4 |
| if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes: | ||
| warnings.warn( | ||
| "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" | ||
| " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", | ||
| UserWarning, | ||
| ) |
There was a problem hiding this comment.
Warning condition misses collective_op_set_2
The guard only checks collective_op_set_1, so users who configure a collective op only on the second GEMM (e.g. a reduce-scatter-only row-parallel layer, where collective_op_set_1 remains noop_collective_op_set) will never see this warning, even though dot_2_input_axes is equally needed to constrain the input to dot_2 (lines 341–345 and 363) and the gradient in _layernorm_mlp_bwd_rule (line 504).
The condition should be broadened to cover either collective op set being active:
| if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes: | |
| warnings.warn( | |
| "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" | |
| " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", | |
| UserWarning, | |
| ) | |
| if (collective_op_set_1 != tex.noop_collective_op_set or collective_op_set_2 != tex.noop_collective_op_set) and not dot_2_input_axes: | |
| warnings.warn( | |
| "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" | |
| " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", | |
| UserWarning, | |
| ) |
There was a problem hiding this comment.
We don't need to check for thecollective_op_set_2 here as the dot2 output sharding (for RS) will work correctly without additional sharding constraint.
| if collective_op_set != tex.noop_collective_op_set and not output_axes: | ||
| warnings.warn( | ||
| "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" | ||
| " for the output. Set `output_axes` to apply the correct sharding constraint.", | ||
| UserWarning, | ||
| ) |
There was a problem hiding this comment.
Warning fires even when Shardy is not active
This check runs unconditionally at Python-call time, so users who explicitly disable Shardy (e.g. via JAX_USE_SHARDY_PARTITIONER=false) or who rely on GSPMD propagation will also see this UserWarning whenever they use collective GEMM without setting output_axes. The warning text specifically says "Shardy propagation", which makes it confusing in a non-Shardy context.
Consider gating the check on whether Shardy is active to keep the signal actionable:
if (
collective_op_set != tex.noop_collective_op_set
and not output_axes
and jax.config.jax_use_shardy_partitioner
):
warnings.warn(...)The same pattern applies to the analogous guard added in layernorm_mlp.py.
There was a problem hiding this comment.
Shardy propagation is the default in JAX, and GSPMD will be deprecated this month. So this check is not needed.
|
/te-ci JAX L1 |
Description
This PR replaces
NotImplementedErrorwithUserWarningin the Shardy lowering rule forGemmPrimitivewhen aCollectiveOpis present, allowing Collective GEMM to be used with Shardy propagation as long as the output sharding constraint is set correctly.Previously, any attempt to use Collective GEMM with Shardy active would raise an error and require users to disable Shardy entirely. Now a UserWarning is emitted with actionable guidance on how to apply the correct output sharding constraint for each TE entry point:
te.densevjp: setoutput_axeste.layernorm_mlpvjp: setdot_2_input_axeste.jax.cpp_extensions.gemmdirectly: applyjax.lax.with_sharding_constrainton the outputNote that without a sharding constraint, Shardy propagation does not work as intended occasionally is a known-issue. A similar WAR was implemented in #2128.
Type of change
Checklist: