diff --git a/src/diffusers/models/autoencoders/autoencoder_vidtok.py b/src/diffusers/models/autoencoders/autoencoder_vidtok.py index 63aadb2dbc9c..296c7bd8d85a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_vidtok.py +++ b/src/diffusers/models/autoencoders/autoencoder_vidtok.py @@ -1502,5 +1502,5 @@ def forward( dec = dec[:, :, :-time_padding, :, :] if not return_dict: - return dec + return (dec,) return DecoderOutput(sample=dec) diff --git a/tests/models/autoencoders/test_models_autoencoder_vidtok.py b/tests/models/autoencoders/test_models_autoencoder_vidtok.py index 087dca5debfa..9810296a07d9 100644 --- a/tests/models/autoencoders/test_models_autoencoder_vidtok.py +++ b/tests/models/autoencoders/test_models_autoencoder_vidtok.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import AutoencoderVidTok @@ -80,7 +79,6 @@ def get_dummy_inputs(self) -> dict: class TestAutoencoderVidTok(AutoencoderVidTokTesterConfig, ModelTesterMixin): - @pytest.mark.skip("VidTok output structure not compatible with recursive output check.") def test_outputs_equivalence(self): super().test_outputs_equivalence() diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index 44cce6af68e5..31e1a0085659 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -210,11 +210,14 @@ def test_mixed_precision_training(self): # Test with bfloat16 if torch.device(torch_device).type != "cpu": - model.zero_grad() - with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): - output = model(**inputs_dict, return_dict=False)[0] + if torch.device(torch_device).type == "cuda" and not torch.cuda.is_bf16_supported(): + pytest.skip("bfloat16 training is not supported on this GPU.") + else: + model.zero_grad() + with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): + output = model(**inputs_dict, return_dict=False)[0] - noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() + loss.backward()