[JAX] GSPMD Deprecation Warning - Only trigger when the primitive is invoked#2729
[JAX] GSPMD Deprecation Warning - Only trigger when the primitive is invoked#2729phu0ngng wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Additional Comments (1)
With |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
| cls.infer_sharding_from_operands | ||
| ) # Use descriptor protocol to unwrap staticmethod | ||
|
|
||
| def _gspmd_wrapper(*args, **kwargs): |
There was a problem hiding this comment.
If JAX has issues internally with the signature here, you may need the functools.wraps decorator
Description
PR #2702 added a GSPMD deprecation warning when registering the primitives, and a JAX version <= 0.9.1 is available. This is a false positive check, as users may not use GSPMD, but the warning is still printed.
This PR adjusts the warning so that it is only triggered when the primitive is invoked with GSPMD.
Type of change
Checklist: