Skip to content

model_infrastructure model/pipeline review #13655

@hlky

Description

@hlky

model_infrastructure model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

All requested model_infrastructure files were reviewed:

src/diffusers/models/__init__.py
src/diffusers/models/_modeling_parallel.py
src/diffusers/models/activations.py
src/diffusers/models/attention.py
src/diffusers/models/attention_dispatch.py
src/diffusers/models/attention_flax.py
src/diffusers/models/attention_processor.py
src/diffusers/models/auto_model.py
src/diffusers/models/cache_utils.py
src/diffusers/models/downsampling.py
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings_flax.py
src/diffusers/models/lora.py
src/diffusers/models/model_loading_utils.py
src/diffusers/models/modeling_flax_pytorch_utils.py
src/diffusers/models/modeling_flax_utils.py
src/diffusers/models/modeling_outputs.py
src/diffusers/models/modeling_pytorch_flax_utils.py
src/diffusers/models/modeling_utils.py
src/diffusers/models/normalization.py
src/diffusers/models/resnet.py
src/diffusers/models/resnet_flax.py
src/diffusers/models/upsampling.py
src/diffusers/models/vae_flax.py
src/diffusers/models/vq_model.py

Duplicate-search status: searched existing GitHub Issues and PRs for model_infrastructure, affected file/function names, and the specific failures below. No direct duplicates were found. Related but not duplicate: huggingface/diffusers#12409 and #12533 around distributed/context-parallel availability.

Test coverage status: fast/unit coverage exists for parts of AutoModel, layer helpers, cache utilities, attention backends, and parallelism helpers, but the slow/integration coverage gaps listed in Issue 9 remain.

Issue 1: AutoModel.from_pretrained rejects PathLike pipeline roots

Affected code:

load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
load_id = "|".join("null" if p is None else p for p in parts)

Problem:
from_pretrained() accepts Union[str, os.PathLike], but the model-index path builds _diffusers_load_id with "|".join(parts) while parts may contain a Path object. Passing a Path root with subfolder fails before loading.

Impact:
Documented/local loading behavior breaks for users who pass pathlib.Path, and the failure is a low-level TypeError rather than a loading error.

Reproduction:

import json
import tempfile
from pathlib import Path
from diffusers import AutoModel, UNet2DModel

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as d:
    root = Path(d)
    unet_dir = root / "unet"
    UNet2DModel(
        sample_size=4, in_channels=1, out_channels=1, layers_per_block=1,
        block_out_channels=(4,), down_block_types=("DownBlock2D",),
        up_block_types=("UpBlock2D",), norm_num_groups=1,
    ).save_pretrained(unet_dir, safe_serialization=False)
    (root / "model_index.json").write_text(json.dumps({"_class_name": "DummyPipeline", "unet": ["diffusers", "UNet2DModel"]}))

    try:
        AutoModel.from_pretrained(root, subfolder="unet", use_safetensors=False)
    except Exception as e:
        print(type(e).__name__)
        print(str(e).splitlines()[0])

Relevant precedent:
Other loading paths normalize path-like inputs before string operations.

Suggested fix:

load_id = "|".join("null" if p is None else str(p) for p in parts)

Issue 2: AutoModel.from_pretrained leaks config_name after model-index loading

Affected code:

# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
load_config_kwargs.update({"subfolder": subfolder})
except EnvironmentError as e:
logger.debug(e)
# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)

Problem:
from_pretrained() mutates the class attribute cls.config_name to "model_index.json" and does not restore it after a successful model-index load. A later AutoModel.from_config(model_dir) then looks for model_index.json inside a plain model directory instead of config.json.

Impact:
One successful AutoModel.from_pretrained() call can change later behavior process-wide.

Reproduction:

import json
import tempfile
from pathlib import Path
from diffusers import AutoModel, UNet2DModel

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as d:
    root = Path(d)
    unet_dir = root / "unet"
    UNet2DModel(
        sample_size=4, in_channels=1, out_channels=1, layers_per_block=1,
        block_out_channels=(4,), down_block_types=("DownBlock2D",),
        up_block_types=("UpBlock2D",), norm_num_groups=1,
    ).save_pretrained(unet_dir, safe_serialization=False)
    (root / "model_index.json").write_text(json.dumps({"_class_name": "DummyPipeline", "unet": ["diffusers", "UNet2DModel"]}))

    AutoModel.from_pretrained(str(root), subfolder="unet", use_safetensors=False)
    print("leaked config_name:", AutoModel.config_name)
    try:
        AutoModel.from_config(str(unet_dir))
    except Exception as e:
        print(type(e).__name__, str(e).splitlines()[0])

