diff --git a/src/pyrecest/_backend/pytorch/__init__.py b/src/pyrecest/_backend/pytorch/__init__.py index 5d8ee111c..05b927aeb 100644 --- a/src/pyrecest/_backend/pytorch/__init__.py +++ b/src/pyrecest/_backend/pytorch/__init__.py @@ -199,6 +199,112 @@ def cov(input, correction=1, fweights=None, aweights=None, bias=False): return cov_matrix +def _quantile_q(q, x): + if _torch.is_tensor(q): + return q.to(device=x.device, dtype=x.dtype) + if _np.isscalar(q): + return float(q) + return _torch.as_tensor(q, dtype=x.dtype, device=x.device) + + +def _quantile_q_shape(q): + if _torch.is_tensor(q): + return tuple(q.shape) + return tuple(_np.shape(q)) + + +def quantile( + a, + q, + axis=None, + out=None, + overwrite_input=False, + method="linear", + keepdims=False, + *, + dim=None, + keepdim=None, + interpolation=None, +): + """Return quantiles using NumPy-compatible argument names.""" + del overwrite_input + + if dim is not None: + if axis is not None and axis != dim: + raise TypeError("quantile() got both 'axis' and 'dim'") + axis = dim + if keepdim is not None: + if keepdims is not False and keepdims != keepdim: + raise TypeError("quantile() got both 'keepdims' and 'keepdim'") + keepdims = keepdim + if interpolation is not None: + method = interpolation + + x = array(a) + if is_complex(x): + raise TypeError("a must be an array of real numbers") + if not is_floating(x): + x = cast(x, dtype=get_default_dtype()) + + q_arg = _quantile_q(q, x) + q_shape = _quantile_q_shape(q) + + if axis is None or isinstance(axis, (int, _np.integer)): + kwargs = {"dim": axis, "keepdim": keepdims, "interpolation": method} + if out is not None: + kwargs["out"] = out + return _torch.quantile(x, q_arg, **kwargs) + + axes = _normalize_reduction_axes(axis, x.ndim) + if not axes: + result = x + if q_shape: + result = _torch.broadcast_to(result, q_shape + tuple(x.shape)) + if out is not None: + out.copy_(result) + return out + return result + + remaining_axes = tuple(dim for dim in range(x.ndim) if dim not in axes) + permuted = x.permute(axes + remaining_axes) + reduced_size = int(_np.prod([x.shape[dim] for dim in axes])) + reduced = permuted.reshape( + (reduced_size,) + tuple(x.shape[dim] for dim in remaining_axes) + ) + result = _torch.quantile(reduced, q_arg, dim=0, interpolation=method) + + if keepdims: + result = result.reshape( + q_shape + + tuple(1 if dim in axes else x.shape[dim] for dim in range(x.ndim)) + ) + if out is not None: + out.copy_(result) + return out + return result + + +def count_nonzero(a, axis=None, keepdims=False): + """Count non-zero entries using NumPy-compatible reduction semantics.""" + x = array(a) + if axis is None: + result = _torch.count_nonzero(x) + if keepdims: + return result.reshape((1,) * x.ndim) + return result + + counts = (x != 0).to(dtype=_torch.int64) + result = _reduce_over_axes( + counts, axis, lambda values, one_axis: _torch.sum(values, dim=one_axis) + ) + if keepdims: + axes = _normalize_reduction_axes(axis, x.ndim) + result = result.reshape( + tuple(1 if dim in axes else x.shape[dim] for dim in range(x.ndim)) + ) + return result + + def has_autodiff(): """If allows for automatic differentiation. diff --git a/tests/test_pytorch_backend.py b/tests/test_pytorch_backend.py index 8ca91e664..d64d66ecc 100644 --- a/tests/test_pytorch_backend.py +++ b/tests/test_pytorch_backend.py @@ -111,6 +111,32 @@ def test_all_accepts_tuple_axis_in_any_order(self): self.assertEqual(tuple(result.shape), (3,)) self.assertEqual(result.tolist(), [False, False, True]) + def test_quantile_accepts_numpy_style_method_axis_and_keepdims(self): + values = pytorch_backend.array( + list(range(24)), dtype=pytorch_backend.float64 + ).reshape(2, 3, 4) + + result = pytorch_backend.quantile( + values, [0.25, 0.5], axis=(0, 2), keepdims=True, method="linear" + ) + + expected = pytorch_backend.array( + [[[[1.75], [5.75], [9.75]]], [[[7.5], [11.5], [15.5]]]], + dtype=pytorch_backend.float64, + ) + self.assertEqual(tuple(result.shape), (2, 1, 3, 1)) + self.assertTrue(pytorch_backend.allclose(result, expected)) + + def test_count_nonzero_accepts_numpy_style_axis_and_keepdims(self): + values = pytorch_backend.array( + [[[0, 1], [2, 0]], [[3, 4], [0, 0]]], dtype=pytorch_backend.int64 + ) + + result = pytorch_backend.count_nonzero(values, axis=(0, 2), keepdims=True) + + self.assertEqual(tuple(result.shape), (1, 2, 1)) + self.assertEqual(result.tolist(), [[[3], [1]]]) + def test_where_promotes_scalar_to_tensor_dtype(self): mask = pytorch_backend.array([True, False]) fallback = pytorch_backend.array([2.0, 3.0], dtype=pytorch_backend.float64) @@ -125,7 +151,6 @@ def test_where_promotes_scalar_to_tensor_dtype(self): ) ) - @unittest.skipIf(pytorch_backend is None, "PyTorch is not installed") class TestPytorchBackendRandom(unittest.TestCase): def test_choice_accepts_weighted_sampling_without_replacement(self):