diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 36d0893734c7..ca0af8a1c711 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -159,13 +159,19 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return self.processor def set_attention_backend(self, backend: str): - from .attention_dispatch import AttentionBackendName + from .attention_dispatch import ( + AttentionBackendName, + _check_attention_backend_requirements, + _maybe_download_kernel_for_backend, + ) available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend.lower()) + _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) self.processor._attention_backend = backend def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: