From 888f786af1e9793b68b369612da998bba9196c7e Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 3 Mar 2026 09:55:07 -0800 Subject: [PATCH 1/6] gspmd warning - only trigger when the primitive is invoked Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index ae3888cf04..03b23a1970 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -189,7 +189,7 @@ def _warn_gspmd_deprecation_once(): global _gspmd_deprecation_warned if not _gspmd_deprecation_warned: warnings.warn( - "GSPMD sharding propagation is planned to be removed in June 2026." + "GSPMD sharding propagation rules in TE-JAX are planned to be removed in June 2026." " It is no longer maintained or tested. Use it at your own risk." " Please use Shardy partitioner instead." " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!", @@ -234,12 +234,19 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + if _JAX_GSPMD_SUPPORTED: - if "infer_sharding_from_operands" in cls.__dict__: - _warn_gspmd_deprecation_once() - gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} + fn = cls.__dict__.get("infer_sharding_from_operands") + if fn is not None: + def _gspmd_wrapper(*args, **kwargs): + _warn_gspmd_deprecation_once() + return fn(*args, **kwargs) + gspmd_kwargs = {"infer_sharding_from_operands": _gspmd_wrapper} + else: + gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} else: gspmd_kwargs = {} + outer_p_lower.def_partition( partition=cls.partition, sharding_rule=cls.shardy_sharding_rule, From b734b8d5c4e768fbd53cb9a898969565c5db362f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 3 Mar 2026 10:02:24 -0800 Subject: [PATCH 2/6] rewords Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 03b23a1970..92cd5b81f6 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -190,8 +190,8 @@ def _warn_gspmd_deprecation_once(): if not _gspmd_deprecation_warned: warnings.warn( "GSPMD sharding propagation rules in TE-JAX are planned to be removed in June 2026." - " It is no longer maintained or tested. Use it at your own risk." - " Please use Shardy partitioner instead." + " They are no longer maintained or tested. Use them at your own risk." + " Please use Shardy propagation instead." " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!", DeprecationWarning, stacklevel=3, From 88e8403551e627f02ad0e4df5a9435429db3d3a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:03:37 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 92cd5b81f6..b65bcfa5d9 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -238,9 +238,11 @@ def name_of_wrapper_p(): if _JAX_GSPMD_SUPPORTED: fn = cls.__dict__.get("infer_sharding_from_operands") if fn is not None: + def _gspmd_wrapper(*args, **kwargs): _warn_gspmd_deprecation_once() return fn(*args, **kwargs) + gspmd_kwargs = {"infer_sharding_from_operands": _gspmd_wrapper} else: gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} From 5029d8ddaf4d7dedbb7b27686bcc4d656483259c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 3 Mar 2026 10:10:07 -0800 Subject: [PATCH 4/6] Update transformer_engine/jax/cpp_extensions/base.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index b65bcfa5d9..71ebc86db5 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -238,10 +238,11 @@ def name_of_wrapper_p(): if _JAX_GSPMD_SUPPORTED: fn = cls.__dict__.get("infer_sharding_from_operands") if fn is not None: + actual_fn = cls.infer_sharding_from_operands # Use descriptor protocol to unwrap staticmethod def _gspmd_wrapper(*args, **kwargs): _warn_gspmd_deprecation_once() - return fn(*args, **kwargs) + return actual_fn(*args, **kwargs) gspmd_kwargs = {"infer_sharding_from_operands": _gspmd_wrapper} else: From 9cf8b0f7274929134cff3f95f50054e62899adc4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:10:52 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 71ebc86db5..bc3b3bc755 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -238,7 +238,9 @@ def name_of_wrapper_p(): if _JAX_GSPMD_SUPPORTED: fn = cls.__dict__.get("infer_sharding_from_operands") if fn is not None: - actual_fn = cls.infer_sharding_from_operands # Use descriptor protocol to unwrap staticmethod + actual_fn = ( + cls.infer_sharding_from_operands + ) # Use descriptor protocol to unwrap staticmethod def _gspmd_wrapper(*args, **kwargs): _warn_gspmd_deprecation_once() From 24b5948968d85a06b0c60252951c526fe8e1f73b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 3 Mar 2026 10:12:44 -0800 Subject: [PATCH 6/6] stacklevel=2 Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index bc3b3bc755..6eb588c849 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -194,7 +194,7 @@ def _warn_gspmd_deprecation_once(): " Please use Shardy propagation instead." " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!", DeprecationWarning, - stacklevel=3, + stacklevel=2, ) _gspmd_deprecation_warned = True