From 89c81b84f5d9b7245680c6b969f1d77430d564e6 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 23 Mar 2026 13:26:26 +0800 Subject: [PATCH 1/6] UT expands to batch inputs --- tests/models/testing_utils/parallelism.py | 11 +++++ .../test_models_transformer_qwenimage.py | 41 +++++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index db9817c86995..fced7602419b 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -22,6 +22,7 @@ import torch.multiprocessing as mp from diffusers.models._modeling_parallel import ContextParallelConfig +from diffusers.models.attention_dispatch import _AttentionBackendRegistry, AttentionBackendName from ...testing_utils import ( is_context_parallel, @@ -167,6 +168,11 @@ def test_context_parallel_inference(self, cp_type): if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + world_size = 2 init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -209,6 +215,11 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + world_size = 2 init_dict = self.get_init_dict() inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()} diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 713a1bec70a5..c8e03931faee 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -78,7 +78,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: } def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 + batch_size = 2 num_latent_channels = embedding_dim = 16 height = width = 4 sequence_length = 8 @@ -122,7 +122,7 @@ def test_infers_text_seq_len_from_mask(self): assert isinstance(per_sample_len, torch.Tensor) assert int(per_sample_len.max().item()) == 2 assert normalized_mask.dtype == torch.bool - assert normalized_mask.sum().item() == 2 + assert normalized_mask.sum().item() == 4 assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] inputs["encoder_hidden_states_mask"] = normalized_mask @@ -139,7 +139,7 @@ def test_infers_text_seq_len_from_mask(self): ) assert int(per_sample_len2.max().item()) == 8 - assert normalized_mask2.sum().item() == 5 + assert normalized_mask2.sum().item() == 10 rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None @@ -219,7 +219,7 @@ def test_layered_model_with_mask(self): assert isinstance(model.pos_embed, QwenEmbedLayer3DRope) - batch_size = 1 + batch_size = 2 text_seq_len = 8 img_h, img_w = 4, 4 layers = 4 @@ -230,9 +230,9 @@ def test_layered_model_with_mask(self): encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) encoder_hidden_states_mask[0, 5:] = 0 - timestep = torch.tensor([1.0]).to(torch_device) + timestep = torch.tensor([1.0, 1.0]).to(torch_device) - addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + addition_t_cond = torch.tensor([0, 0], dtype=torch.long).to(torch_device) img_shapes = [ [ @@ -242,7 +242,7 @@ def test_layered_model_with_mask(self): (1, img_h, img_w), (1, img_h, img_w), ] - ] + ] * batch_size with torch.no_grad(): output = model( @@ -276,6 +276,33 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for QwenImage Transformer.""" + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 # TODO: context parallel failed with batch size > 1, need fix + num_latent_channels = embedding_dim = 16 + height = width = 4 + sequence_length = 8 + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): """LoRA adapter tests for QwenImage Transformer.""" From 0f2ce00577f6ae4ce80aa4b1a00571adbda0d09d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 23 Mar 2026 14:33:59 +0800 Subject: [PATCH 2/6] update according to suggestion --- .../test_models_transformer_qwenimage.py | 68 ++++++++----------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index c8e03931faee..c977cb1cab39 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import warnings import torch +import pytest from diffusers import QwenImageTransformer2DModel from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask @@ -77,8 +79,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": (8, 4, 4), } - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 2 + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: num_latent_channels = embedding_dim = 16 height = width = 4 sequence_length = 8 @@ -106,9 +107,10 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin): - def test_infers_text_seq_len_from_mask(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_infers_text_seq_len_from_mask(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() @@ -122,7 +124,7 @@ def test_infers_text_seq_len_from_mask(self): assert isinstance(per_sample_len, torch.Tensor) assert int(per_sample_len.max().item()) == 2 assert normalized_mask.dtype == torch.bool - assert normalized_mask.sum().item() == 4 + assert normalized_mask.sum().item() == 2 * batch_size assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] inputs["encoder_hidden_states_mask"] = normalized_mask @@ -139,7 +141,7 @@ def test_infers_text_seq_len_from_mask(self): ) assert int(per_sample_len2.max().item()) == 8 - assert normalized_mask2.sum().item() == 10 + assert normalized_mask2.sum().item() == 5 * batch_size rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None @@ -149,9 +151,10 @@ def test_infers_text_seq_len_from_mask(self): assert per_sample_len_none is None assert normalized_mask_none is None - def test_non_contiguous_attention_mask(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_non_contiguous_attention_mask(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() @@ -174,9 +177,10 @@ def test_non_contiguous_attention_mask(self): assert output.sample.shape[1] == inputs["hidden_states"].shape[1] - def test_txt_seq_lens_deprecation(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_txt_seq_lens_deprecation(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] @@ -219,7 +223,7 @@ def test_layered_model_with_mask(self): assert isinstance(model.pos_embed, QwenEmbedLayer3DRope) - batch_size = 2 + batch_size = 1 text_seq_len = 8 img_h, img_w = 4, 4 layers = 4 @@ -230,9 +234,9 @@ def test_layered_model_with_mask(self): encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) encoder_hidden_states_mask[0, 5:] = 0 - timestep = torch.tensor([1.0, 1.0]).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device) - addition_t_cond = torch.tensor([0, 0], dtype=torch.long).to(torch_device) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) img_shapes = [ [ @@ -242,7 +246,7 @@ def test_layered_model_with_mask(self): (1, img_h, img_w), (1, img_h, img_w), ] - ] * batch_size + ] with torch.no_grad(): output = model( @@ -276,32 +280,16 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for QwenImage Transformer.""" - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 # TODO: context parallel failed with batch size > 1, need fix - num_latent_channels = embedding_dim = 16 - height = width = 4 - sequence_length = 8 - vae_scale_factor = 4 - - hidden_states = randn_tensor( - (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device - ) - encoder_hidden_states = randn_tensor( - (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device - ) - encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) - orig_height = height * 2 * vae_scale_factor - orig_width = width * 2 * vae_scale_factor - img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size - - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "encoder_hidden_states_mask": encoder_hidden_states_mask, - "timestep": timestep, - "img_shapes": img_shapes, - } + @pytest.mark.parametrize( + "batch_size", + [ + 1, + pytest.param(2, marks=pytest.mark.xfail(reason="Context parallel does not support batch_size > 1")), + ], + ) + def test_context_parallel_batch_size(self, batch_size): + self.get_dummy_inputs = functools.partial(self.get_dummy_inputs, batch_size=batch_size) + self.test_context_parallel_inference("ulysses_degree") class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): From ec5b58fa8d38691e16248b15f0dcfa6aa6e11cae Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 23 Mar 2026 15:54:41 +0800 Subject: [PATCH 3/6] update according to suggestion 2 --- tests/models/testing_utils/parallelism.py | 18 ++++++++++++++++-- .../test_models_transformer_qwenimage.py | 17 ++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index fced7602419b..769752ab599b 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import socket @@ -161,7 +162,7 @@ def _custom_mesh_worker( @require_torch_multi_accelerator class ContextParallelTesterMixin: @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_inference(self, cp_type): + def test_context_parallel_inference(self, cp_type, batch_size=None): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") @@ -175,7 +176,12 @@ def test_context_parallel_inference(self, cp_type): world_size = 2 init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() + + # get_dummy_inputs may or may not support a batch_size argument + if batch_size is not None and "batch_size" in inspect.signature(self.get_dummy_inputs).parameters: + inputs_dict = self.get_dummy_inputs(batch_size=batch_size) + else: + inputs_dict = self.get_dummy_inputs() # Move all tensors to CPU for multiprocessing inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} @@ -200,6 +206,14 @@ def test_context_parallel_inference(self, cp_type): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_batch_inputs(self, cp_type): + if "batch_size" not in inspect.signature(self.get_dummy_inputs).parameters: + pytest.skip("get_dummy_inputs does not support a batch_size parameter.") + + self.test_context_parallel_inference(cp_type, batch_size=2) + @pytest.mark.parametrize( "cp_type,mesh_shape,mesh_dim_names", [ diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index c977cb1cab39..d3fbfb61e549 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import warnings import torch @@ -177,10 +176,9 @@ def test_non_contiguous_attention_mask(self, batch_size): assert output.sample.shape[1] == inputs["hidden_states"].shape[1] - @pytest.mark.parametrize("batch_size", [1, 2]) - def test_txt_seq_lens_deprecation(self, batch_size): + def test_txt_seq_lens_deprecation(self): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs(batch_size=batch_size) + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] @@ -280,17 +278,6 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for QwenImage Transformer.""" - @pytest.mark.parametrize( - "batch_size", - [ - 1, - pytest.param(2, marks=pytest.mark.xfail(reason="Context parallel does not support batch_size > 1")), - ], - ) - def test_context_parallel_batch_size(self, batch_size): - self.get_dummy_inputs = functools.partial(self.get_dummy_inputs, batch_size=batch_size) - self.test_context_parallel_inference("ulysses_degree") - class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): """LoRA adapter tests for QwenImage Transformer.""" From 26a511d7e788e96f0cc3f8d7c22d2766f8c76d02 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 23 Mar 2026 18:25:14 +0800 Subject: [PATCH 4/6] fix CI --- tests/models/testing_utils/parallelism.py | 4 ++-- .../models/transformers/test_models_transformer_qwenimage.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 769752ab599b..43d2f7176397 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -23,7 +23,7 @@ import torch.multiprocessing as mp from diffusers.models._modeling_parallel import ContextParallelConfig -from diffusers.models.attention_dispatch import _AttentionBackendRegistry, AttentionBackendName +from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry from ...testing_utils import ( is_context_parallel, @@ -176,7 +176,7 @@ def test_context_parallel_inference(self, cp_type, batch_size=None): world_size = 2 init_dict = self.get_init_dict() - + # get_dummy_inputs may or may not support a batch_size argument if batch_size is not None and "batch_size" in inspect.signature(self.get_dummy_inputs).parameters: inputs_dict = self.get_dummy_inputs(batch_size=batch_size) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index d3fbfb61e549..5b45577f2dff 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -14,8 +14,8 @@ import warnings -import torch import pytest +import torch from diffusers import QwenImageTransformer2DModel from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask From bb6459014dbca85df210003c16cc98d9ae736d2b Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 24 Mar 2026 11:00:55 +0800 Subject: [PATCH 5/6] update according to suggestion 3 --- tests/models/testing_utils/parallelism.py | 12 ++---------- .../transformers/test_models_transformer_flux.py | 3 +-- .../transformers/test_models_transformer_flux2.py | 3 +-- 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 43d2f7176397..46fdcbfd735e 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import os import socket @@ -162,7 +161,7 @@ def _custom_mesh_worker( @require_torch_multi_accelerator class ContextParallelTesterMixin: @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_inference(self, cp_type, batch_size=None): + def test_context_parallel_inference(self, cp_type, batch_size: int = 1): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") @@ -177,11 +176,7 @@ def test_context_parallel_inference(self, cp_type, batch_size=None): world_size = 2 init_dict = self.get_init_dict() - # get_dummy_inputs may or may not support a batch_size argument - if batch_size is not None and "batch_size" in inspect.signature(self.get_dummy_inputs).parameters: - inputs_dict = self.get_dummy_inputs(batch_size=batch_size) - else: - inputs_dict = self.get_dummy_inputs() + inputs_dict = self.get_dummy_inputs(batch_size=batch_size) # Move all tensors to CPU for multiprocessing inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} @@ -209,9 +204,6 @@ def test_context_parallel_inference(self, cp_type, batch_size=None): @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_batch_inputs(self, cp_type): - if "batch_size" not in inspect.signature(self.get_dummy_inputs).parameters: - pytest.skip("get_dummy_inputs does not support a batch_size parameter.") - self.test_context_parallel_inference(cp_type, batch_size=2) @pytest.mark.parametrize( diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 24be833d0ed2..a15b7be50b97 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -150,8 +150,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": [4, 4, 8], } - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: height = width = 4 num_latent_channels = 4 num_image_channels = 3 diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index a109f603411d..77b5f1b86e59 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -90,8 +90,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": [4, 4, 4, 4], } - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]: num_latent_channels = 4 sequence_length = 48 embedding_dim = 32 From 33aad003c808a930ef1c2655c9dde23aa9783364 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 24 Mar 2026 11:02:58 +0800 Subject: [PATCH 6/6] clean line --- tests/models/testing_utils/parallelism.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 46fdcbfd735e..2b6aab59a662 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -175,7 +175,6 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): world_size = 2 init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs(batch_size=batch_size) # Move all tensors to CPU for multiprocessing