diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index 318782463..7cda3a2e7 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -120,6 +120,8 @@ def rand(size=None, dtype=None): def multinomial(n, pvals): + device = pvals.device if _torch.is_tensor(pvals) else None + pvals = _torch.as_tensor(pvals, dtype=_torch.float32, device=device) pvals = pvals / pvals.sum() return _torch.multinomial(pvals, n, replacement=True).bincount(minlength=len(pvals)) diff --git a/tests/test_backend_random.py b/tests/test_backend_random.py index 31f63e093..cf18e7538 100644 --- a/tests/test_backend_random.py +++ b/tests/test_backend_random.py @@ -70,6 +70,12 @@ def test_choice_samples_matrix_values_along_requested_axis(self): self.assertTrue(set(sample_np[0].tolist()).issubset({0, 1, 2})) self.assertTrue(set(sample_np[1].tolist()).issubset({3, 4, 5})) + def test_multinomial_accepts_python_probability_sequence(self): + sample = random.multinomial(12, [0.25, 0.75]) + + self.assertEqual(sample.shape, (2,)) + self.assertEqual(int(pyrecest.backend.sum(sample)), 12) + @unittest.skipIf( pyrecest.backend.__backend_name__ != "jax", "JAX-specific RNG state contract" )