From ff34b3601da2b356a48c889fc256d85c5e272d89 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sun, 24 May 2026 09:20:36 +0200 Subject: [PATCH 1/2] Fix PyTorch multinomial probability input coercion --- src/pyrecest/_backend/pytorch/random.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index 318782463..7cda3a2e7 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -120,6 +120,8 @@ def rand(size=None, dtype=None): def multinomial(n, pvals): + device = pvals.device if _torch.is_tensor(pvals) else None + pvals = _torch.as_tensor(pvals, dtype=_torch.float32, device=device) pvals = pvals / pvals.sum() return _torch.multinomial(pvals, n, replacement=True).bincount(minlength=len(pvals)) From 43a6453eafabfd05403450b6fd13a8e17a4a2707 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sun, 24 May 2026 09:21:35 +0200 Subject: [PATCH 2/2] Add multinomial probability sequence regression test --- tests/test_backend_random.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_backend_random.py b/tests/test_backend_random.py index 31f63e093..cf18e7538 100644 --- a/tests/test_backend_random.py +++ b/tests/test_backend_random.py @@ -70,6 +70,12 @@ def test_choice_samples_matrix_values_along_requested_axis(self): self.assertTrue(set(sample_np[0].tolist()).issubset({0, 1, 2})) self.assertTrue(set(sample_np[1].tolist()).issubset({3, 4, 5})) + def test_multinomial_accepts_python_probability_sequence(self): + sample = random.multinomial(12, [0.25, 0.75]) + + self.assertEqual(sample.shape, (2,)) + self.assertEqual(int(pyrecest.backend.sum(sample)), 12) + @unittest.skipIf( pyrecest.backend.__backend_name__ != "jax", "JAX-specific RNG state contract" )