Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from ._dtype import _allow_complex_dtype, _modify_func_default_dtype


_COMPLEX_TO_FLOAT_DTYPE = {
_torch.complex64: _torch.float32,
_torch.complex128: _torch.float64,
}


def _choice_size(size):
if size is None:
return None, 1
Expand Down Expand Up @@ -153,11 +159,37 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None):
return (high - low) * _torch.rand(size, dtype=dtype, device=device) + low


def _tensor_device(*values):
for value in values:
if _torch.is_tensor(value):
return value.device
return None


def _floating_distribution_dtype(*values):
for value in values:
if not _torch.is_tensor(value):
continue
if value.dtype.is_floating_point:
return value.dtype
if value.dtype.is_complex:
return _COMPLEX_TO_FLOAT_DTYPE[value.dtype]
return _torch.get_default_dtype()


def _normal_sample_size(size):
if size is None:
return ()
if not hasattr(size, "__iter__"):
return (size,)
return tuple(size)


@_modify_func_default_dtype(copy=False, kw_only=True)
@_allow_complex_dtype
def multivariate_normal(mean, cov, size=None):
if size is None:
size = ()
elif not hasattr(size, "__iter__"):
size = (size,)
return _MultivariateNormal(mean, cov).sample(size)
device = _tensor_device(mean, cov)
dtype = _floating_distribution_dtype(mean, cov)
mean = _torch.as_tensor(mean, dtype=dtype, device=device)
cov = _torch.as_tensor(cov, dtype=mean.dtype, device=mean.device)
return _MultivariateNormal(mean, cov).sample(_normal_sample_size(size))
7 changes: 7 additions & 0 deletions tests/test_backend_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def test_choice_samples_matrix_values_along_requested_axis(self):
self.assertTrue(set(sample_np[0].tolist()).issubset({0, 1, 2}))
self.assertTrue(set(sample_np[1].tolist()).issubset({3, 4, 5}))

def test_multivariate_normal_accepts_python_sequences(self):
samples = random.multivariate_normal(
[0.0, 0.0], [[1.0, 0.0], [0.0, 1.0]], size=(6,)
)

self.assertEqual(tuple(pyrecest.backend.shape(samples)), (6, 2))

def test_multinomial_accepts_python_probability_sequence(self):
sample = random.multinomial(12, [0.25, 0.75])

Expand Down
Loading