Relevant precedent:
Config name overrides should be local to the load attempt, not stored on the public class.

Suggested fix:

old_config_name = cls.config_name
try:
    cls.config_name = "model_index.json"
    model_index, kwargs = cls.load_config(...)

    # existing model-index handling
finally:
    cls.config_name = old_config_name

Issue 3: CacheMixin.cache_context leaves stale hook state after exceptions

Affected code:

@contextmanager
def cache_context(self, name: str):
r"""Context manager that provides additional methods for cache management."""
from ..hooks import HookRegistry
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)
yield
registry._set_context(None)

Problem:
cache_context() clears the hook context only after the yield. If the wrapped forward pass raises, _current_context remains set on stateful hooks.

Impact:
A failed denoising step can leak cached/stateful hook state into later calls, and get_state() can incorrectly succeed outside a context.

Reproduction:

import torch
from diffusers.hooks.hooks import BaseState, HookRegistry, ModelHook, StateManager
from diffusers.models.cache_utils import CacheMixin

class State(BaseState):
    def reset(self):
        pass

class StatefulHook(ModelHook):
    _is_stateful = True
    def __init__(self):
        super().__init__()
        self.state_manager = StateManager(State)
    def reset_state(self, module):
        self.state_manager.reset()

class Model(torch.nn.Module, CacheMixin):
    def forward(self, x):
        return x

model = Model()
hook = StatefulHook()
HookRegistry.check_if_exists_or_initialize(model).register_hook(hook, "stateful")

try:
    with model.cache_context("failed-call"):
        raise RuntimeError("simulate an interrupted denoise step")
except RuntimeError:
    pass

print(hook.state_manager._current_context)
print(type(hook.state_manager.get_state()).__name__)

Relevant precedent:
Context managers that mutate global or hook state should restore that state in finally.

Suggested fix:

registry._set_context(name)
try:
    yield
finally:
    registry._set_context(None)

Issue 4: FIR up/downsampling fails on bfloat16 and float16 inputs

Affected code:

kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
if self.use_conv:
_, _, convH, convW = weight.shape
pad_value = (kernel.shape[0] - factor) + (convW - 1)
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,

kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),

kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
if self.use_conv:
convH = weight.shape[2]
convW = weight.shape[3]
inC = weight.shape[1]
pad_value = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
output_shape = (
(hidden_states.shape[2] - 1) * factor + convH,
(hidden_states.shape[3] - 1) * factor + convW,
)
output_padding = (
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = hidden_states.shape[1] // inC
# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
)
output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),

kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),

Problem:
The FIR helpers create kernels as default float32 tensors and pass them to convolution against low-precision hidden states.

Impact:
Low-precision model execution can fail in FIR downsample/upsample paths with dtype mismatch.

Reproduction:

import torch
from diffusers.models.downsampling import downsample_2d
from diffusers.models.upsampling import upsample_2d

x = torch.randn(1, 1, 4, 4, dtype=torch.bfloat16)
for fn in (downsample_2d, upsample_2d):
    try:
        print(fn.__name__, fn(x).dtype)
    except Exception as e:
        print(fn.__name__, type(e).__name__, str(e).splitlines()[0])

Relevant precedent:
Upsample2D already has low-precision test coverage; FIR helper/module paths need equivalent dtype handling.

Suggested fix:

kernel = torch.as_tensor(kernel, device=hidden_states.device, dtype=torch.float32)
if kernel.ndim == 1:
    kernel = torch.outer(kernel, kernel)
kernel = kernel / kernel.sum() * gain
kernel = kernel.to(dtype=hidden_states.dtype)

Issue 5: enable_parallelism uses the wrong distributed guard

Affected code:

if not torch.distributed.is_available() and not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed must be available and initialized before calling `enable_parallelism`."
)

Problem:
The guard uses and instead of or: not is_available() and not is_initialized(). When distributed is unavailable, this may call is_initialized() anyway; when distributed is available but uninitialized, the guard can fail to raise.

Impact:
Users get backend-specific errors instead of the intended clear RuntimeError, or parallelism proceeds before torch.distributed is initialized.

Reproduction:

from diffusers import ContextParallelConfig, UNet2DModel

model = UNet2DModel(
    sample_size=4, in_channels=1, out_channels=1, layers_per_block=1,
    block_out_channels=(4,), down_block_types=("DownBlock2D",),
    up_block_types=("UpBlock2D",), norm_num_groups=1,
)
try:
    model.enable_parallelism(config=ContextParallelConfig(ring_degree=2))
except Exception as e:
    print(type(e).__name__)
    print(str(e).splitlines()[0])

Relevant precedent:
Distributed APIs normally require both availability and initialization before model wrapping.

Suggested fix:

if not torch.distributed.is_available() or not torch.distributed.is_initialized():
    raise RuntimeError("torch.distributed must be available and initialized before calling `enable_parallelism`.")

Issue 6: Legacy Attention.set_use_xla_flash_attention checks the function object instead of calling it

Affected code:

def set_use_xla_flash_attention(
self,
use_xla_flash_attention: bool,
partition_spec: tuple[str | None, ...] | None = None,
is_flux=False,
) -> None:
r"""
Set whether to use xla flash attention from `torch_xla` or not.
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
"""
if use_xla_flash_attention:
if not is_torch_xla_available:
raise "torch_xla is not available"
elif is_torch_xla_version("<", "2.3"):
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
elif is_spmd() and is_torch_xla_version("<", "2.4"):
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
else:
if is_flux:
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
else:
processor = XLAFlashAttnProcessor2_0(partition_spec)
else:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)

Problem:
The method checks if is_torch_xla_available: instead of if is_torch_xla_available():, so environments without torch_xla enter the XLA version checks and can raise InvalidVersion. The method also raises strings in two branches.

Impact:
Users enabling XLA flash attention get confusing exceptions instead of a clear dependency/version error.

Reproduction:

from diffusers.models.attention_processor import Attention

attn = Attention(query_dim=4, heads=1, dim_head=4)
try:
    attn.set_use_xla_flash_attention(True)
except Exception as e:
    print(type(e).__name__)
    print(str(e).splitlines()[0])

Relevant precedent:
AttentionModuleMixin.set_use_xla_flash_attention() in src/diffusers/models/attention.py calls is_torch_xla_available() and raises ImportError.

Suggested fix:

if use_xla_flash_attention:
    if not is_torch_xla_available():
        raise ImportError("torch_xla is not available")
    if is_torch_xla_version("<", "2.3"):
        raise ImportError("flash attention pallas kernel is supported from torch_xla version 2.3")

Issue 7: FlaxModelMixin.from_pretrained(config=...) can reference unused_kwargs before assignment

Affected code:

# Load config if we don't provide one
if config is None:
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
**kwargs,
)
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)

Problem:
unused_kwargs is assigned only in the config is None branch, but later used unconditionally. Passing a preloaded config can therefore raise UnboundLocalError. This was statically verified; the local .venv does not include Flax.

Impact:
Callers that pass an already loaded Flax config cannot reliably use from_pretrained(config=...).

Reproduction:

from diffusers import FlaxAutoencoderKL

config = {
    "in_channels": 3,
    "out_channels": 3,
    "down_block_types": ("DownEncoderBlock2D",),
    "up_block_types": ("UpDecoderBlock2D",),
    "block_out_channels": (32,),
    "layers_per_block": 1,
    "act_fn": "silu",
    "latent_channels": 4,
    "norm_num_groups": 32,
    "sample_size": 32,
    "scaling_factor": 0.18215,
}

try:
    FlaxAutoencoderKL.from_pretrained("does-not-matter", config=config, local_files_only=True)
except Exception as e:
    print(type(e).__name__, str(e).splitlines()[0])

Relevant precedent:
The PyTorch loading path preserves extra kwargs regardless of whether config is loaded internally or supplied by the caller.

Suggested fix:

if config is None:
    config, unused_kwargs = cls.load_config(..., **kwargs)
else:
    unused_kwargs = kwargs

Issue 8: Sinusoidal embedding helper defaults to float64 on non-MPS devices

