Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/pyrecest/_backend/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,112 @@ def cov(input, correction=1, fweights=None, aweights=None, bias=False):
return cov_matrix


def _quantile_q(q, x):
if _torch.is_tensor(q):
return q.to(device=x.device, dtype=x.dtype)
if _np.isscalar(q):
return float(q)
return _torch.as_tensor(q, dtype=x.dtype, device=x.device)


def _quantile_q_shape(q):
if _torch.is_tensor(q):
return tuple(q.shape)
return tuple(_np.shape(q))


def quantile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
method="linear",
keepdims=False,
*,
dim=None,
keepdim=None,
interpolation=None,
):
"""Return quantiles using NumPy-compatible argument names."""
del overwrite_input

if dim is not None:
if axis is not None and axis != dim:
raise TypeError("quantile() got both 'axis' and 'dim'")
axis = dim
if keepdim is not None:
if keepdims is not False and keepdims != keepdim:
raise TypeError("quantile() got both 'keepdims' and 'keepdim'")
keepdims = keepdim
if interpolation is not None:
method = interpolation

x = array(a)
if is_complex(x):
raise TypeError("a must be an array of real numbers")
if not is_floating(x):
x = cast(x, dtype=get_default_dtype())

q_arg = _quantile_q(q, x)
q_shape = _quantile_q_shape(q)

if axis is None or isinstance(axis, (int, _np.integer)):
kwargs = {"dim": axis, "keepdim": keepdims, "interpolation": method}
if out is not None:
kwargs["out"] = out
return _torch.quantile(x, q_arg, **kwargs)

axes = _normalize_reduction_axes(axis, x.ndim)
if not axes:
result = x
if q_shape:
result = _torch.broadcast_to(result, q_shape + tuple(x.shape))
if out is not None:
out.copy_(result)
return out
return result

remaining_axes = tuple(dim for dim in range(x.ndim) if dim not in axes)
permuted = x.permute(axes + remaining_axes)
reduced_size = int(_np.prod([x.shape[dim] for dim in axes]))
reduced = permuted.reshape(
(reduced_size,) + tuple(x.shape[dim] for dim in remaining_axes)
)
result = _torch.quantile(reduced, q_arg, dim=0, interpolation=method)

if keepdims:
result = result.reshape(
q_shape
+ tuple(1 if dim in axes else x.shape[dim] for dim in range(x.ndim))
)
if out is not None:
out.copy_(result)
return out
return result


def count_nonzero(a, axis=None, keepdims=False):
"""Count non-zero entries using NumPy-compatible reduction semantics."""
x = array(a)
if axis is None:
result = _torch.count_nonzero(x)
if keepdims:
return result.reshape((1,) * x.ndim)
return result

counts = (x != 0).to(dtype=_torch.int64)
result = _reduce_over_axes(
counts, axis, lambda values, one_axis: _torch.sum(values, dim=one_axis)
)
if keepdims:
axes = _normalize_reduction_axes(axis, x.ndim)
result = result.reshape(
tuple(1 if dim in axes else x.shape[dim] for dim in range(x.ndim))
)
return result


def has_autodiff():
"""If allows for automatic differentiation.

Expand Down
27 changes: 26 additions & 1 deletion tests/test_pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,32 @@ def test_all_accepts_tuple_axis_in_any_order(self):
self.assertEqual(tuple(result.shape), (3,))
self.assertEqual(result.tolist(), [False, False, True])

def test_quantile_accepts_numpy_style_method_axis_and_keepdims(self):
values = pytorch_backend.array(
list(range(24)), dtype=pytorch_backend.float64
).reshape(2, 3, 4)

result = pytorch_backend.quantile(
values, [0.25, 0.5], axis=(0, 2), keepdims=True, method="linear"
)

expected = pytorch_backend.array(
[[[[1.75], [5.75], [9.75]]], [[[7.5], [11.5], [15.5]]]],
dtype=pytorch_backend.float64,
)
self.assertEqual(tuple(result.shape), (2, 1, 3, 1))
self.assertTrue(pytorch_backend.allclose(result, expected))

def test_count_nonzero_accepts_numpy_style_axis_and_keepdims(self):
values = pytorch_backend.array(
[[[0, 1], [2, 0]], [[3, 4], [0, 0]]], dtype=pytorch_backend.int64
)

result = pytorch_backend.count_nonzero(values, axis=(0, 2), keepdims=True)

self.assertEqual(tuple(result.shape), (1, 2, 1))
self.assertEqual(result.tolist(), [[[3], [1]]])

def test_where_promotes_scalar_to_tensor_dtype(self):
mask = pytorch_backend.array([True, False])
fallback = pytorch_backend.array([2.0, 3.0], dtype=pytorch_backend.float64)
Expand All @@ -125,7 +151,6 @@ def test_where_promotes_scalar_to_tensor_dtype(self):
)
)


@unittest.skipIf(pytorch_backend is None, "PyTorch is not installed")
class TestPytorchBackendRandom(unittest.TestCase):
def test_choice_accepts_weighted_sampling_without_replacement(self):
Expand Down
Loading