From 51906b8b3dd43bafa8a3a45f2b62b1af6812b91b Mon Sep 17 00:00:00 2001 From: zxuhan7 Date: Thu, 14 May 2026 21:06:33 +0200 Subject: [PATCH] [schedulers] fix RecursionError in CosineDPMSolverMultistepScheduler `CosineDPMSolverMultistepScheduler.step` initialised `BrownianTreeNoiseSampler` with `sigma_min`/`sigma_max` from the config, but the sampler is queried with `self.sigmas[step_index]` values that drift outside those bounds: the Karras/exponential reconstruction of the endpoints in fp32 lands a few ULPs off, and `final_sigmas_type="zero"` makes the last `sigmas` entry strictly below `config.sigma_min`. Out-of-range queries push torchsde into unbounded recursive interval splitting and trip Python's recursion limit (#13274). Initialise the sampler with the actual `self.sigmas` extrema instead, matching the pattern in `scheduling_dpmsolver_sde.py`. Adds a regression test covering both Karras and exponential schedules with `final_sigmas_type="zero"`. --- .../scheduling_cosine_dpmsolver_multistep.py | 10 ++++- ...st_scheduler_cosine_dpmsolver_multistep.py | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 tests/schedulers/test_scheduler_cosine_dpmsolver_multistep.py diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index 18ee272ef619..5b22c34b9d70 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -653,10 +653,16 @@ def step( seed = ( [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() ) + # Use the actual extrema of `self.sigmas` rather than the config bounds: + # the Karras/exponential reconstruction of the endpoints in fp32 can drift + # by a few ULPs, and `final_sigmas_type="zero"` makes `sigmas[-1] == 0`, + # both of which fall outside `[config.sigma_min, config.sigma_max]`. An + # out-of-range query drives `torchsde` into unbounded recursive splitting + # of its Brownian interval and eventually raises `RecursionError` (#13274). self.noise_sampler = BrownianTreeNoiseSampler( model_output, - sigma_min=self.config.sigma_min, - sigma_max=self.config.sigma_max, + sigma_min=self.sigmas.min().item(), + sigma_max=self.sigmas.max().item(), seed=seed, ) noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( diff --git a/tests/schedulers/test_scheduler_cosine_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_cosine_dpmsolver_multistep.py new file mode 100644 index 000000000000..770357645079 --- /dev/null +++ b/tests/schedulers/test_scheduler_cosine_dpmsolver_multistep.py @@ -0,0 +1,40 @@ +import unittest +import warnings + +import torch + +from diffusers import CosineDPMSolverMultistepScheduler + +from ..testing_utils import require_torchsde + + +@require_torchsde +class CosineDPMSolverMultistepSchedulerTest(unittest.TestCase): + """Regression tests for `CosineDPMSolverMultistepScheduler` (used by Stable Audio Open).""" + + def _run_loop(self, **scheduler_kwargs): + scheduler = CosineDPMSolverMultistepScheduler(**scheduler_kwargs) + scheduler.set_timesteps(num_inference_steps=10, device="cpu") + sample = torch.randn(1, 4, 8) + generator = torch.Generator().manual_seed(0) + for t in scheduler.timesteps: + model_output = torch.randn_like(sample) + sample = scheduler.step(model_output, t, sample, generator=generator).prev_sample + return sample + + def test_step_does_not_recurse_with_zero_final_sigma(self): + # See https://github.com/huggingface/diffusers/issues/13274. With the defaults + # used by Stable Audio Open (sigma_min=0.3, sigma_max=500, final_sigmas_type="zero") + # querying the Brownian sampler at sigma_next=0 used to fall below the configured + # `sigma_min` interval and recurse until Python's recursion limit was hit. + for sigma_schedule in ("exponential", "karras"): + with self.subTest(sigma_schedule=sigma_schedule): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + sample = self._run_loop( + sigma_schedule=sigma_schedule, + final_sigmas_type="zero", + sigma_min=0.3, + sigma_max=500.0, + ) + self.assertFalse(torch.isnan(sample).any().item())