Skip to content

[JAX] GSPMD Deprecation Warning - Only trigger when the primitive is invoked#2729

Open
phu0ngng wants to merge 6 commits intoNVIDIA:mainfrom
phu0ngng:gspmd
Open

[JAX] GSPMD Deprecation Warning - Only trigger when the primitive is invoked#2729
phu0ngng wants to merge 6 commits intoNVIDIA:mainfrom
phu0ngng:gspmd

Conversation

@phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Mar 3, 2026

Description

PR #2702 added a GSPMD deprecation warning when registering the primitives, and a JAX version <= 0.9.1 is available. This is a false positive check, as users may not use GSPMD, but the warning is still printed.

This PR adjusts the warning so that it is only triggered when the primitive is invoked with GSPMD.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng added 2 commits March 3, 2026 09:55
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@greptile-apps

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/base.py
stacklevel no longer points to user code

stacklevel=3 was set when _warn_gspmd_deprecation_once() was called from register_primitive (at import time), with a shallow call stack. Now the warning is raised from inside _gspmd_wrapper, which is itself invoked deep inside JAX's custom_partitioning dispatch machinery — adding several extra frames between the warning site and the user's invocation point.

With stacklevel=3, the warning will report the location as somewhere inside JAX's internals rather than the user's code or the TE library wrapper. Consider reducing to stacklevel=2 so the warning points to _gspmd_wrapper itself, which provides more context about where in TE the issue originated.

        warnings.warn(
            "GSPMD sharding propagation rules in TE-JAX are planned to be removed in June 2026."
            " They are no longer maintained or tested. Use them at your own risk."
            " Please use Shardy propagation instead."
            " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!",
            DeprecationWarning,
            stacklevel=2,
        )

phu0ngng and others added 3 commits March 3, 2026 10:10
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 3, 2026

/te-ci JAX L1

cls.infer_sharding_from_operands
) # Use descriptor protocol to unwrap staticmethod

def _gspmd_wrapper(*args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If JAX has issues internally with the signature here, you may need the functools.wraps decorator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants