From 4190f094da188f370e2a9e4091bc7ed23aaae113 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sun, 24 May 2026 09:27:00 +0200 Subject: [PATCH 1/2] Fix PyTorch multivariate_normal array-like inputs --- src/pyrecest/_backend/pytorch/random.py | 42 ++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index 318782463..be4a7daec 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -13,6 +13,12 @@ from ._dtype import _allow_complex_dtype, _modify_func_default_dtype +_COMPLEX_TO_FLOAT_DTYPE = { + _torch.complex64: _torch.float32, + _torch.complex128: _torch.float64, +} + + def _choice_size(size): if size is None: return None, 1 @@ -151,11 +157,37 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None): return (high - low) * _torch.rand(size, dtype=dtype, device=device) + low +def _tensor_device(*values): + for value in values: + if _torch.is_tensor(value): + return value.device + return None + + +def _floating_distribution_dtype(*values): + for value in values: + if not _torch.is_tensor(value): + continue + if value.dtype.is_floating_point: + return value.dtype + if value.dtype.is_complex: + return _COMPLEX_TO_FLOAT_DTYPE[value.dtype] + return _torch.get_default_dtype() + + +def _normal_sample_size(size): + if size is None: + return () + if not hasattr(size, "__iter__"): + return (size,) + return tuple(size) + + @_modify_func_default_dtype(copy=False, kw_only=True) @_allow_complex_dtype def multivariate_normal(mean, cov, size=None): - if size is None: - size = () - elif not hasattr(size, "__iter__"): - size = (size,) - return _MultivariateNormal(mean, cov).sample(size) + device = _tensor_device(mean, cov) + dtype = _floating_distribution_dtype(mean, cov) + mean = _torch.as_tensor(mean, dtype=dtype, device=device) + cov = _torch.as_tensor(cov, dtype=mean.dtype, device=mean.device) + return _MultivariateNormal(mean, cov).sample(_normal_sample_size(size)) \ No newline at end of file From a1895c085081b185bdeb250949eb2a9a65ab9d81 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sun, 24 May 2026 09:27:33 +0200 Subject: [PATCH 2/2] Add multivariate_normal array-like backend regression --- tests/test_backend_random.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_backend_random.py b/tests/test_backend_random.py index 31f63e093..30bcd5be4 100644 --- a/tests/test_backend_random.py +++ b/tests/test_backend_random.py @@ -70,6 +70,13 @@ def test_choice_samples_matrix_values_along_requested_axis(self): self.assertTrue(set(sample_np[0].tolist()).issubset({0, 1, 2})) self.assertTrue(set(sample_np[1].tolist()).issubset({3, 4, 5})) + def test_multivariate_normal_accepts_python_sequences(self): + samples = random.multivariate_normal( + [0.0, 0.0], [[1.0, 0.0], [0.0, 1.0]], size=(6,) + ) + + self.assertEqual(tuple(pyrecest.backend.shape(samples)), (6, 2)) + @unittest.skipIf( pyrecest.backend.__backend_name__ != "jax", "JAX-specific RNG state contract" ) @@ -101,4 +108,4 @@ def test_jax_multinomial_uses_and_advances_global_state(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file