Affected code:

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
"""
if output_type == "np":
deprecation_message = (
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# Auto-detect appropriate dtype if not specified
if dtype is None:
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.outer(pos, omega) # (M, D/2), outer product

Problem:
For output_type="pt", the helper defaults dtype to torch.float64 unless the device is MPS. The review rules call out NPU float64 limitations, but NPU takes the same non-MPS branch.

Impact:
Embedding creation can fail on NPU, and CPU/GPU callers get an unexpectedly high-precision tensor unless they override dtype.

Reproduction:

import torch
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid

pos = torch.arange(4, dtype=torch.float32)
emb = get_1d_sincos_pos_embed_from_grid(8, pos, output_type="pt")
print(emb.dtype)

# On NPU this same non-MPS branch attempts float64 tensor creation:
# pos = pos.to("npu")
# get_1d_sincos_pos_embed_from_grid(8, pos, output_type="pt")

Relevant precedent:
NPU-safe code paths in the repository avoid implicit float64 tensors.

Suggested fix:

if dtype is None:
    dtype = torch.float32

Issue 9: Slow/integration coverage is missing for shared model infrastructure regressions

Affected code:

class TestAutoModel(unittest.TestCase):
@patch(
"diffusers.models.AutoModel.load_config",
side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}],
)
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
assert isinstance(model, UNet2DConditionModel)
@patch(
"diffusers.models.AutoModel.load_config",
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
)
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
)
assert isinstance(model, CLIPTextModel)
def test_load_from_config_without_subfolder(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-longformer")
assert isinstance(model, LongformerModel)
def test_load_from_model_index(self):
model = AutoModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
)
assert isinstance(model, CLIPTextModel)
def test_load_dynamic_module_from_local_path_with_subfolder(self):
CUSTOM_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class CustomModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=8):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with tempfile.TemporaryDirectory() as tmpdir:
subfolder = "custom_model"
model_dir = os.path.join(tmpdir, subfolder)
os.makedirs(model_dir)
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
f.write(CUSTOM_MODEL_CODE)
config = {
"_class_name": "CustomModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.CustomModel"},
"hidden_size": 8,
}
with open(os.path.join(model_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
assert model.__class__.__name__ == "CustomModel"
assert model.config["hidden_size"] == 8

class Upsample2DBlockTests(unittest.TestCase):
def test_upsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample2D(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
@require_torch_version_greater_equal("2.1")
def test_upsample_bfloat16(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16)
upsample = Upsample2D(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample2D(channels=32, use_conv=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 64, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_transpose(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class Downsample2DBlockTests(unittest.TestCase):
def test_downsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample2D(channels=32, use_conv=False)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179])
max_diff = (output_slice.flatten() - expected_slice).abs().sum().item()
assert max_diff <= 1e-3
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
def test_downsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample2D(channels=32, use_conv=True)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913],
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_downsample_with_conv_pad1(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample2D(channels=32, use_conv=True, padding=1)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 32, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_downsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64)
downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
with torch.no_grad():
downsampled = downsample(sample)
assert downsampled.shape == (1, 16, 32, 32)
output_slice = downsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class ResnetBlock2DTests(unittest.TestCase):
def test_resnet_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 64, 64)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_restnet_with_use_in_shortcut(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 64, 64)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_resnet_up(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 128, 128)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_resnet_down(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 32, 32)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_restnet_with_kernel_fir(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 32, 32)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_restnet_with_kernel_sde_vp(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 32, 32)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)
backend_manual_seed(torch_device, 0)

def _test_cache_context_manager(self, atol=1e-5, rtol=0):
"""Test the cache_context context manager properly isolates cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# Run inference in first context
with model.cache_context("context_1"):
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
# Run same inference in second context (cache should be reset)
with model.cache_context("context_2"):
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
# Both contexts should produce the same output (first pass in each)
assert_tensors_close(
output_ctx1,
output_ctx2,
atol=atol,
rtol=rtol,
msg="First pass in different cache contexts should produce the same output.",
)
model.disable_cache()

Problem:
Fast coverage exists, but it does not cover the integration-style failures found above: AutoModel PathLike/state leakage across calls, cache cleanup after exceptions, FIR low-precision paths, and unavailable/uninitialized distributed backends. No dedicated slow test covers these shared infrastructure behaviors through a tiny pipeline/model load.

Impact:
Regressions in shared model infrastructure can affect many model and pipeline families without being caught by family-specific fast tests.

Reproduction:

from pathlib import Path

interesting = [
    Path("tests/models/test_models_auto.py"),
    Path("tests/models/test_layers_utils.py"),
    Path("tests/models/testing_utils/cache.py"),
    Path("tests/others/test_attention_backends.py"),
]
for path in interesting:
    text = path.read_text()
    print(path, "@slow" in text or "slow(" in text)

Relevant precedent:
Other model/pipeline families combine focused fast tests with slow tests that exercise actual loading/runtime behavior.

Suggested fix:
Add fast regression tests for Issues 1-8. Add at least one slow/integration test using a tiny Hub fixture or saved tiny local pipeline to exercise AutoModel.from_pretrained, cache/offload/attention setup, and shared dtype/device behavior through public APIs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions