Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,16 @@ def step(
seed = (
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
)
# Use the actual extrema of `self.sigmas` rather than the config bounds:
# the Karras/exponential reconstruction of the endpoints in fp32 can drift
# by a few ULPs, and `final_sigmas_type="zero"` makes `sigmas[-1] == 0`,
# both of which fall outside `[config.sigma_min, config.sigma_max]`. An
# out-of-range query drives `torchsde` into unbounded recursive splitting
# of its Brownian interval and eventually raises `RecursionError` (#13274).
self.noise_sampler = BrownianTreeNoiseSampler(
model_output,
sigma_min=self.config.sigma_min,
sigma_max=self.config.sigma_max,
sigma_min=self.sigmas.min().item(),
sigma_max=self.sigmas.max().item(),
seed=seed,
)
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
Expand Down
40 changes: 40 additions & 0 deletions tests/schedulers/test_scheduler_cosine_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
import warnings

import torch

from diffusers import CosineDPMSolverMultistepScheduler

from ..testing_utils import require_torchsde


@require_torchsde
class CosineDPMSolverMultistepSchedulerTest(unittest.TestCase):
"""Regression tests for `CosineDPMSolverMultistepScheduler` (used by Stable Audio Open)."""

def _run_loop(self, **scheduler_kwargs):
scheduler = CosineDPMSolverMultistepScheduler(**scheduler_kwargs)
scheduler.set_timesteps(num_inference_steps=10, device="cpu")
sample = torch.randn(1, 4, 8)
generator = torch.Generator().manual_seed(0)
for t in scheduler.timesteps:
model_output = torch.randn_like(sample)
sample = scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample

def test_step_does_not_recurse_with_zero_final_sigma(self):
# See https://github.com/huggingface/diffusers/issues/13274. With the defaults
# used by Stable Audio Open (sigma_min=0.3, sigma_max=500, final_sigmas_type="zero")
# querying the Brownian sampler at sigma_next=0 used to fall below the configured
# `sigma_min` interval and recurse until Python's recursion limit was hit.
for sigma_schedule in ("exponential", "karras"):
with self.subTest(sigma_schedule=sigma_schedule):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
sample = self._run_loop(
sigma_schedule=sigma_schedule,
final_sigmas_type="zero",
sigma_min=0.3,
sigma_max=500.0,
)
self.assertFalse(torch.isnan(sample).any().item())
Loading