diff --git a/src/pyrecest/_backend/__init__.py b/src/pyrecest/_backend/__init__.py index 83e4e318c..2c48ab684 100644 --- a/src/pyrecest/_backend/__init__.py +++ b/src/pyrecest/_backend/__init__.py @@ -290,6 +290,33 @@ def _deduplicated_attributes(attributes): BACKEND_ATTRIBUTES[_module_name] = _deduplicated_attributes(_attributes) +def _quantile_with_numpy_axis(quantile_func, asarray_func): + """Return a NumPy-compatible quantile wrapper for stricter backends.""" + + @wraps(quantile_func) + def quantile( + a, + q, + axis=None, + out=None, + overwrite_input=False, + method="linear", + keepdims=False, + *, + interpolation=None, + ): + del overwrite_input + if interpolation is not None: + method = interpolation + + kwargs = {"dim": axis, "keepdim": keepdims, "interpolation": method} + if out is not None: + kwargs["out"] = out + return quantile_func(asarray_func(a), asarray_func(q), **kwargs) + + return quantile + + def _meshgrid_with_arraylike_axes(meshgrid_func, asarray_func, atleast_1d_func): """Return a NumPy-compatible meshgrid wrapper for stricter backends.""" @@ -375,6 +402,15 @@ def _create_backend_module(self, backend_name: str): getattr(backend, "asarray"), getattr(backend, "atleast_1d"), ) + if ( + module_name == "" + and attribute_name == "quantile" + and backend_name == "pytorch" + ): + attribute = _quantile_with_numpy_axis( + attribute, + getattr(backend, "asarray"), + ) setattr(new_submodule, attribute_name, attribute) for attribute_name in OPTIONAL_BACKEND_ATTRIBUTES.get(module_name, []): diff --git a/tests/backend_support/test_pytorch_quantile_contract.py b/tests/backend_support/test_pytorch_quantile_contract.py new file mode 100644 index 000000000..ac7e0ef28 --- /dev/null +++ b/tests/backend_support/test_pytorch_quantile_contract.py @@ -0,0 +1,32 @@ +"""Regression tests for PyTorch backend quantile keyword compatibility.""" + +from __future__ import annotations + +import importlib.util + +import pytest +from tests.support.backend_runner import run_backend_code + + +@pytest.mark.backend_portable +def test_pytorch_quantile_accepts_numpy_axis_and_keepdims_keywords(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "pytorch", + """ +import pyrecest.backend as backend + +values = backend.asarray([[1.0, 3.0], [2.0, 5.0], [4.0, 7.0]]) +median = backend.quantile(values, 0.5, axis=0) +median_keepdims = backend.quantile(values, 0.5, axis=0, keepdims=True) +list_quantile = backend.quantile([[1.0, 4.0], [3.0, 8.0]], [0.25, 0.75], axis=0) + +assert backend.to_numpy(median).tolist() == [2.0, 5.0] +assert backend.to_numpy(median_keepdims).tolist() == [[2.0, 5.0]] +assert backend.to_numpy(list_quantile).tolist() == [[1.5, 5.0], [2.5, 7.0]] +""", + ) + + assert result.returncode == 0, result.stderr