From 36c203da57381eab637b626f4d1be00a92c4c32b Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Fri, 17 Apr 2026 17:06:26 +0800
Subject: [PATCH 01/16] model-code
---
diffsynth/configs/model_configs.py | 112 ++-
diffsynth/diffusion/flow_match.py | 23 +-
diffsynth/models/ace_step_conditioner.py | 709 ++++++++++++++
diffsynth/models/ace_step_dit.py | 908 ++++++++++++++++++
diffsynth/models/ace_step_lm.py | 79 ++
diffsynth/models/ace_step_text_encoder.py | 80 ++
diffsynth/models/ace_step_tokenizer.py | 732 ++++++++++++++
diffsynth/models/ace_step_vae.py | 241 +++++
diffsynth/pipelines/ace_step.py | 527 ++++++++++
.../ace_step_conditioner.py | 48 +
.../state_dict_converters/ace_step_dit.py | 43 +
.../state_dict_converters/ace_step_lm.py | 55 ++
.../ace_step_text_encoder.py | 39 +
.../ace_step_tokenizer.py | 27 +
.../model_inference/Ace-Step1.5-SimpleMode.py | 180 ++++
.../ace_step/model_inference/Ace-Step1.5.py | 67 ++
.../model_inference/acestep-v15-base.py | 52 +
.../model_inference/acestep-v15-sft.py | 52 +
.../acestep-v15-turbo-shift1.py | 52 +
.../acestep-v15-turbo-shift3.py | 52 +
.../model_inference/acestep-v15-xl-base.py | 52 +
.../model_inference/acestep-v15-xl-sft.py | 50 +
.../model_inference/acestep-v15-xl-turbo.py | 52 +
23 files changed, 4230 insertions(+), 2 deletions(-)
create mode 100644 diffsynth/models/ace_step_conditioner.py
create mode 100644 diffsynth/models/ace_step_dit.py
create mode 100644 diffsynth/models/ace_step_lm.py
create mode 100644 diffsynth/models/ace_step_text_encoder.py
create mode 100644 diffsynth/models/ace_step_tokenizer.py
create mode 100644 diffsynth/models/ace_step_vae.py
create mode 100644 diffsynth/pipelines/ace_step.py
create mode 100644 diffsynth/utils/state_dict_converters/ace_step_conditioner.py
create mode 100644 diffsynth/utils/state_dict_converters/ace_step_dit.py
create mode 100644 diffsynth/utils/state_dict_converters/ace_step_lm.py
create mode 100644 diffsynth/utils/state_dict_converters/ace_step_text_encoder.py
create mode 100644 diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
create mode 100644 examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
create mode 100644 examples/ace_step/model_inference/Ace-Step1.5.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-base.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-sft.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-turbo-shift1.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-turbo-shift3.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-xl-base.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-xl-sft.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-xl-turbo.py
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index 5fc95c3e..ee97fec3 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -916,4 +916,114 @@
},
]
-MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
+ace_step_series = [
+ # === Standard DiT variants (24 layers, hidden_size=2048) ===
+ # Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft
+ # All share identical state_dict structure → same hash
+ {
+ # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
+ "model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
+ "model_name": "ace_step_dit",
+ "model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
+ },
+ # === XL DiT variants (32 layers, hidden_size=2560) ===
+ # Covers: xl-base, xl-sft, xl-turbo
+ {
+ # Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
+ "model_hash": "3a28a410c2246f125153ef792d8bc828",
+ "model_name": "ace_step_dit",
+ "model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
+ "extra_kwargs": {
+ "hidden_size": 2560,
+ "intermediate_size": 9728,
+ "num_hidden_layers": 32,
+ "num_attention_heads": 32,
+ "num_key_value_heads": 8,
+ "head_dim": 128,
+ "encoder_hidden_size": 2048,
+ "layer_types": ["sliding_attention", "full_attention"] * 16,
+ },
+ },
+ # === Conditioner (shared by all DiT variants, same architecture) ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
+ "model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
+ "model_name": "ace_step_conditioner",
+ "model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
+ },
+ # === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
+ "model_hash": "3a28a410c2246f125153ef792d8bc828",
+ "model_name": "ace_step_conditioner",
+ "model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
+ },
+ # === LLM variants ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-0.6B", origin_file_pattern="model.safetensors")
+ "model_hash": "f3ab4bef9e00745fd0fea7aa8b2a4041",
+ "model_name": "ace_step_lm",
+ "model_class": "diffsynth.models.ace_step_lm.AceStepLM",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
+ "extra_kwargs": {
+ "variant": "acestep-5Hz-lm-0.6B",
+ },
+ },
+ {
+ # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-5Hz-lm-1.7B/model.safetensors")
+ "model_hash": "a14b6e422b0faa9b41e7efe0fee46766",
+ "model_name": "ace_step_lm",
+ "model_class": "diffsynth.models.ace_step_lm.AceStepLM",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
+ "extra_kwargs": {
+ "variant": "acestep-5Hz-lm-1.7B",
+ },
+ },
+ {
+ # Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-4B", origin_file_pattern="model-*.safetensors")
+ "model_hash": "046a3934f2e6f2f6d450bad23b1f4933",
+ "model_name": "ace_step_lm",
+ "model_class": "diffsynth.models.ace_step_lm.AceStepLM",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
+ "extra_kwargs": {
+ "variant": "acestep-5Hz-lm-4B",
+ },
+ },
+ # === Qwen3-Embedding (text encoder) ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
+ "model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
+ "model_name": "ace_step_text_encoder",
+ "model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter",
+ },
+ # === VAE (AutoencoderOobleck CNN) ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
+ "model_hash": "51420834e54474986a7f4be0e4d6f687",
+ "model_name": "ace_step_vae",
+ "model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
+ },
+ # === Tokenizer (VAE latent discretization: tokenizer + detokenizer) ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
+ "model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
+ "model_name": "ace_step_tokenizer",
+ "model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
+ },
+ # === XL Tokenizer (XL models share same tokenizer architecture) ===
+ {
+ # Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
+ "model_hash": "3a28a410c2246f125153ef792d8bc828",
+ "model_name": "ace_step_tokenizer",
+ "model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
+ },
+]
+
+MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py
index 6c1b846b..f2416838 100644
--- a/diffsynth/diffusion/flow_match.py
+++ b/diffsynth/diffusion/flow_match.py
@@ -4,7 +4,7 @@
class FlowMatchScheduler():
- def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
+ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -14,6 +14,7 @@ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
+ "ACE-Step": FlowMatchScheduler.set_timesteps_ace_step,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -142,6 +143,26 @@ def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, sh
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
+ @staticmethod
+ def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
+ """ACE-Step Flow Matching scheduler.
+
+ Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
+ Shift transformation: t = shift * t / (1 + (shift - 1) * t)
+
+ Args:
+ num_inference_steps: Number of diffusion steps.
+ denoising_strength: Denoising strength (1.0 = full denoising).
+ shift: Timestep shift parameter (default 3.0 for turbo).
+ """
+ num_train_timesteps = 1000
+ sigma_start = denoising_strength
+ sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps)
+ if shift is not None and shift != 1.0:
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+ timesteps = sigmas # ACE-Step uses [0, 1] range directly
+ return sigmas, timesteps
+
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
diff --git a/diffsynth/models/ace_step_conditioner.py b/diffsynth/models/ace_step_conditioner.py
new file mode 100644
index 00000000..93fe0d32
--- /dev/null
+++ b/diffsynth/models/ace_step_conditioner.py
@@ -0,0 +1,709 @@
+# Copyright 2025 The ACESTEO Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from einops import rearrange
+
+from ..core.attention import attention_forward
+from ..core.gradient import gradient_checkpoint_forward
+
+from transformers.cache_utils import Cache
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.processing_utils import Unpack
+from transformers.utils import can_return_tuple, logging
+from transformers.models.qwen3.modeling_qwen3 import (
+ Qwen3MLP,
+ Qwen3RMSNorm,
+ Qwen3RotaryEmbedding,
+ apply_rotary_pos_emb,
+)
+
+logger = logging.get_logger(__name__)
+
+
+def create_4d_mask(
+ seq_len: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window: Optional[int] = None,
+ is_sliding_window: bool = False,
+ is_causal: bool = True,
+) -> torch.Tensor:
+ indices = torch.arange(seq_len, device=device)
+ diff = indices.unsqueeze(1) - indices.unsqueeze(0)
+ valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
+ if is_causal:
+ valid_mask = valid_mask & (diff >= 0)
+ if is_sliding_window and sliding_window is not None:
+ if is_causal:
+ valid_mask = valid_mask & (diff <= sliding_window)
+ else:
+ valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
+ valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
+ if attention_mask is not None:
+ padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
+ valid_mask = valid_mask & padding_mask_4d
+ min_dtype = torch.finfo(dtype).min
+ mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
+ mask_tensor.masked_fill_(valid_mask, 0.0)
+ return mask_tensor
+
+
+def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
+ hidden_cat = torch.cat([hidden1, hidden2], dim=1)
+ mask_cat = torch.cat([mask1, mask2], dim=1)
+ B, L, D = hidden_cat.shape
+ sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
+ hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
+ lengths = mask_cat.sum(dim=1)
+ new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
+ return hidden_left, new_mask
+
+
+class Lambda(nn.Module):
+ def __init__(self, func):
+ super().__init__()
+ self.func = func
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class AceStepAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ rms_norm_eps: float,
+ attention_bias: bool,
+ attention_dropout: float,
+ layer_types: list,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ is_cross_attention: bool = False,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.num_key_value_groups = num_attention_heads // num_key_value_heads
+ self.scaling = self.head_dim ** -0.5
+ self.attention_dropout = attention_dropout
+ if is_cross_attention:
+ is_causal = False
+ self.is_causal = is_causal
+ self.is_cross_attention = is_cross_attention
+
+ self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
+ self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
+ self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
+ self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.attention_type = layer_types[layer_idx]
+ self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+
+ is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
+
+ if is_cross_attention:
+ encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ curr_past_key_value = past_key_value.cross_attention_cache
+ if not is_updated:
+ key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
+ key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
+ past_key_value.is_updated[self.layer_idx] = True
+ else:
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
+
+ else:
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.num_key_value_groups > 1:
+ key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
+ value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
+
+ attn_output = attention_forward(
+ query_states, key_states, value_states,
+ q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
+ attn_mask=attention_mask,
+ )
+ attn_weights = None
+
+ attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class AceStepEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ rms_norm_eps: float,
+ attention_bias: bool,
+ attention_dropout: float,
+ layer_types: list,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.layer_idx = layer_idx
+
+ self.self_attn = AceStepAttention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ is_cross_attention=False,
+ is_causal=False,
+ )
+ self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+
+ mlp_config = type('Config', (), {
+ 'hidden_size': hidden_size,
+ 'intermediate_size': intermediate_size,
+ 'hidden_act': 'silu',
+ })()
+ self.mlp = Qwen3MLP(mlp_config)
+ self.attention_type = layer_types[layer_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ use_cache=False,
+ past_key_value=None,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ return outputs
+
+
+class AceStepLyricEncoder(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ use_cache: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ text_hidden_dim: int = 1024,
+ num_lyric_encoder_hidden_layers: int = 8,
+ **kwargs,
+ ):
+ super().__init__()
+ self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
+ self.text_hidden_dim = text_hidden_dim
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_sliding_window = use_sliding_window
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
+
+ self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
+ self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ rope_config = type('RopeConfig', (), {
+ 'hidden_size': hidden_size,
+ 'num_attention_heads': num_attention_heads,
+ 'num_key_value_heads': num_key_value_heads,
+ 'head_dim': head_dim,
+ 'max_position_embeddings': max_position_embeddings,
+ 'rope_theta': rope_theta,
+ 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
+ 'rms_norm_eps': rms_norm_eps,
+ 'attention_bias': attention_bias,
+ 'attention_dropout': attention_dropout,
+ 'hidden_act': 'silu',
+ 'intermediate_size': intermediate_size,
+ 'layer_types': self.layer_types,
+ 'sliding_window': sliding_window,
+ '_attn_implementation': self._attn_implementation,
+ })()
+ self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
+ self.gradient_checkpointing = False
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=self.layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ )
+ for layer_idx in range(num_lyric_encoder_hidden_layers)
+ ])
+
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutput:
+ output_attentions = output_attentions if output_attentions is not None else False
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+
+ assert input_ids is None, "Only `inputs_embeds` is supported for the lyric encoder."
+ assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
+ assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
+
+ inputs_embeds = self.embed_tokens(inputs_embeds)
+ cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ seq_len = inputs_embeds.shape[1]
+ dtype = inputs_embeds.dtype
+ device = inputs_embeds.device
+
+ full_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=None,
+ is_sliding_window=False, is_causal=False
+ )
+ sliding_attn_mask = None
+ if self.use_sliding_window:
+ sliding_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=self.sliding_window,
+ is_sliding_window=True, is_causal=False
+ )
+
+ self_attn_mask_mapping = {
+ "full_attention": full_attn_mask,
+ "sliding_attention": sliding_attn_mask,
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for layer_module in self.layers[: self.num_lyric_encoder_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states, position_embeddings,
+ self_attn_mask_mapping[layer_module.attention_type],
+ position_ids, output_attentions,
+ **flash_attn_kwargs,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class AceStepTimbreEncoder(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ use_cache: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ timbre_hidden_dim: int = 64,
+ num_timbre_encoder_hidden_layers: int = 4,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_sliding_window = use_sliding_window
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.timbre_hidden_dim = timbre_hidden_dim
+ self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
+ self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
+
+ self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
+ self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ rope_config = type('RopeConfig', (), {
+ 'hidden_size': hidden_size,
+ 'num_attention_heads': num_attention_heads,
+ 'num_key_value_heads': num_key_value_heads,
+ 'head_dim': head_dim,
+ 'max_position_embeddings': max_position_embeddings,
+ 'rope_theta': rope_theta,
+ 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
+ 'rms_norm_eps': rms_norm_eps,
+ 'attention_bias': attention_bias,
+ 'attention_dropout': attention_dropout,
+ 'hidden_act': 'silu',
+ 'intermediate_size': intermediate_size,
+ 'layer_types': self.layer_types,
+ 'sliding_window': sliding_window,
+ '_attn_implementation': self._attn_implementation,
+ })()
+ self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
+ self.gradient_checkpointing = False
+ self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=self.layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ )
+ for layer_idx in range(num_timbre_encoder_hidden_layers)
+ ])
+
+
+ def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
+ N, d = timbre_embs_packed.shape
+ device = timbre_embs_packed.device
+ dtype = timbre_embs_packed.dtype
+ B = int(refer_audio_order_mask.max().item() + 1)
+ counts = torch.bincount(refer_audio_order_mask, minlength=B)
+ max_count = counts.max().item()
+ sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
+ sorted_batch_ids = refer_audio_order_mask[sorted_indices]
+ positions = torch.arange(N, device=device)
+ batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
+ positions_in_sorted = positions - batch_starts[sorted_batch_ids]
+ inverse_indices = torch.empty_like(sorted_indices)
+ inverse_indices[sorted_indices] = torch.arange(N, device=device)
+ positions_in_batch = positions_in_sorted[inverse_indices]
+ indices_2d = refer_audio_order_mask * max_count + positions_in_batch
+ one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
+ timbre_embs_flat = one_hot.t() @ timbre_embs_packed
+ timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
+ mask_flat = (one_hot.sum(dim=0) > 0).long()
+ new_mask = mask_flat.reshape(B, max_count)
+ return timbre_embs_unpack, new_mask
+
+ @can_return_tuple
+ def forward(
+ self,
+ refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None,
+ refer_audio_order_mask: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutput:
+ inputs_embeds = refer_audio_acoustic_hidden_states_packed
+ inputs_embeds = self.embed_tokens(inputs_embeds)
+ # Handle 2D (packed) or 3D (batched) input
+ is_packed = inputs_embeds.dim() == 2
+ if is_packed:
+ seq_len = inputs_embeds.shape[0]
+ cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
+ position_ids = cache_position.unsqueeze(0)
+ inputs_embeds = inputs_embeds.unsqueeze(0)
+ else:
+ seq_len = inputs_embeds.shape[1]
+ cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
+ position_ids = cache_position.unsqueeze(0)
+
+ dtype = inputs_embeds.dtype
+ device = inputs_embeds.device
+
+ full_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=None,
+ is_sliding_window=False, is_causal=False
+ )
+ sliding_attn_mask = None
+ if self.use_sliding_window:
+ sliding_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=self.sliding_window,
+ is_sliding_window=True, is_causal=False
+ )
+
+ self_attn_mask_mapping = {
+ "full_attention": full_attn_mask,
+ "sliding_attention": sliding_attn_mask,
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for layer_module in self.layers[: self.num_timbre_encoder_hidden_layers]:
+ layer_outputs = layer_module(
+ hidden_states, position_embeddings,
+ self_attn_mask_mapping[layer_module.attention_type],
+ position_ids,
+ **flash_attn_kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ hidden_states = self.norm(hidden_states)
+ # For packed input: reshape [1, T, D] -> [T, D] for unpacking
+ if is_packed:
+ hidden_states = hidden_states.squeeze(0)
+ timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
+ return timbre_embs_unpack, timbre_embs_mask
+
+
+class AceStepConditionEncoder(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ use_cache: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ text_hidden_dim: int = 1024,
+ timbre_hidden_dim: int = 64,
+ num_lyric_encoder_hidden_layers: int = 8,
+ num_timbre_encoder_hidden_layers: int = 4,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_sliding_window = use_sliding_window
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.text_hidden_dim = text_hidden_dim
+ self.timbre_hidden_dim = timbre_hidden_dim
+ self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
+ self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
+ self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
+
+ self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
+ self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
+ self.lyric_encoder = AceStepLyricEncoder(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ use_sliding_window=use_sliding_window,
+ rope_theta=rope_theta,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ text_hidden_dim=text_hidden_dim,
+ num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
+ )
+ self.timbre_encoder = AceStepTimbreEncoder(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ use_sliding_window=use_sliding_window,
+ rope_theta=rope_theta,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ timbre_hidden_dim=timbre_hidden_dim,
+ num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
+ )
+
+ def forward(
+ self,
+ text_hidden_states: Optional[torch.FloatTensor] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ lyric_hidden_states: Optional[torch.LongTensor] = None,
+ lyric_attention_mask: Optional[torch.Tensor] = None,
+ refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None,
+ refer_audio_order_mask: Optional[torch.LongTensor] = None,
+ ):
+ text_hidden_states = self.text_projector(text_hidden_states)
+ lyric_encoder_outputs = self.lyric_encoder(
+ inputs_embeds=lyric_hidden_states,
+ attention_mask=lyric_attention_mask,
+ )
+ lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
+ timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(
+ refer_audio_acoustic_hidden_states_packed,
+ refer_audio_order_mask
+ )
+
+ encoder_hidden_states, encoder_attention_mask = pack_sequences(
+ lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
+ )
+ encoder_hidden_states, encoder_attention_mask = pack_sequences(
+ encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
+ )
+ return encoder_hidden_states, encoder_attention_mask
diff --git a/diffsynth/models/ace_step_dit.py b/diffsynth/models/ace_step_dit.py
new file mode 100644
index 00000000..c4621feb
--- /dev/null
+++ b/diffsynth/models/ace_step_dit.py
@@ -0,0 +1,908 @@
+# Copyright 2025 The ACESTEO Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..core.attention.attention import attention_forward
+from ..core import gradient_checkpoint_forward
+
+from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.processing_utils import Unpack
+from transformers.utils import logging
+
+from transformers.models.qwen3.modeling_qwen3 import (
+ Qwen3MLP,
+ Qwen3RMSNorm,
+ Qwen3RotaryEmbedding,
+ apply_rotary_pos_emb,
+)
+
+logger = logging.get_logger(__name__)
+
+
+def create_4d_mask(
+ seq_len: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
+ sliding_window: Optional[int] = None,
+ is_sliding_window: bool = False,
+ is_causal: bool = True,
+) -> torch.Tensor:
+ """
+ General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
+ Supports use cases:
+ 1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
+ 2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
+ 3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
+ 4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
+
+ Returns:
+ [Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
+ """
+ # ------------------------------------------------------
+ # 1. Construct basic geometry mask [Seq_Len, Seq_Len]
+ # ------------------------------------------------------
+
+ # Build index matrices
+ # i (Query): [0, 1, ..., L-1]
+ # j (Key): [0, 1, ..., L-1]
+ indices = torch.arange(seq_len, device=device)
+ # diff = i - j
+ diff = indices.unsqueeze(1) - indices.unsqueeze(0)
+
+ # Initialize all True (all positions visible)
+ valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
+
+ # (A) Handle causality (Causal)
+ if is_causal:
+ # i >= j => diff >= 0
+ valid_mask = valid_mask & (diff >= 0)
+
+ # (B) Handle sliding window
+ if is_sliding_window and sliding_window is not None:
+ if is_causal:
+ # Causal sliding: only attend to past window steps
+ # i - j <= window => diff <= window
+ # (diff >= 0 already handled above)
+ valid_mask = valid_mask & (diff <= sliding_window)
+ else:
+ # Bidirectional sliding: attend past and future window steps
+ # |i - j| <= window => abs(diff) <= sliding_window
+ valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
+
+ # Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
+ valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
+
+ # ------------------------------------------------------
+ # 2. Apply padding mask (Key Masking)
+ # ------------------------------------------------------
+ if attention_mask is not None:
+ # attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
+ # We want to mask out invalid keys (columns)
+ # Expand shape: [Batch, 1, 1, Seq_Len]
+ padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
+
+ # Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
+ # Result shape: [B, 1, L, L]
+ valid_mask = valid_mask & padding_mask_4d
+
+ # ------------------------------------------------------
+ # 3. Convert to additive mask
+ # ------------------------------------------------------
+ # Get the minimal value for current dtype
+ min_dtype = torch.finfo(dtype).min
+
+ # Create result tensor filled with -inf by default
+ mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
+
+ # Set valid positions to 0.0
+ mask_tensor.masked_fill_(valid_mask, 0.0)
+
+ return mask_tensor
+
+
+def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
+ """
+ Pack two sequences by concatenating and sorting them based on mask values.
+
+ Args:
+ hidden1: First hidden states tensor of shape [B, L1, D]
+ hidden2: Second hidden states tensor of shape [B, L2, D]
+ mask1: First mask tensor of shape [B, L1]
+ mask2: Second mask tensor of shape [B, L2]
+
+ Returns:
+ Tuple of (packed_hidden_states, new_mask) where:
+ - packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
+ - new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
+ """
+ # Step 1: Concatenate hidden states and masks along sequence dimension
+ hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
+ mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
+
+ B, L, D = hidden_cat.shape
+
+ # Step 2: Sort indices so that mask values of 1 come before 0
+ sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
+
+ # Step 3: Reorder hidden states using sorted indices
+ hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
+
+ # Step 4: Create new mask based on valid sequence lengths
+ lengths = mask_cat.sum(dim=1) # [B]
+ new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
+
+ return hidden_left, new_mask
+
+
+class TimestepEmbedding(nn.Module):
+ """
+ Timestep embedding module for diffusion models.
+
+ Converts timestep values into high-dimensional embeddings using sinusoidal
+ positional encoding, followed by MLP layers. Used for conditioning diffusion
+ models on timestep information.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ scale: float = 1000,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
+ self.act1 = nn.SiLU()
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
+ self.in_channels = in_channels
+
+ self.act2 = nn.SiLU()
+ self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
+ self.scale = scale
+
+ def timestep_embedding(self, t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ Args:
+ t: A 1-D tensor of N indices, one per batch element. These may be fractional.
+ dim: The dimension of the output embeddings.
+ max_period: Controls the minimum frequency of the embeddings.
+
+ Returns:
+ An (N, D) tensor of positional embeddings.
+ """
+ t = t * self.scale
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.in_channels)
+ temb = self.linear_1(t_freq.to(t.dtype))
+ temb = self.act1(temb)
+ temb = self.linear_2(temb)
+ timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
+ return temb, timestep_proj
+
+
+class AceStepAttention(nn.Module):
+ """
+ Multi-headed attention module for AceStep model.
+
+ Implements the attention mechanism from 'Attention Is All You Need' paper,
+ with support for both self-attention and cross-attention modes. Uses RMSNorm
+ for query and key normalization, and supports sliding window attention for
+ efficient long-sequence processing.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ rms_norm_eps: float,
+ attention_bias: bool,
+ attention_dropout: float,
+ layer_types: list,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ is_cross_attention: bool = False,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.num_key_value_groups = num_attention_heads // num_key_value_heads
+ self.scaling = self.head_dim ** -0.5
+ self.attention_dropout = attention_dropout
+ if is_cross_attention:
+ is_causal = False
+ self.is_causal = is_causal
+ self.is_cross_attention = is_cross_attention
+
+ self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
+ self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
+ self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
+ self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.attention_type = layer_types[layer_idx]
+ self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ # Project and normalize query states
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+
+ # Determine if this is cross-attention (requires encoder_hidden_states)
+ is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
+
+ # Cross-attention path: attend to encoder hidden states
+ if is_cross_attention:
+ encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ # After the first generated token, we can reuse all key/value states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+
+ # Conditions for calculating key and value states
+ if not is_updated:
+ # Compute and cache K/V for the first time
+ key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
+ # Update cache: save all key/value states to cache for fast auto-regressive generation
+ key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
+ # Set flag that this layer's cross-attention cache is updated
+ past_key_value.is_updated[self.layer_idx] = True
+ else:
+ # Reuse cached key/value states for subsequent tokens
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ # No cache used, compute K/V directly
+ key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
+
+ # Self-attention path: attend to the same sequence
+ else:
+ # Project and normalize key/value states for self-attention
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ # Apply rotary position embeddings (RoPE) if provided
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ # Update cache for auto-regressive generation
+ if past_key_value is not None:
+ # Sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # GGA expansion: if num_key_value_heads < num_attention_heads
+ if self.num_key_value_groups > 1:
+ key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
+ value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
+
+ # Use DiffSynth unified attention
+ # Tensors are already in (batch, heads, seq, dim) format -> "b n s d"
+ attn_output = attention_forward(
+ query_states, key_states, value_states,
+ q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
+ attn_mask=attention_mask,
+ )
+
+ attn_weights = None # attention_forward doesn't return weights
+
+ # Flatten and project output: (B, n_heads, seq, dim) -> (B, seq, n_heads*dim)
+ attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class AceStepEncoderLayer(nn.Module):
+ """
+ Encoder layer for AceStep model.
+
+ Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ intermediate_size: int = 6144,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: list = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.layer_idx = layer_idx
+
+ self.self_attn = AceStepAttention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ is_cross_attention=False,
+ is_causal=False,
+ )
+ self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+
+ # MLP (feed-forward) sub-layer
+ self.mlp = Qwen3MLP(
+ config=type('Config', (), {
+ 'hidden_size': hidden_size,
+ 'intermediate_size': intermediate_size,
+ 'hidden_act': 'silu',
+ })()
+ )
+ self.attention_type = layer_types[layer_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[
+ torch.FloatTensor,
+ Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
+ ]:
+ # Self-attention with residual connection
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ # Encoders don't use cache
+ use_cache=False,
+ past_key_value=None,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # MLP with residual connection
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class AceStepDiTLayer(nn.Module):
+ """
+ DiT (Diffusion Transformer) layer for AceStep model.
+
+ Implements a transformer layer with three main components:
+ 1. Self-attention with adaptive layer norm (AdaLN)
+ 2. Cross-attention (optional) for conditioning on encoder outputs
+ 3. Feed-forward MLP with adaptive layer norm
+
+ Uses scale-shift modulation from timestep embeddings for adaptive normalization.
+ """
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ intermediate_size: int,
+ rms_norm_eps: float,
+ attention_bias: bool,
+ attention_dropout: float,
+ layer_types: list,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ use_cross_attention: bool = True,
+ ):
+ super().__init__()
+
+ self.self_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.self_attn = AceStepAttention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ )
+
+ self.use_cross_attention = use_cross_attention
+ if self.use_cross_attention:
+ self.cross_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.cross_attn = AceStepAttention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ is_cross_attention=True,
+ )
+
+ self.mlp_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.mlp = Qwen3MLP(
+ config=type('Config', (), {
+ 'hidden_size': hidden_size,
+ 'intermediate_size': intermediate_size,
+ 'hidden_act': 'silu',
+ })()
+ )
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
+ self.attention_type = layer_types[layer_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[EncoderDecoderCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+
+ # Extract scale-shift parameters for adaptive layer norm from timestep embeddings
+ # 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb
+ ).chunk(6, dim=1)
+
+ # Step 1: Self-attention with adaptive layer norm (AdaLN)
+ # Apply adaptive normalization: norm(x) * (1 + scale) + shift
+ norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output, self_attn_weights = self.self_attn(
+ hidden_states=norm_hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ use_cache=False,
+ past_key_value=None,
+ **kwargs,
+ )
+ # Apply gated residual connection: x = x + attn_output * gate
+ hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
+
+ # Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
+ if self.use_cross_attention:
+ norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
+ attn_output, cross_attn_weights = self.cross_attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ # Standard residual connection for cross-attention
+ hidden_states = hidden_states + attn_output
+
+ # Step 3: Feed-forward (MLP) with adaptive layer norm
+ # Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
+ norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
+ ff_output = self.mlp(norm_hidden_states)
+ # Apply gated residual connection: x = x + mlp_output * gate
+ hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+
+class Lambda(nn.Module):
+ """
+ Wrapper module for arbitrary lambda functions.
+
+ Allows using lambda functions in nn.Sequential by wrapping them in a Module.
+ Useful for simple transformations like transpose operations.
+ """
+ def __init__(self, func):
+ super().__init__()
+ self.func = func
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class AceStepDiTModel(nn.Module):
+ """
+ DiT (Diffusion Transformer) model for AceStep.
+
+ Main diffusion model that generates audio latents conditioned on text, lyrics,
+ and timbre. Uses patch-based processing with transformer layers, timestep
+ conditioning, and cross-attention to encoder outputs.
+ """
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ use_cache: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ patch_size: int = 2,
+ in_channels: int = 192,
+ audio_acoustic_hidden_dim: int = 64,
+ encoder_hidden_size: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.use_cache = use_cache
+ encoder_hidden_size = encoder_hidden_size or hidden_size
+
+ # Rotary position embeddings for transformer layers
+ rope_config = type('RopeConfig', (), {
+ 'hidden_size': hidden_size,
+ 'num_attention_heads': num_attention_heads,
+ 'num_key_value_heads': num_key_value_heads,
+ 'head_dim': head_dim,
+ 'max_position_embeddings': max_position_embeddings,
+ 'rope_theta': rope_theta,
+ 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
+ 'rms_norm_eps': rms_norm_eps,
+ 'attention_bias': attention_bias,
+ 'attention_dropout': attention_dropout,
+ 'hidden_act': 'silu',
+ 'intermediate_size': intermediate_size,
+ 'layer_types': self.layer_types,
+ 'sliding_window': sliding_window,
+ })()
+ self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
+
+ # Stack of DiT transformer layers
+ self.layers = nn.ModuleList([
+ AceStepDiTLayer(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ intermediate_size=intermediate_size,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=self.layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ )
+ for layer_idx in range(num_hidden_layers)
+ ])
+
+ self.patch_size = patch_size
+
+ # Input projection: patch embedding using 1D convolution
+ self.proj_in = nn.Sequential(
+ Lambda(lambda x: x.transpose(1, 2)),
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=hidden_size,
+ kernel_size=patch_size,
+ stride=patch_size,
+ padding=0,
+ ),
+ Lambda(lambda x: x.transpose(1, 2)),
+ )
+
+ # Timestep embeddings for diffusion conditioning
+ self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
+ self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
+
+ # Project encoder hidden states to model dimension
+ self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
+
+ # Output normalization and projection
+ self.norm_out = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.proj_out = nn.Sequential(
+ Lambda(lambda x: x.transpose(1, 2)),
+ nn.ConvTranspose1d(
+ in_channels=hidden_size,
+ out_channels=audio_acoustic_hidden_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ padding=0,
+ ),
+ Lambda(lambda x: x.transpose(1, 2)),
+ )
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ timestep_r: torch.Tensor,
+ attention_mask: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ context_latents: torch.Tensor,
+ use_cache: Optional[bool] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = False,
+ return_hidden_states: int = None,
+ custom_layers_config: Optional[dict] = None,
+ enable_early_exit: bool = False,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ):
+
+ use_cache = use_cache if use_cache is not None else self.use_cache
+
+ # Disable cache during training or when gradient checkpointing is enabled
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+ if self.training:
+ use_cache = False
+
+ # Initialize cache if needed (only during inference for auto-regressive generation)
+ if not self.training and use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
+
+ # Compute timestep embeddings for diffusion conditioning
+ # Two embeddings: one for timestep t, one for timestep difference (t - r)
+ temb_t, timestep_proj_t = self.time_embed(timestep)
+ temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
+ # Combine embeddings
+ temb = temb_t + temb_r
+ timestep_proj = timestep_proj_t + timestep_proj_r
+
+ # Concatenate context latents (source latents + chunk masks) with hidden states
+ hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
+ # Record original sequence length for later restoration after padding
+ original_seq_len = hidden_states.shape[1]
+ # Apply padding if sequence length is not divisible by patch_size
+ # This ensures proper patch extraction
+ pad_length = 0
+ if hidden_states.shape[1] % self.patch_size != 0:
+ pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0)
+
+ # Project input to patches and project encoder states
+ hidden_states = self.proj_in(hidden_states)
+ encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
+
+ # Cache positions
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
+ )
+
+ # Position IDs
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ seq_len = hidden_states.shape[1]
+ encoder_seq_len = encoder_hidden_states.shape[1]
+ dtype = hidden_states.dtype
+ device = hidden_states.device
+
+ # Initialize Mask variables
+ full_attn_mask = None
+ sliding_attn_mask = None
+ encoder_attn_mask = None
+ decoder_attn_mask = None
+ # Target library discards the passed-in attention_mask for 4D mask
+ # construction (line 1384: attention_mask = None)
+ attention_mask = None
+
+ # 1. Full Attention (Bidirectional, Global)
+ full_attn_mask = create_4d_mask(
+ seq_len=seq_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=attention_mask,
+ sliding_window=None,
+ is_sliding_window=False,
+ is_causal=False
+ )
+ max_len = max(seq_len, encoder_seq_len)
+
+ encoder_attn_mask = create_4d_mask(
+ seq_len=max_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=attention_mask,
+ sliding_window=None,
+ is_sliding_window=False,
+ is_causal=False
+ )
+ encoder_attn_mask = encoder_attn_mask[:, :, :seq_len, :encoder_seq_len]
+
+ # 2. Sliding Attention (Bidirectional, Local)
+ if self.use_sliding_window:
+ sliding_attn_mask = create_4d_mask(
+ seq_len=seq_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=attention_mask,
+ sliding_window=self.sliding_window,
+ is_sliding_window=True,
+ is_causal=False
+ )
+
+ # Build mask mapping
+ self_attn_mask_mapping = {
+ "full_attention": full_attn_mask,
+ "sliding_attention": sliding_attn_mask,
+ "encoder_attention_mask": encoder_attn_mask,
+ }
+
+ # Create position embeddings to be shared across all decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ all_cross_attentions = () if output_attentions else None
+
+ # Handle early exit for custom layer configurations
+ max_needed_layer = float('inf')
+ if custom_layers_config is not None and enable_early_exit:
+ max_needed_layer = max(custom_layers_config.keys())
+ output_attentions = True
+ if all_cross_attentions is None:
+ all_cross_attentions = ()
+
+ # Process through transformer layers
+ for index_block, layer_module in enumerate(self.layers):
+ # Early exit optimization
+ if index_block > max_needed_layer:
+ break
+
+ # Prepare layer arguments
+ layer_args = (
+ hidden_states,
+ position_embeddings,
+ timestep_proj,
+ self_attn_mask_mapping[layer_module.attention_type],
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ encoder_hidden_states,
+ self_attn_mask_mapping["encoder_attention_mask"],
+ )
+ layer_kwargs = flash_attn_kwargs
+
+ # Use gradient checkpointing if enabled
+ if use_gradient_checkpointing or use_gradient_checkpointing_offload:
+ layer_outputs = gradient_checkpoint_forward(
+ layer_module,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ *layer_args,
+ **layer_kwargs,
+ )
+ else:
+ layer_outputs = layer_module(
+ *layer_args,
+ **layer_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions and self.layers[index_block].use_cross_attention:
+ # layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights)
+ if len(layer_outputs) >= 3:
+ all_cross_attentions += (layer_outputs[2],)
+
+ if return_hidden_states:
+ return hidden_states
+
+ # Extract scale-shift parameters for adaptive output normalization
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ # Apply adaptive layer norm: norm(x) * (1 + scale) + shift
+ hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
+ # Project output: de-patchify back to original sequence format
+ hidden_states = self.proj_out(hidden_states)
+
+ # Crop back to original sequence length to ensure exact length match (remove padding)
+ hidden_states = hidden_states[:, :original_seq_len, :]
+
+ outputs = (hidden_states, past_key_values)
+
+ if output_attentions:
+ outputs += (all_cross_attentions,)
+ return outputs
diff --git a/diffsynth/models/ace_step_lm.py b/diffsynth/models/ace_step_lm.py
new file mode 100644
index 00000000..fc8c0817
--- /dev/null
+++ b/diffsynth/models/ace_step_lm.py
@@ -0,0 +1,79 @@
+import torch
+
+
+LM_CONFIGS = {
+ "acestep-5Hz-lm-0.6B": {
+ "hidden_size": 1024,
+ "intermediate_size": 3072,
+ "num_hidden_layers": 28,
+ "num_attention_heads": 16,
+ "layer_types": ["full_attention"] * 28,
+ "max_window_layers": 28,
+ },
+ "acestep-5Hz-lm-1.7B": {
+ "hidden_size": 2048,
+ "intermediate_size": 6144,
+ "num_hidden_layers": 28,
+ "num_attention_heads": 16,
+ "layer_types": ["full_attention"] * 28,
+ "max_window_layers": 28,
+ },
+ "acestep-5Hz-lm-4B": {
+ "hidden_size": 2560,
+ "intermediate_size": 9728,
+ "num_hidden_layers": 36,
+ "num_attention_heads": 32,
+ "layer_types": ["full_attention"] * 36,
+ "max_window_layers": 36,
+ },
+}
+
+
+class AceStepLM(torch.nn.Module):
+ """
+ Language model for ACE-Step.
+
+ Converts natural language prompts into structured parameters
+ (caption, lyrics, bpm, keyscale, duration, timesignature, etc.)
+ for ACE-Step music generation.
+
+ Wraps a Qwen3ForCausalLM transformers model. Config is manually
+ constructed based on variant type, and model weights are loaded
+ via DiffSynth's standard mechanism from safetensors files.
+ """
+
+ def __init__(
+ self,
+ variant: str = "acestep-5Hz-lm-1.7B",
+ ):
+ super().__init__()
+ from transformers import Qwen3Config, Qwen3ForCausalLM
+
+ config_params = LM_CONFIGS[variant]
+
+ config = Qwen3Config(
+ attention_bias=False,
+ attention_dropout=0.0,
+ bos_token_id=151643,
+ dtype="bfloat16",
+ eos_token_id=151645,
+ head_dim=128,
+ hidden_act="silu",
+ initializer_range=0.02,
+ max_position_embeddings=40960,
+ model_type="qwen3",
+ num_key_value_heads=8,
+ pad_token_id=151643,
+ rms_norm_eps=1e-06,
+ rope_scaling=None,
+ rope_theta=1000000,
+ sliding_window=None,
+ tie_word_embeddings=True,
+ use_cache=True,
+ use_sliding_window=False,
+ vocab_size=217204,
+ **config_params,
+ )
+
+ self.model = Qwen3ForCausalLM(config)
+ self.config = config
diff --git a/diffsynth/models/ace_step_text_encoder.py b/diffsynth/models/ace_step_text_encoder.py
new file mode 100644
index 00000000..58b52a7e
--- /dev/null
+++ b/diffsynth/models/ace_step_text_encoder.py
@@ -0,0 +1,80 @@
+import torch
+
+
+class AceStepTextEncoder(torch.nn.Module):
+ """
+ Text encoder for ACE-Step using Qwen3-Embedding-0.6B.
+
+ Converts text/lyric tokens to hidden state embeddings that are
+ further processed by the ACE-Step ConditionEncoder.
+
+ Wraps a Qwen3Model transformers model. Config is manually
+ constructed, and model weights are loaded via DiffSynth's
+ standard mechanism from safetensors files.
+ """
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ from transformers import Qwen3Config, Qwen3Model
+
+ config = Qwen3Config(
+ attention_bias=False,
+ attention_dropout=0.0,
+ bos_token_id=151643,
+ dtype="bfloat16",
+ eos_token_id=151643,
+ head_dim=128,
+ hidden_act="silu",
+ hidden_size=1024,
+ initializer_range=0.02,
+ intermediate_size=3072,
+ layer_types=["full_attention"] * 28,
+ max_position_embeddings=32768,
+ max_window_layers=28,
+ model_type="qwen3",
+ num_attention_heads=16,
+ num_hidden_layers=28,
+ num_key_value_heads=8,
+ pad_token_id=151643,
+ rms_norm_eps=1e-06,
+ rope_scaling=None,
+ rope_theta=1000000,
+ sliding_window=None,
+ tie_word_embeddings=True,
+ use_cache=True,
+ use_sliding_window=False,
+ vocab_size=151669,
+ )
+
+ self.model = Qwen3Model(config)
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ @torch.no_grad()
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: torch.Tensor,
+ ):
+ """
+ Encode text/lyric tokens to hidden states.
+
+ Args:
+ input_ids: [B, T] token IDs
+ attention_mask: [B, T] attention mask
+
+ Returns:
+ last_hidden_state: [B, T, hidden_size]
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ )
+ return outputs.last_hidden_state
+
+ def to(self, *args, **kwargs):
+ self.model.to(*args, **kwargs)
+ return self
diff --git a/diffsynth/models/ace_step_tokenizer.py b/diffsynth/models/ace_step_tokenizer.py
new file mode 100644
index 00000000..c01e9d50
--- /dev/null
+++ b/diffsynth/models/ace_step_tokenizer.py
@@ -0,0 +1,732 @@
+# Copyright 2025 The ACESTEO Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ACE-Step Audio Tokenizer — VAE latent discretization pathway.
+
+Contains:
+- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens
+- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features
+
+Only used in cover song mode (is_covers=True). Bypassed in text-to-music.
+"""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ..core.attention import attention_forward
+from ..core.gradient import gradient_checkpoint_forward
+
+from transformers.cache_utils import Cache
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.processing_utils import Unpack
+from transformers.utils import can_return_tuple, logging
+from transformers.models.qwen3.modeling_qwen3 import (
+ Qwen3MLP,
+ Qwen3RMSNorm,
+ Qwen3RotaryEmbedding,
+ apply_rotary_pos_emb,
+)
+from vector_quantize_pytorch import ResidualFSQ
+
+logger = logging.get_logger(__name__)
+
+
+def create_4d_mask(
+ seq_len: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ attention_mask: Optional[torch.Tensor] = None,
+ sliding_window: Optional[int] = None,
+ is_sliding_window: bool = False,
+ is_causal: bool = True,
+) -> torch.Tensor:
+ indices = torch.arange(seq_len, device=device)
+ diff = indices.unsqueeze(1) - indices.unsqueeze(0)
+ valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
+ if is_causal:
+ valid_mask = valid_mask & (diff >= 0)
+ if is_sliding_window and sliding_window is not None:
+ if is_causal:
+ valid_mask = valid_mask & (diff <= sliding_window)
+ else:
+ valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
+ valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
+ if attention_mask is not None:
+ padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
+ valid_mask = valid_mask & padding_mask_4d
+ min_dtype = torch.finfo(dtype).min
+ mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
+ mask_tensor.masked_fill_(valid_mask, 0.0)
+ return mask_tensor
+
+
+class Lambda(nn.Module):
+ def __init__(self, func):
+ super().__init__()
+ self.func = func
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class AceStepAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ rms_norm_eps: float,
+ attention_bias: bool,
+ attention_dropout: float,
+ layer_types: list,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ is_cross_attention: bool = False,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.num_key_value_groups = num_attention_heads // num_key_value_heads
+ self.scaling = self.head_dim ** -0.5
+ self.attention_dropout = attention_dropout
+ if is_cross_attention:
+ is_causal = False
+ self.is_causal = is_causal
+ self.is_cross_attention = is_cross_attention
+
+ self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
+ self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
+ self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
+ self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.attention_type = layer_types[layer_idx]
+ self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+
+ is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
+
+ if is_cross_attention:
+ encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ curr_past_key_value = past_key_value.cross_attention_cache
+ if not is_updated:
+ key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
+ key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
+ past_key_value.is_updated[self.layer_idx] = True
+ else:
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
+
+ else:
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.num_key_value_groups > 1:
+ key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
+ value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
+
+ attn_output = attention_forward(
+ query_states, key_states, value_states,
+ q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
+ attn_mask=attention_mask,
+ )
+ attn_weights = None
+
+ attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class AceStepEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ rms_norm_eps: float,
+ attention_bias: bool,
+ attention_dropout: float,
+ layer_types: list,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ layer_idx: int = 0,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.layer_idx = layer_idx
+
+ self.self_attn = AceStepAttention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ is_cross_attention=False,
+ is_causal=False,
+ )
+ self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+
+ mlp_config = type('Config', (), {
+ 'hidden_size': hidden_size,
+ 'intermediate_size': intermediate_size,
+ 'hidden_act': 'silu',
+ })()
+ self.mlp = Qwen3MLP(mlp_config)
+ self.attention_type = layer_types[layer_idx]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ use_cache=False,
+ past_key_value=None,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ return outputs
+
+
+class AttentionPooler(nn.Module):
+ """Pools every pool_window_size frames into 1 representation via transformer + CLS token."""
+
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ num_attention_pooler_hidden_layers: int = 2,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Default matches target library config (24 alternating entries).
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_sliding_window = use_sliding_window
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
+ self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
+
+ self.embed_tokens = nn.Linear(hidden_size, hidden_size)
+ self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ # Slice layer_types to our own layer count
+ pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
+ rope_config = type('RopeConfig', (), {
+ 'hidden_size': hidden_size,
+ 'num_attention_heads': num_attention_heads,
+ 'num_key_value_heads': num_key_value_heads,
+ 'head_dim': head_dim,
+ 'max_position_embeddings': max_position_embeddings,
+ 'rope_theta': rope_theta,
+ 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
+ 'rms_norm_eps': rms_norm_eps,
+ 'attention_bias': attention_bias,
+ 'attention_dropout': attention_dropout,
+ 'hidden_act': 'silu',
+ 'intermediate_size': intermediate_size,
+ 'layer_types': pooler_layer_types,
+ 'sliding_window': sliding_window,
+ '_attn_implementation': self._attn_implementation,
+ })()
+ self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
+ self.gradient_checkpointing = False
+ self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=pooler_layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ )
+ for layer_idx in range(num_attention_pooler_hidden_layers)
+ ])
+
+ @can_return_tuple
+ def forward(
+ self,
+ x,
+ attention_mask: Optional[torch.Tensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.Tensor:
+ B, T, P, D = x.shape
+ x = self.embed_tokens(x)
+ special_tokens = self.special_token.expand(B, T, 1, -1)
+ x = torch.cat([special_tokens, x], dim=2)
+ x = rearrange(x, "b t p c -> (b t) p c")
+
+ cache_position = torch.arange(0, x.shape[1], device=x.device)
+ position_ids = cache_position.unsqueeze(0)
+ hidden_states = x
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ seq_len = x.shape[1]
+ dtype = x.dtype
+ device = x.device
+
+ full_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=None,
+ is_sliding_window=False, is_causal=False
+ )
+ sliding_attn_mask = None
+ if self.use_sliding_window:
+ sliding_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=self.sliding_window,
+ is_sliding_window=True, is_causal=False
+ )
+
+ self_attn_mask_mapping = {
+ "full_attention": full_attn_mask,
+ "sliding_attention": sliding_attn_mask,
+ }
+
+ for layer_module in self.layers:
+ layer_outputs = layer_module(
+ hidden_states, position_embeddings,
+ attention_mask=self_attn_mask_mapping[layer_module.attention_type],
+ **flash_attn_kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ hidden_states = self.norm(hidden_states)
+ cls_output = hidden_states[:, 0, :]
+ return rearrange(cls_output, "(b t) c -> b t c", b=B)
+
+
+class AceStepAudioTokenizer(nn.Module):
+ """Converts continuous acoustic features (VAE latents) into discrete quantized tokens.
+
+ Input: [B, T, 64] (VAE latent dim)
+ Output: quantized [B, T/5, 2048], indices [B, T/5, 1]
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ audio_acoustic_hidden_dim: int = 64,
+ pool_window_size: int = 5,
+ fsq_dim: int = 2048,
+ fsq_input_levels: list = None,
+ fsq_input_num_quantizers: int = 1,
+ num_attention_pooler_hidden_layers: int = 2,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Default matches target library config (24 alternating entries).
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_sliding_window = use_sliding_window
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
+ self.pool_window_size = pool_window_size
+ self.fsq_dim = fsq_dim
+ self.fsq_input_levels = fsq_input_levels or [8, 8, 8, 5, 5, 5]
+ self.fsq_input_num_quantizers = fsq_input_num_quantizers
+ self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
+ self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
+
+ self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
+ # Slice layer_types for the attention pooler
+ pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
+ self.attention_pooler = AttentionPooler(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=pooler_layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ use_sliding_window=use_sliding_window,
+ rope_theta=rope_theta,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
+ )
+ self.quantizer = ResidualFSQ(
+ dim=self.fsq_dim,
+ levels=self.fsq_input_levels,
+ num_quantizers=self.fsq_input_num_quantizers,
+ force_quantization_f32=False, # avoid autocast bug in vector_quantize_pytorch
+ )
+
+ @can_return_tuple
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ hidden_states = self.audio_acoustic_proj(hidden_states)
+ hidden_states = self.attention_pooler(hidden_states)
+ quantized, indices = self.quantizer(hidden_states)
+ return quantized, indices
+
+ def tokenize(self, x):
+ """Convenience: takes [B, T, 64], rearranges to patches, runs forward."""
+ x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size)
+ return self.forward(x)
+
+
+class AudioTokenDetokenizer(nn.Module):
+ """Converts quantized audio tokens back to continuous acoustic representations.
+
+ Input: [B, T/5, hidden_size] (quantized vectors)
+ Output: [B, T, 64] (VAE-latent-shaped continuous features)
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ pool_window_size: int = 5,
+ audio_acoustic_hidden_dim: int = 64,
+ num_attention_pooler_hidden_layers: int = 2,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Default matches target library config (24 alternating entries).
+ self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_sliding_window = use_sliding_window
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.pool_window_size = pool_window_size
+ self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
+ self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
+ self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
+
+ self.embed_tokens = nn.Linear(hidden_size, hidden_size)
+ self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
+ # Slice layer_types to our own layer count (use num_audio_decoder_hidden_layers)
+ detok_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
+ rope_config = type('RopeConfig', (), {
+ 'hidden_size': hidden_size,
+ 'num_attention_heads': num_attention_heads,
+ 'num_key_value_heads': num_key_value_heads,
+ 'head_dim': head_dim,
+ 'max_position_embeddings': max_position_embeddings,
+ 'rope_theta': rope_theta,
+ 'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
+ 'rms_norm_eps': rms_norm_eps,
+ 'attention_bias': attention_bias,
+ 'attention_dropout': attention_dropout,
+ 'hidden_act': 'silu',
+ 'intermediate_size': intermediate_size,
+ 'layer_types': detok_layer_types,
+ 'sliding_window': sliding_window,
+ '_attn_implementation': self._attn_implementation,
+ })()
+ self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
+ self.gradient_checkpointing = False
+ self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=detok_layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ layer_idx=layer_idx,
+ )
+ for layer_idx in range(num_attention_pooler_hidden_layers)
+ ])
+ self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
+
+ @can_return_tuple
+ def forward(
+ self,
+ x,
+ attention_mask: Optional[torch.Tensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.Tensor:
+ B, T, D = x.shape
+ x = self.embed_tokens(x)
+ x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
+ special_tokens = self.special_tokens.expand(B, T, -1, -1)
+ x = x + special_tokens
+ x = rearrange(x, "b t p c -> (b t) p c")
+
+ cache_position = torch.arange(0, x.shape[1], device=x.device)
+ position_ids = cache_position.unsqueeze(0)
+ hidden_states = x
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ seq_len = x.shape[1]
+ dtype = x.dtype
+ device = x.device
+
+ full_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=None,
+ is_sliding_window=False, is_causal=False
+ )
+ sliding_attn_mask = None
+ if self.use_sliding_window:
+ sliding_attn_mask = create_4d_mask(
+ seq_len=seq_len, dtype=dtype, device=device,
+ attention_mask=attention_mask, sliding_window=self.sliding_window,
+ is_sliding_window=True, is_causal=False
+ )
+
+ self_attn_mask_mapping = {
+ "full_attention": full_attn_mask,
+ "sliding_attention": sliding_attn_mask,
+ }
+
+ for layer_module in self.layers:
+ layer_outputs = layer_module(
+ hidden_states, position_embeddings,
+ attention_mask=self_attn_mask_mapping[layer_module.attention_type],
+ **flash_attn_kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+ return rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.pool_window_size)
+
+
+class AceStepTokenizer(nn.Module):
+ """Container for AceStepAudioTokenizer + AudioTokenDetokenizer.
+
+ Provides encode/decode convenience methods for VAE latent discretization.
+ Used in cover song mode to convert source audio latents to discrete tokens
+ and back to continuous conditioning hints.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ rms_norm_eps: float = 1e-6,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ layer_types: Optional[list] = None,
+ head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = 128,
+ use_sliding_window: bool = True,
+ rope_theta: float = 1000000,
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ audio_acoustic_hidden_dim: int = 64,
+ pool_window_size: int = 5,
+ fsq_dim: int = 2048,
+ fsq_input_levels: list = None,
+ fsq_input_num_quantizers: int = 1,
+ num_attention_pooler_hidden_layers: int = 2,
+ num_audio_decoder_hidden_layers: int = 24,
+ **kwargs,
+ ):
+ super().__init__()
+ # Default layer_types matches target library config (24 alternating entries).
+ # Sub-modules (pooler/detokenizer) slice first N entries for their own layer count.
+ if layer_types is None:
+ layer_types = ["sliding_attention", "full_attention"] * 12
+ self.tokenizer = AceStepAudioTokenizer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ use_sliding_window=use_sliding_window,
+ rope_theta=rope_theta,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
+ pool_window_size=pool_window_size,
+ fsq_dim=fsq_dim,
+ fsq_input_levels=fsq_input_levels,
+ fsq_input_num_quantizers=fsq_input_num_quantizers,
+ num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
+ **kwargs,
+ )
+ self.detokenizer = AudioTokenDetokenizer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ rms_norm_eps=rms_norm_eps,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ layer_types=layer_types,
+ head_dim=head_dim,
+ sliding_window=sliding_window,
+ use_sliding_window=use_sliding_window,
+ rope_theta=rope_theta,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ pool_window_size=pool_window_size,
+ audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
+ num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
+ **kwargs,
+ )
+
+ def encode(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """VAE latent [B, T, 64] → discrete tokens."""
+ return self.tokenizer(hidden_states)
+
+ def decode(self, quantized: torch.Tensor) -> torch.Tensor:
+ """Discrete tokens [B, T/5, hidden_size] → continuous [B, T, 64]."""
+ return self.detokenizer(quantized)
+
+ def tokenize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convenience: [B, T, 64] → quantized + indices via patch rearrangement."""
+ return self.tokenizer.tokenize(x)
diff --git a/diffsynth/models/ace_step_vae.py b/diffsynth/models/ace_step_vae.py
new file mode 100644
index 00000000..dd78a0a6
--- /dev/null
+++ b/diffsynth/models/ace_step_vae.py
@@ -0,0 +1,241 @@
+# Copyright 2025 The ACESTEO Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture).
+
+This is a CNN-based VAE for audio waveform encoding/decoding.
+It uses weight-normalized convolutions and Snake1d activations.
+Does NOT depend on diffusers — pure nn.Module implementation.
+"""
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+
+
+class Snake1d(nn.Module):
+ """Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2."""
+
+ def __init__(self, hidden_dim: int, logscale: bool = True):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
+ self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
+ self.logscale = logscale
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ shape = hidden_states.shape
+ alpha = torch.exp(self.alpha) if self.logscale else self.alpha
+ beta = torch.exp(self.beta) if self.logscale else self.beta
+ hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
+ hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
+ return hidden_states.reshape(shape)
+
+
+class OobleckResidualUnit(nn.Module):
+ """Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip."""
+
+ def __init__(self, dimension: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.snake1 = Snake1d(dimension)
+ self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
+ self.snake2 = Snake1d(dimension)
+ self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ output = self.conv1(self.snake1(hidden_state))
+ output = self.conv2(self.snake2(output))
+ padding = (hidden_state.shape[-1] - output.shape[-1]) // 2
+ if padding > 0:
+ hidden_state = hidden_state[..., padding:-padding]
+ return hidden_state + output
+
+
+class OobleckEncoderBlock(nn.Module):
+ """Encoder block: 3 residual units + downsampling conv."""
+
+ def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
+ super().__init__()
+ self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
+ self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
+ self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
+ self.snake1 = Snake1d(input_dim)
+ self.conv1 = weight_norm(
+ nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
+ )
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.res_unit1(hidden_state)
+ hidden_state = self.res_unit2(hidden_state)
+ hidden_state = self.snake1(self.res_unit3(hidden_state))
+ return self.conv1(hidden_state)
+
+
+class OobleckDecoderBlock(nn.Module):
+ """Decoder block: upsampling conv + 3 residual units."""
+
+ def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
+ super().__init__()
+ self.snake1 = Snake1d(input_dim)
+ self.conv_t1 = weight_norm(
+ nn.ConvTranspose1d(
+ input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2),
+ )
+ )
+ self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
+ self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
+ self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.snake1(hidden_state)
+ hidden_state = self.conv_t1(hidden_state)
+ hidden_state = self.res_unit1(hidden_state)
+ hidden_state = self.res_unit2(hidden_state)
+ return self.res_unit3(hidden_state)
+
+
+class OobleckEncoder(nn.Module):
+ """Full encoder: audio → latent representation [B, encoder_hidden_size, T'].
+
+ conv1 → [blocks] → snake1 → conv2
+ """
+
+ def __init__(
+ self,
+ encoder_hidden_size: int = 128,
+ audio_channels: int = 2,
+ downsampling_ratios: list = None,
+ channel_multiples: list = None,
+ ):
+ super().__init__()
+ downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
+ channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
+ channel_multiples = [1] + channel_multiples
+
+ self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
+
+ self.block = nn.ModuleList()
+ for stride_index, stride in enumerate(downsampling_ratios):
+ self.block.append(
+ OobleckEncoderBlock(
+ input_dim=encoder_hidden_size * channel_multiples[stride_index],
+ output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
+ stride=stride,
+ )
+ )
+
+ d_model = encoder_hidden_size * channel_multiples[-1]
+ self.snake1 = Snake1d(d_model)
+ self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.conv1(hidden_state)
+ for block in self.block:
+ hidden_state = block(hidden_state)
+ hidden_state = self.snake1(hidden_state)
+ return self.conv2(hidden_state)
+
+
+class OobleckDecoder(nn.Module):
+ """Full decoder: latent → audio waveform [B, audio_channels, T].
+
+ conv1 → [blocks] → snake1 → conv2(no bias)
+ """
+
+ def __init__(
+ self,
+ channels: int = 128,
+ input_channels: int = 64,
+ audio_channels: int = 2,
+ upsampling_ratios: list = None,
+ channel_multiples: list = None,
+ ):
+ super().__init__()
+ upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2]
+ channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
+ channel_multiples = [1] + channel_multiples
+
+ self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
+
+ self.block = nn.ModuleList()
+ for stride_index, stride in enumerate(upsampling_ratios):
+ self.block.append(
+ OobleckDecoderBlock(
+ input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index],
+ output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1],
+ stride=stride,
+ )
+ )
+
+ self.snake1 = Snake1d(channels)
+ # conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key)
+ self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.conv1(hidden_state)
+ for block in self.block:
+ hidden_state = block(hidden_state)
+ hidden_state = self.snake1(hidden_state)
+ return self.conv2(hidden_state)
+
+
+class AceStepVAE(nn.Module):
+ """Audio VAE for ACE-Step (AutoencoderOobleck architecture).
+
+ Encodes audio waveform → latent, decodes latent → audio waveform.
+ Uses Snake1d activations and weight-normalized convolutions.
+ """
+
+ def __init__(
+ self,
+ encoder_hidden_size: int = 128,
+ downsampling_ratios: list = None,
+ channel_multiples: list = None,
+ decoder_channels: int = 128,
+ decoder_input_channels: int = 64,
+ audio_channels: int = 2,
+ sampling_rate: int = 48000,
+ ):
+ super().__init__()
+ downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
+ channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
+ upsampling_ratios = downsampling_ratios[::-1]
+
+ self.encoder = OobleckEncoder(
+ encoder_hidden_size=encoder_hidden_size,
+ audio_channels=audio_channels,
+ downsampling_ratios=downsampling_ratios,
+ channel_multiples=channel_multiples,
+ )
+ self.decoder = OobleckDecoder(
+ channels=decoder_channels,
+ input_channels=decoder_input_channels,
+ audio_channels=audio_channels,
+ upsampling_ratios=upsampling_ratios,
+ channel_multiples=channel_multiples,
+ )
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
+ return self.encoder(x)
+
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ """Latent [B, encoder_hidden_size, T] → audio waveform [B, audio_channels, T']."""
+ return self.decoder(z)
+
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
+ """Full round-trip: encode → decode."""
+ z = self.encode(sample)
+ return self.decoder(z)
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
new file mode 100644
index 00000000..09c31786
--- /dev/null
+++ b/diffsynth/pipelines/ace_step.py
@@ -0,0 +1,527 @@
+"""
+ACE-Step Pipeline for DiffSynth-Studio.
+
+Text-to-Music generation pipeline using ACE-Step 1.5 model.
+"""
+import torch
+from typing import Optional
+from tqdm import tqdm
+
+from ..core.device.npu_compatible_device import get_device_type
+from ..diffusion import FlowMatchScheduler
+from ..core import ModelConfig
+from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
+
+from ..models.ace_step_dit import AceStepDiTModel
+from ..models.ace_step_conditioner import AceStepConditionEncoder
+from ..models.ace_step_text_encoder import AceStepTextEncoder
+from ..models.ace_step_vae import AceStepVAE
+
+
+class AceStepPipeline(BasePipeline):
+ """Pipeline for ACE-Step text-to-music generation."""
+
+ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device,
+ torch_dtype=torch_dtype,
+ height_division_factor=1,
+ width_division_factor=1,
+ )
+ self.scheduler = FlowMatchScheduler("ACE-Step")
+ self.text_encoder: AceStepTextEncoder = None
+ self.conditioner: AceStepConditionEncoder = None
+ self.dit: AceStepDiTModel = None
+ self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE
+
+ # Unit chain order — 7 units total
+ #
+ # 1. ShapeChecker: duration → seq_len
+ # 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
+ # 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
+ # 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
+ # 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
+ # 6. NoiseInitializer: context_latents → noise
+ # 7. InputAudioEmbedder: noise → latents
+ #
+ # ContextLatentBuilder runs before ConditionEmbedder so that
+ # context_latents is available for noise shape computation.
+ self.in_iteration_models = ("dit",)
+ self.units = [
+ AceStepUnit_ShapeChecker(),
+ AceStepUnit_PromptEmbedder(),
+ AceStepUnit_SilenceLatentInitializer(),
+ AceStepUnit_ContextLatentBuilder(),
+ AceStepUnit_ConditionEmbedder(),
+ AceStepUnit_NoiseInitializer(),
+ AceStepUnit_InputAudioEmbedder(),
+ ]
+ self.model_fn = model_fn_ace_step
+ self.compilable_models = ["dit"]
+
+ self.sample_rate = 48000
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: str = get_device_type(),
+ model_configs: list[ModelConfig] = [],
+ text_tokenizer_config: ModelConfig = None,
+ vram_limit: float = None,
+ ):
+ """Load pipeline from pretrained checkpoints."""
+ pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
+ model_pool = pipe.download_and_load_models(model_configs, vram_limit)
+
+ pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
+ pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
+ pipe.dit = model_pool.fetch_model("ace_step_dit")
+ pipe.vae = model_pool.fetch_model("ace_step_vae")
+
+ if text_tokenizer_config is not None:
+ text_tokenizer_config.download_if_necessary()
+ from transformers import AutoTokenizer
+ pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
+
+ # VRAM Management
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
+ return pipe
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 1.0,
+ # Lyrics
+ lyrics: str = "",
+ # Reference audio (optional, for timbre conditioning)
+ reference_audio = None,
+ # Shape
+ duration: float = 60.0,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 8,
+ # Scheduler-specific parameters
+ shift: float = 3.0,
+ # Progress
+ progress_bar_cmd=tqdm,
+ ):
+ # 1. Scheduler
+ self.scheduler.set_timesteps(
+ num_inference_steps=num_inference_steps,
+ denoising_strength=1.0,
+ shift=shift,
+ )
+
+ # 2. 三字典输入
+ inputs_posi = {"prompt": prompt}
+ inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "lyrics": lyrics,
+ "reference_audio": reference_audio,
+ "duration": duration,
+ "seed": seed,
+ "rand_device": rand_device,
+ "num_inference_steps": num_inference_steps,
+ "shift": shift,
+ }
+
+ # 3. Unit 链执行
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
+ unit, self, inputs_shared, inputs_posi, inputs_nega
+ )
+
+ # 4. Denoise loop
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.cfg_guided_model_fn(
+ self.model_fn, cfg_scale,
+ inputs_shared, inputs_posi, inputs_nega,
+ **models, timestep=timestep, progress_id=progress_id
+ )
+ inputs_shared["latents"] = self.step(
+ self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
+ )
+
+ # 5. VAE 解码
+ self.load_models_to_device(['vae'])
+ # DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
+ latents = inputs_shared["latents"].transpose(1, 2)
+ vae_output = self.vae.decode(latents)
+ # VAE returns OobleckDecoderOutput with .sample attribute
+ audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
+ audio = self.output_audio_format_check(audio_output)
+ self.load_models_to_device([])
+ return audio
+
+ def output_audio_format_check(self, audio_output):
+ """Convert VAE output to standard audio format [C, T], float32.
+
+ VAE decode outputs [B, C, T] (audio waveform).
+ We squeeze batch dim and return [C, T].
+ """
+ if audio_output.ndim == 3:
+ audio_output = audio_output.squeeze(0)
+ return audio_output.float()
+
+
+class AceStepUnit_ShapeChecker(PipelineUnit):
+ """Check and compute sequence length from duration."""
+ def __init__(self):
+ super().__init__(
+ input_params=("duration",),
+ output_params=("duration", "seq_len"),
+ )
+
+ def process(self, pipe, duration):
+ # ACE-Step: 25 Hz latent rate
+ seq_len = int(duration * 25)
+ return {"duration": duration, "seq_len": seq_len}
+
+
+class AceStepUnit_PromptEmbedder(PipelineUnit):
+ """Encode prompt and lyrics using Qwen3-Embedding.
+
+ Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
+ The negative condition uses null_condition_emb (handled by ConditionEmbedder),
+ so negative text encoding is not needed here.
+ """
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={},
+ input_params=("lyrics",),
+ output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
+ onload_model_names=("text_encoder",)
+ )
+
+ def _encode_text(self, pipe, text):
+ """Encode text using Qwen3-Embedding → [B, T, 1024]."""
+ if pipe.tokenizer is None:
+ return None, None
+ text_inputs = pipe.tokenizer(
+ text,
+ padding="max_length",
+ max_length=512,
+ truncation=True,
+ return_tensors="pt",
+ )
+ input_ids = text_inputs.input_ids.to(pipe.device)
+ attention_mask = text_inputs.attention_mask.to(pipe.device)
+ hidden_states = pipe.text_encoder(input_ids, attention_mask)
+ return hidden_states, attention_mask
+
+ def process(self, pipe, prompt, lyrics, negative_prompt=None):
+ pipe.load_models_to_device(['text_encoder'])
+
+ text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt)
+
+ # Lyrics encoding — use empty string if not provided
+ lyric_text = lyrics if lyrics else ""
+ lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text)
+
+ if text_hidden_states is not None and lyric_hidden_states is not None:
+ return {
+ "text_hidden_states": text_hidden_states,
+ "text_attention_mask": text_attention_mask,
+ "lyric_hidden_states": lyric_hidden_states,
+ "lyric_attention_mask": lyric_attention_mask,
+ }
+ return {}
+
+
+class AceStepUnit_SilenceLatentInitializer(PipelineUnit):
+ """Generate silence latent (all zeros) and chunk_masks for text2music.
+
+ Target library reference: `prepare_condition()` line 1698-1699:
+ context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
+
+ For text2music mode:
+ - src_latents = zeros [B, T, 64] (VAE latent dimension)
+ - chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
+ - context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
+ """
+ def __init__(self):
+ super().__init__(
+ input_params=("seq_len",),
+ output_params=("silence_latent", "src_latents", "chunk_masks"),
+ )
+
+ def process(self, pipe, seq_len):
+ # silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension
+ silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
+ # For text2music: src_latents = silence_latent
+ src_latents = silence_latent.clone()
+
+ # chunk_masks: [B, T, 64] of ones (same shape as src_latents)
+ # In text2music mode (is_covers=0), chunk_masks are all 1.0
+ # This matches the target library's behavior at line 1699
+ chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
+
+ return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks}
+
+
+class AceStepUnit_ContextLatentBuilder(PipelineUnit):
+ """Build context_latents from src_latents and chunk_masks.
+
+ Target library reference: `prepare_condition()` line 1699:
+ context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
+
+ context_latents is the SAME for positive and negative CFG paths
+ (it comes from src_latents + chunk_masks, not from text encoding).
+ So this is a普通模式 Unit — outputs go to inputs_shared.
+ """
+ def __init__(self):
+ super().__init__(
+ input_params=("src_latents", "chunk_masks"),
+ output_params=("context_latents", "attention_mask"),
+ )
+
+ def process(self, pipe, src_latents, chunk_masks):
+ # context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
+ context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
+
+ # attention_mask for the DiT: ones [B, T]
+ # The target library uses this for cross-attention with context_latents
+ attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
+ device=pipe.device, dtype=pipe.torch_dtype)
+
+ return {"context_latents": context_latents, "attention_mask": attention_mask}
+
+
+class AceStepUnit_ConditionEmbedder(PipelineUnit):
+ """Generate encoder_hidden_states via ACEStepConditioner.
+
+ Target library reference: `prepare_condition()` line 1674-1681:
+ encoder_hidden_states, encoder_attention_mask = self.encoder(...)
+
+ Uses seperate_cfg mode:
+ - Positive: encode with full condition (text + lyrics + reference audio)
+ - Negative: replace text with null_condition_emb, keep lyrics/timbre same
+
+ context_latents is handled by ContextLatentBuilder (普通模式), not here.
+ """
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={
+ "text_hidden_states": "text_hidden_states",
+ "text_attention_mask": "text_attention_mask",
+ "lyric_hidden_states": "lyric_hidden_states",
+ "lyric_attention_mask": "lyric_attention_mask",
+ "reference_audio": "reference_audio",
+ "refer_audio_order_mask": "refer_audio_order_mask",
+ },
+ input_params_nega={},
+ input_params=("cfg_scale",),
+ output_params=(
+ "encoder_hidden_states", "encoder_attention_mask",
+ "negative_encoder_hidden_states", "negative_encoder_attention_mask",
+ ),
+ onload_model_names=("conditioner",)
+ )
+
+ def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed=None,
+ refer_audio_order_mask=None):
+ """Call ACEStepConditioner forward to produce encoder_hidden_states."""
+ pipe.load_models_to_device(['conditioner'])
+
+ # Handle reference audio
+ if refer_audio_acoustic_hidden_states_packed is None:
+ # No reference audio: create 2D packed zeros [N=1, d=64]
+ # TimbreEncoder.unpack expects [N, d], not [B, T, d]
+ refer_audio_acoustic_hidden_states_packed = torch.zeros(
+ 1, 64, device=pipe.device, dtype=pipe.torch_dtype
+ )
+ refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
+
+ encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
+ text_hidden_states=text_hidden_states,
+ text_attention_mask=text_attention_mask,
+ lyric_hidden_states=lyric_hidden_states,
+ lyric_attention_mask=lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
+ refer_audio_order_mask=refer_audio_order_mask,
+ )
+
+ return encoder_hidden_states, encoder_attention_mask
+
+ def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed=None,
+ refer_audio_order_mask=None):
+ """Generate negative condition using null_condition_emb."""
+ if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'):
+ return None, None
+
+ null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
+ bsz = 1
+ if lyric_hidden_states is not None:
+ bsz = lyric_hidden_states.shape[0]
+ null_hidden_states = null_emb.expand(bsz, -1, -1)
+ null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
+
+ # For negative: use null_condition_emb as text, keep lyrics and timbre
+ neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
+ text_hidden_states=null_hidden_states,
+ text_attention_mask=null_attn_mask,
+ lyric_hidden_states=lyric_hidden_states,
+ lyric_attention_mask=lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
+ refer_audio_order_mask=refer_audio_order_mask,
+ )
+
+ return neg_encoder_hidden_states, neg_encoder_attention_mask
+
+ def process(self, pipe, text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ reference_audio=None, refer_audio_order_mask=None,
+ negative_prompt=None, cfg_scale=1.0):
+
+ # Positive condition
+ pos_enc_hs, pos_enc_mask = self._prepare_condition(
+ pipe, text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ None, refer_audio_order_mask,
+ )
+
+ # Negative condition: only needed when CFG is active (cfg_scale > 1.0)
+ # For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
+ result = {
+ "encoder_hidden_states": pos_enc_hs,
+ "encoder_attention_mask": pos_enc_mask,
+ }
+
+ if cfg_scale > 1.0:
+ neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
+ pipe, lyric_hidden_states, lyric_attention_mask,
+ None, refer_audio_order_mask,
+ )
+ if neg_enc_hs is not None:
+ result["negative_encoder_hidden_states"] = neg_enc_hs
+ result["negative_encoder_attention_mask"] = neg_enc_mask
+
+ return result
+
+
+class AceStepUnit_NoiseInitializer(PipelineUnit):
+ """Generate initial noise tensor.
+
+ Target library reference: `prepare_noise()` line 1781-1818:
+ src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
+
+ Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
+ """
+ def __init__(self):
+ super().__init__(
+ input_params=("seed", "seq_len", "rand_device", "context_latents"),
+ output_params=("noise",),
+ )
+
+ def process(self, pipe, seed, seq_len, rand_device, context_latents):
+ # Noise shape: [B, T, context_latents.shape[-1] // 2]
+ # context_latents = [B, T, 128] → noise = [B, T, 64]
+ # This matches the target library's prepare_noise() at line 1796
+ noise_shape = (context_latents.shape[0], context_latents.shape[1],
+ context_latents.shape[-1] // 2)
+ noise = pipe.generate_noise(
+ noise_shape,
+ seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
+ )
+ return {"noise": noise}
+
+
+class AceStepUnit_InputAudioEmbedder(PipelineUnit):
+ """Set up latents for denoise loop.
+
+ For text2music (no input audio): latents = noise, input_latents = None.
+
+ Target library reference: `generate_audio()` line 1972:
+ xt = noise (when cover_noise_strength == 0)
+ """
+ def __init__(self):
+ super().__init__(
+ input_params=("noise",),
+ output_params=("latents", "input_latents"),
+ )
+
+ def process(self, pipe, noise):
+ # For text2music: start from pure noise
+ return {"latents": noise, "input_latents": None}
+
+
+def model_fn_ace_step(
+ dit: AceStepDiTModel,
+ latents=None,
+ timestep=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ context_latents=None,
+ attention_mask=None,
+ past_key_values=None,
+ negative_encoder_hidden_states=None,
+ negative_encoder_attention_mask=None,
+ negative_context_latents=None,
+ **kwargs,
+):
+ """Model function for ACE-Step DiT forward.
+
+ Timestep is already in [0, 1] range — no scaling needed.
+
+ Target library reference: `generate_audio()` line 2009-2020:
+ decoder_outputs = self.decoder(
+ hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ context_latents=context_latents,
+ use_cache=True, past_key_values=past_key_values,
+ )
+
+ Args:
+ dit: AceStepDiTModel
+ latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
+ timestep: scalar tensor in [0, 1]
+ encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
+ (positive or negative depending on CFG pass — the cfg_guided_model_fn
+ passes inputs_posi for positive, inputs_nega for negative)
+ encoder_attention_mask: [B, T_text]
+ context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
+ (same for both CFG+/- paths in text2music mode)
+ attention_mask: [B, T] ones mask for DiT
+ past_key_values: EncoderDecoderCache for KV caching
+
+ The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
+ as the actual input (128 + 64 = 192 channels).
+ """
+ # ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
+ timestep = timestep.squeeze()
+
+ # Expand timestep to match batch size
+ bsz = latents.shape[0]
+ timestep = timestep.expand(bsz)
+
+ decoder_outputs = dit(
+ hidden_states=latents,
+ timestep=timestep,
+ timestep_r=timestep,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ context_latents=context_latents,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
+ # Return velocity prediction (first element of decoder_outputs)
+ return decoder_outputs[0]
diff --git a/diffsynth/utils/state_dict_converters/ace_step_conditioner.py b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py
new file mode 100644
index 00000000..b6984b88
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py
@@ -0,0 +1,48 @@
+"""
+State dict converter for ACE-Step Conditioner model.
+
+The original checkpoint stores all model weights in a single file
+(nested in AceStepConditionGenerationModel). The Conditioner weights are
+prefixed with 'encoder.'.
+
+This converter extracts only keys starting with 'encoder.' and strips
+the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
+"""
+
+
+def ace_step_conditioner_converter(state_dict):
+ """
+ Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
+
+ 参数 state_dict 是 DiskMap 类型。
+ 遍历时,key 是 key 名,state_dict[key] 获取实际值。
+
+ Original checkpoint contains all model weights under prefixes:
+ - decoder.* (DiT)
+ - encoder.* (Conditioner)
+ - tokenizer.* (Audio Tokenizer)
+ - detokenizer.* (Audio Detokenizer)
+ - null_condition_emb (CFG null embedding)
+
+ This extracts only 'encoder.' keys and strips the prefix.
+
+ Example mapping:
+ encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
+ encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
+ encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
+ encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
+ encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
+ """
+ new_state_dict = {}
+ prefix = "encoder."
+
+ for key in state_dict:
+ if key.startswith(prefix):
+ new_key = key[len(prefix):]
+ new_state_dict[new_key] = state_dict[key]
+
+ # Extract null_condition_emb from top level (used for CFG negative condition)
+ if "null_condition_emb" in state_dict:
+ new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
+
+ return new_state_dict
diff --git a/diffsynth/utils/state_dict_converters/ace_step_dit.py b/diffsynth/utils/state_dict_converters/ace_step_dit.py
new file mode 100644
index 00000000..758462cc
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/ace_step_dit.py
@@ -0,0 +1,43 @@
+"""
+State dict converter for ACE-Step DiT model.
+
+The original checkpoint stores all model weights in a single file
+(nested in AceStepConditionGenerationModel). The DiT weights are
+prefixed with 'decoder.'.
+
+This converter extracts only keys starting with 'decoder.' and strips
+the prefix to match the standalone AceStepDiTModel in DiffSynth.
+"""
+
+
+def ace_step_dit_converter(state_dict):
+ """
+ Convert ACE-Step DiT checkpoint keys to DiffSynth format.
+
+ 参数 state_dict 是 DiskMap 类型。
+ 遍历时,key 是 key 名,state_dict[key] 获取实际值。
+
+ Original checkpoint contains all model weights under prefixes:
+ - decoder.* (DiT)
+ - encoder.* (Conditioner)
+ - tokenizer.* (Audio Tokenizer)
+ - detokenizer.* (Audio Detokenizer)
+ - null_condition_emb (CFG null embedding)
+
+ This extracts only 'decoder.' keys and strips the prefix.
+
+ Example mapping:
+ decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
+ decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
+ decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
+ decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
+ """
+ new_state_dict = {}
+ prefix = "decoder."
+
+ for key in state_dict:
+ if key.startswith(prefix):
+ new_key = key[len(prefix):]
+ new_state_dict[new_key] = state_dict[key]
+
+ return new_state_dict
diff --git a/diffsynth/utils/state_dict_converters/ace_step_lm.py b/diffsynth/utils/state_dict_converters/ace_step_lm.py
new file mode 100644
index 00000000..2067cb16
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/ace_step_lm.py
@@ -0,0 +1,55 @@
+"""
+State dict converter for ACE-Step LLM (Qwen3-based).
+
+The safetensors file stores Qwen3 model weights. Different checkpoints
+may have different key formats:
+- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
+- Qwen3Model format: embed_tokens.weight, layers.0.*
+
+Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
+state_dict() has keys:
+ model.model.embed_tokens.weight
+ model.model.layers.0.self_attn.q_proj.weight
+ model.model.norm.weight
+ model.lm_head.weight (tied to model.model.embed_tokens)
+
+This converter normalizes all keys to the Qwen3ForCausalLM format.
+
+Example mapping:
+ model.embed_tokens.weight -> model.model.embed_tokens.weight
+ embed_tokens.weight -> model.model.embed_tokens.weight
+ model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
+ layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
+ model.norm.weight -> model.model.norm.weight
+ norm.weight -> model.model.norm.weight
+"""
+
+
+def ace_step_lm_converter(state_dict):
+ """
+ Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
+
+ 参数 state_dict 是 DiskMap 类型。
+ 遍历时,key 是 key 名,state_dict[key] 获取实际值。
+ """
+ new_state_dict = {}
+ model_prefix = "model."
+ nested_prefix = "model.model."
+
+ for key in state_dict:
+ if key.startswith(nested_prefix):
+ # Already has model.model., keep as is
+ new_key = key
+ elif key.startswith(model_prefix):
+ # Has model., add another model.
+ new_key = "model." + key
+ else:
+ # No prefix, add model.model.
+ new_key = "model.model." + key
+ new_state_dict[new_key] = state_dict[key]
+
+ # Handle tied word embeddings: lm_head.weight shares with embed_tokens
+ if "model.model.embed_tokens.weight" in new_state_dict:
+ new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
+
+ return new_state_dict
diff --git a/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py
new file mode 100644
index 00000000..de0b6c7b
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py
@@ -0,0 +1,39 @@
+"""
+State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
+
+The safetensors stores Qwen3Model weights with keys:
+ embed_tokens.weight
+ layers.0.self_attn.q_proj.weight
+ norm.weight
+
+AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
+state_dict() has keys with 'model.' prefix:
+ model.embed_tokens.weight
+ model.layers.0.self_attn.q_proj.weight
+ model.norm.weight
+
+This converter adds 'model.' prefix to match the nested structure.
+"""
+
+
+def ace_step_text_encoder_converter(state_dict):
+ """
+ Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
+
+ 参数 state_dict 是 DiskMap 类型。
+ 遍历时,key 是 key 名,state_dict[key] 获取实际值。
+ """
+ new_state_dict = {}
+ prefix = "model."
+ nested_prefix = "model.model."
+
+ for key in state_dict:
+ if key.startswith(nested_prefix):
+ new_key = key
+ elif key.startswith(prefix):
+ new_key = "model." + key
+ else:
+ new_key = "model." + key
+ new_state_dict[new_key] = state_dict[key]
+
+ return new_state_dict
diff --git a/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
new file mode 100644
index 00000000..d4cb2bab
--- /dev/null
+++ b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
@@ -0,0 +1,27 @@
+"""
+State dict converter for ACE-Step Tokenizer model.
+
+The original checkpoint stores tokenizer and detokenizer weights at the top level:
+- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
+- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
+
+These map directly to the AceStepTokenizer class which wraps both as
+self.tokenizer and self.detokenizer submodules.
+"""
+
+
+def ace_step_tokenizer_converter(state_dict):
+ """
+ Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
+
+ The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
+ the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
+ No key remapping needed — just extract the relevant keys.
+ """
+ new_state_dict = {}
+
+ for key in state_dict:
+ if key.startswith("tokenizer.") or key.startswith("detokenizer."):
+ new_state_dict[key] = state_dict[key]
+
+ return new_state_dict
diff --git a/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
new file mode 100644
index 00000000..edcf2cd4
--- /dev/null
+++ b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
@@ -0,0 +1,180 @@
+"""
+Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
+
+Uses the ACE-Step LLM to expand a simple description into structured
+parameters (caption, lyrics, bpm, keyscale, etc.), then feeds them
+to the DiffSynth Pipeline.
+
+The LLM expansion uses the target library's LLMHandler. If vLLM is
+not available, it falls back to using pre-structured parameters.
+
+Usage:
+ python examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
+"""
+import os
+import sys
+import json
+import torch
+import soundfile as sf
+
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+
+
+# ---------------------------------------------------------------------------
+# Simple Mode: LLM expansion
+# ---------------------------------------------------------------------------
+
+def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-lm-1.7B",
+ backend: str = "vllm"):
+ """Try to load the target library's LLMHandler. Returns (handler, success)."""
+ try:
+ from acestep.llm_inference import LLMHandler
+ handler = LLMHandler()
+ status, success = handler.initialize(
+ checkpoint_dir=checkpoint_dir,
+ lm_model_path=lm_model_path,
+ backend=backend,
+ )
+ if success:
+ print(f"[Simple Mode] LLM loaded via {backend} backend: {status}")
+ return handler, True
+ else:
+ print(f"[Simple Mode] LLM init failed: {status}")
+ return None, False
+ except Exception as e:
+ print(f"[Simple Mode] LLMHandler not available: {e}")
+ return None, False
+
+
+def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
+ """Expand a simple description using LLM Chain-of-Thought."""
+ result = llm_handler.generate_with_stop_condition(
+ caption=description,
+ lyrics="",
+ infer_type="dit", # metadata only
+ temperature=0.85,
+ cfg_scale=1.0,
+ use_cot_metas=True,
+ use_cot_caption=True,
+ use_cot_language=True,
+ user_metadata={"duration": int(duration)},
+ )
+
+ if result.get("success") and result.get("metadata"):
+ meta = result["metadata"]
+ return {
+ "caption": meta.get("caption", description),
+ "lyrics": meta.get("lyrics", ""),
+ "bpm": meta.get("bpm", 100),
+ "keyscale": meta.get("keyscale", ""),
+ "language": meta.get("language", "en"),
+ "timesignature": meta.get("timesignature", "4"),
+ "duration": meta.get("duration", duration),
+ }
+
+ print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
+ return None
+
+
+def fallback_expand(description: str, duration: float = 30.0):
+ """Fallback: use description as caption with default parameters."""
+ print(f"[Simple Mode] LLM not available. Using description as caption.")
+ return {
+ "caption": description,
+ "lyrics": "",
+ "bpm": 100,
+ "keyscale": "",
+ "language": "en",
+ "timesignature": "4",
+ "duration": duration,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+def main():
+ # Target library path (for LLMHandler)
+ TARGET_LIB = os.path.join(os.path.dirname(__file__), "../../../../ACE-Step-1.5")
+ if TARGET_LIB not in sys.path:
+ sys.path.insert(0, TARGET_LIB)
+
+ description = "a soft Bengali love song for a quiet evening"
+ duration = 30.0
+
+ # 1. Try to load LLM
+ print("=" * 60)
+ print("Ace-Step 1.5 — Simple Mode (LLM expansion)")
+ print("=" * 60)
+ print(f"\n[Simple Mode] Input: '{description}'")
+
+ llm_handler, llm_ok = try_load_llm_handler(
+ checkpoint_dir=TARGET_LIB,
+ lm_model_path="acestep-5Hz-lm-1.7B",
+ )
+
+ # 2. Expand parameters
+ if llm_ok:
+ params = expand_with_llm(llm_handler, description, duration=duration)
+ if params is None:
+ params = fallback_expand(description, duration)
+ else:
+ params = fallback_expand(description, duration)
+
+ print(f"\n[Simple Mode] Parameters:")
+ print(f" Caption: {params['caption'][:100]}...")
+ print(f" Lyrics: {len(params['lyrics'])} chars")
+ print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
+ print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
+ print(f" Duration: {params['duration']}s")
+
+ # 3. Load Pipeline
+ print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
+ pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="vae/"
+ ),
+ )
+
+ # 4. Generate
+ print(f"\n[Generation] Running Pipeline...")
+ audio = pipe(
+ prompt=params["caption"],
+ lyrics=params["lyrics"],
+ duration=params["duration"],
+ seed=42,
+ num_inference_steps=8,
+ cfg_scale=1.0,
+ shift=3.0,
+ )
+
+ output_path = "Ace-Step1.5-SimpleMode.wav"
+ sf.write(output_path, audio.cpu().numpy(), pipe.sample_rate)
+ print(f"\n[Done] Saved to {output_path}")
+ print(f" Shape: {audio.shape}, Duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
new file mode 100644
index 00000000..a9fda153
--- /dev/null
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -0,0 +1,67 @@
+"""
+Ace-Step 1.5 — Text-to-Music (Turbo) inference example.
+
+Demonstrates the standard text2music pipeline with structured parameters
+(caption, lyrics, duration, etc.) — no LLM expansion needed.
+
+For Simple Mode (LLM expands a short description), see:
+ - Ace-Step1.5-SimpleMode.py
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline."
+lyrics = """[Intro - Synth Brass Fanfare]
+
+[Verse 1]
+黑夜里的风吹过耳畔
+甜蜜时光转瞬即逝
+脚步飘摇在星光上
+
+[Chorus]
+心电感应在震动间
+拥抱未来勇敢冒险
+
+[Outro - Instrumental]"""
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=8,
+ cfg_scale=1.0,
+ shift=3.0,
+)
+
+sf.write("Ace-Step1.5.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
diff --git a/examples/ace_step/model_inference/acestep-v15-base.py b/examples/ace_step/model_inference/acestep-v15-base.py
new file mode 100644
index 00000000..480a6fec
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-base.py
@@ -0,0 +1,52 @@
+"""
+Ace-Step 1.5 Base (non-turbo, 24 layers) — Text-to-Music inference example.
+
+Uses cfg_scale=7.0 (standard CFG guidance) and more steps for higher quality.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-base/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-base/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "A cinematic orchestral piece with soaring strings and heroic brass"
+lyrics = "[Intro - Orchestra]\n\n[Verse 1]\nAcross the mountains, through the valley\nA journey of a thousand miles\n\n[Chorus]\nRise above the stormy skies\nLet the music carry you"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=20,
+ cfg_scale=7.0, # Base model uses CFG
+ shift=3.0,
+)
+
+sf.write("acestep-v15-base.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
diff --git a/examples/ace_step/model_inference/acestep-v15-sft.py b/examples/ace_step/model_inference/acestep-v15-sft.py
new file mode 100644
index 00000000..c9ec0fff
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-sft.py
@@ -0,0 +1,52 @@
+"""
+Ace-Step 1.5 SFT (supervised fine-tuned, 24 layers) — Text-to-Music inference example.
+
+SFT variant is fine-tuned for specific music styles.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-sft/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-sft/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "A jazzy lo-fi beat with smooth saxophone and vinyl crackle, late night vibes"
+lyrics = "[Intro - Vinyl crackle]\n\n[Verse 1]\nMidnight city, neon glow\nSmooth jazz flowing to and fro\n\n[Chorus]\nLay back, let the music play\nJazzy nights, dreams drift away"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=20,
+ cfg_scale=7.0,
+ shift=3.0,
+)
+
+sf.write("acestep-v15-sft.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py b/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py
new file mode 100644
index 00000000..447f6b0d
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py
@@ -0,0 +1,52 @@
+"""
+Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example.
+
+Uses shift=1.0 (no timestep transformation) for smoother, slower denoising.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "A gentle acoustic guitar melody with soft piano accompaniment, peaceful and warm atmosphere"
+lyrics = "[Verse 1]\nSunlight filtering through the trees\nA quiet moment, just the breeze\n\n[Chorus]\nPeaceful heart, open mind\nLeaving all the noise behind"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=8,
+ cfg_scale=1.0,
+ shift=1.0, # shift=1: no timestep transformation
+)
+
+sf.write("acestep-v15-turbo-shift1.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py b/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py
new file mode 100644
index 00000000..8091500c
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py
@@ -0,0 +1,52 @@
+"""
+Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example.
+
+Uses shift=3.0 (default turbo shift) for faster denoising convergence.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/Ace-Step1.5",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "An explosive, high-energy pop-rock track with anime theme song feel"
+lyrics = "[Intro]\n\n[Verse 1]\nRunning through the neon lights\nChasing dreams across the night\n\n[Chorus]\nFeel the fire in my soul\nMusic takes complete control"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=8,
+ cfg_scale=1.0,
+ shift=3.0,
+)
+
+sf.write("acestep-v15-turbo-shift3.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
diff --git a/examples/ace_step/model_inference/acestep-v15-xl-base.py b/examples/ace_step/model_inference/acestep-v15-xl-base.py
new file mode 100644
index 00000000..f1c5b4ec
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-xl-base.py
@@ -0,0 +1,52 @@
+"""
+Ace-Step 1.5 XL Base (32 layers, hidden_size=2560) — Text-to-Music inference example.
+
+XL variant with larger capacity for higher quality generation.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-base",
+ origin_file_pattern="model-*.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-base",
+ origin_file_pattern="model-*.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-base",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-base",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-base",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "An epic symphonic metal track with double bass drums and soaring vocals"
+lyrics = "[Intro - Heavy guitar riff]\n\n[Verse 1]\nSteel and thunder, fire and rain\nBurning through the endless pain\n\n[Chorus]\nRise up, break the chains\nUnleash the fire in your veins"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=20,
+ cfg_scale=7.0,
+ shift=3.0,
+)
+
+sf.write("acestep-v15-xl-base.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
diff --git a/examples/ace_step/model_inference/acestep-v15-xl-sft.py b/examples/ace_step/model_inference/acestep-v15-xl-sft.py
new file mode 100644
index 00000000..73d54d96
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-xl-sft.py
@@ -0,0 +1,50 @@
+"""
+Ace-Step 1.5 XL SFT (32 layers, supervised fine-tuned) — Text-to-Music inference example.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-sft",
+ origin_file_pattern="model-*.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-sft",
+ origin_file_pattern="model-*.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-sft",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-sft",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-sft",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "A beautiful piano ballad with lush strings and emotional vocals, cinematic feel"
+lyrics = "[Intro - Solo piano]\n\n[Verse 1]\nWhispers of a distant shore\nMemories I hold so dear\n\n[Chorus]\nIn your eyes I see the dawn\nAll my fears are gone"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=20,
+ cfg_scale=7.0,
+ shift=3.0,
+)
+
+sf.write("acestep-v15-xl-sft.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
diff --git a/examples/ace_step/model_inference/acestep-v15-xl-turbo.py b/examples/ace_step/model_inference/acestep-v15-xl-turbo.py
new file mode 100644
index 00000000..9116567f
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-xl-turbo.py
@@ -0,0 +1,52 @@
+"""
+Ace-Step 1.5 XL Turbo (32 layers) — Text-to-Music inference example.
+
+XL turbo with fast generation (8 steps, shift=3.0, no CFG).
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+import torch
+import soundfile as sf
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-turbo",
+ origin_file_pattern="model-*.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-turbo",
+ origin_file_pattern="model-*.safetensors"
+ ),
+ ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-turbo",
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ ),
+ ],
+ tokenizer_config=ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-turbo",
+ origin_file_pattern="Qwen3-Embedding-0.6B/"
+ ),
+ vae_config=ModelConfig(
+ model_id="ACE-Step/acestep-v15-xl-turbo",
+ origin_file_pattern="vae/"
+ ),
+)
+
+prompt = "An upbeat electronic dance track with pulsing synths and driving bassline"
+lyrics = "[Intro - Synth build]\n\n[Verse 1]\nFeel the rhythm in the air\nElectric beats are everywhere\n\n[Drop]\n\n[Chorus]\nDance until the break of dawn\nMove your body, carry on"
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=30.0,
+ seed=42,
+ num_inference_steps=8,
+ cfg_scale=1.0, # turbo: no CFG
+ shift=3.0,
+)
+
+sf.write("acestep-v15-xl-turbo.wav", audio.cpu().numpy(), pipe.sample_rate)
+print(f"Saved, shape: {audio.shape}")
From a604d76339f4f84dc3630eddc09a35d2507821fa Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Fri, 17 Apr 2026 17:45:52 +0800
Subject: [PATCH 02/16] pipeline_t2m
---
examples/ace_step/model_inference/Ace-Step1.5.py | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
index a9fda153..ca3616f1 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -22,21 +22,17 @@
),
ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
),
ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
),
],
- tokenizer_config=ModelConfig(
+ text_tokenizer_config=ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="Qwen3-Embedding-0.6B/"
),
- vae_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/"
- ),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline."
From 9d09e0431c785dcc1619f9634d530e9b4495e18b Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Tue, 21 Apr 2026 13:16:15 +0800
Subject: [PATCH 03/16] acestep t2m
---
diffsynth/diffusion/flow_match.py | 4 +-
diffsynth/models/ace_step_conditioner.py | 25 +-
diffsynth/models/ace_step_dit.py | 4 +-
diffsynth/models/ace_step_text_encoder.py | 27 -
diffsynth/models/ace_step_vae.py | 1 +
diffsynth/pipelines/ace_step.py | 554 +++++++++---------
diffsynth/utils/data/audio.py | 1 +
.../model_inference/Ace-Step1.5-SimpleMode.py | 39 +-
.../ace_step/model_inference/Ace-Step1.5.py | 36 +-
9 files changed, 307 insertions(+), 384 deletions(-)
diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py
index f2416838..8e77ea80 100644
--- a/diffsynth/diffusion/flow_match.py
+++ b/diffsynth/diffusion/flow_match.py
@@ -157,10 +157,10 @@ def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=
"""
num_train_timesteps = 1000
sigma_start = denoising_strength
- sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps)
+ sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
- timesteps = sigmas # ACE-Step uses [0, 1] range directly
+ timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
diff --git a/diffsynth/models/ace_step_conditioner.py b/diffsynth/models/ace_step_conditioner.py
index 93fe0d32..76cc502b 100644
--- a/diffsynth/models/ace_step_conditioner.py
+++ b/diffsynth/models/ace_step_conditioner.py
@@ -540,17 +540,9 @@ def forward(
) -> BaseModelOutput:
inputs_embeds = refer_audio_acoustic_hidden_states_packed
inputs_embeds = self.embed_tokens(inputs_embeds)
- # Handle 2D (packed) or 3D (batched) input
- is_packed = inputs_embeds.dim() == 2
- if is_packed:
- seq_len = inputs_embeds.shape[0]
- cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
- position_ids = cache_position.unsqueeze(0)
- inputs_embeds = inputs_embeds.unsqueeze(0)
- else:
- seq_len = inputs_embeds.shape[1]
- cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
- position_ids = cache_position.unsqueeze(0)
+ seq_len = inputs_embeds.shape[1]
+ cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
+ position_ids = cache_position.unsqueeze(0)
dtype = inputs_embeds.dtype
device = inputs_embeds.device
@@ -586,9 +578,8 @@ def forward(
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states[:, 0, :]
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
- if is_packed:
- hidden_states = hidden_states.squeeze(0)
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask
@@ -686,7 +677,7 @@ def forward(
text_attention_mask: Optional[torch.Tensor] = None,
lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None,
- refer_audio_acoustic_hidden_states_packed: Optional[torch.Tensor] = None,
+ reference_latents: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
):
text_hidden_states = self.text_projector(text_hidden_states)
@@ -695,11 +686,7 @@ def forward(
attention_mask=lyric_attention_mask,
)
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
- timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(
- refer_audio_acoustic_hidden_states_packed,
- refer_audio_order_mask
- )
-
+ timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
encoder_hidden_states, encoder_attention_mask = pack_sequences(
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
)
diff --git a/diffsynth/models/ace_step_dit.py b/diffsynth/models/ace_step_dit.py
index c4621feb..d9172771 100644
--- a/diffsynth/models/ace_step_dit.py
+++ b/diffsynth/models/ace_step_dit.py
@@ -165,7 +165,7 @@ def __init__(
self,
in_channels: int,
time_embed_dim: int,
- scale: float = 1000,
+ scale: float = 1,
):
super().__init__()
@@ -711,7 +711,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
- use_cache: Optional[bool] = None,
+ use_cache: Optional[bool] = False,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
diff --git a/diffsynth/models/ace_step_text_encoder.py b/diffsynth/models/ace_step_text_encoder.py
index 58b52a7e..1d0ac57f 100644
--- a/diffsynth/models/ace_step_text_encoder.py
+++ b/diffsynth/models/ace_step_text_encoder.py
@@ -2,17 +2,6 @@
class AceStepTextEncoder(torch.nn.Module):
- """
- Text encoder for ACE-Step using Qwen3-Embedding-0.6B.
-
- Converts text/lyric tokens to hidden state embeddings that are
- further processed by the ACE-Step ConditionEncoder.
-
- Wraps a Qwen3Model transformers model. Config is manually
- constructed, and model weights are loaded via DiffSynth's
- standard mechanism from safetensors files.
- """
-
def __init__(
self,
):
@@ -49,8 +38,6 @@ def __init__(
)
self.model = Qwen3Model(config)
- self.config = config
- self.hidden_size = config.hidden_size
@torch.no_grad()
def forward(
@@ -58,23 +45,9 @@ def forward(
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
):
- """
- Encode text/lyric tokens to hidden states.
-
- Args:
- input_ids: [B, T] token IDs
- attention_mask: [B, T] attention mask
-
- Returns:
- last_hidden_state: [B, T, hidden_size]
- """
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
return outputs.last_hidden_state
-
- def to(self, *args, **kwargs):
- self.model.to(*args, **kwargs)
- return self
diff --git a/diffsynth/models/ace_step_vae.py b/diffsynth/models/ace_step_vae.py
index dd78a0a6..168f8517 100644
--- a/diffsynth/models/ace_step_vae.py
+++ b/diffsynth/models/ace_step_vae.py
@@ -226,6 +226,7 @@ def __init__(
upsampling_ratios=upsampling_ratios,
channel_multiples=channel_multiples,
)
+ self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 09c31786..f9254020 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -3,8 +3,9 @@
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
+import re
import torch
-from typing import Optional
+from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
from ..core.device.npu_compatible_device import get_device_type
@@ -16,6 +17,7 @@
from ..models.ace_step_conditioner import AceStepConditionEncoder
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
+from ..models.ace_step_tokenizer import AceStepTokenizer
class AceStepPipeline(BasePipeline):
@@ -32,29 +34,18 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
self.text_encoder: AceStepTextEncoder = None
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
- self.vae = None # AutoencoderOobleck (diffusers) or AceStepVAE
-
- # Unit chain order — 7 units total
- #
- # 1. ShapeChecker: duration → seq_len
- # 2. PromptEmbedder: prompt/lyrics → text/lyric embeddings (shared for CFG)
- # 3. SilenceLatentInitializer: seq_len → src_latents + chunk_masks
- # 4. ContextLatentBuilder: src_latents + chunk_masks → context_latents (shared, same for CFG+)
- # 5. ConditionEmbedder: text/lyric → encoder_hidden_states (separate for CFG+/-)
- # 6. NoiseInitializer: context_latents → noise
- # 7. InputAudioEmbedder: noise → latents
- #
- # ContextLatentBuilder runs before ConditionEmbedder so that
- # context_latents is available for noise shape computation.
+ self.vae: AceStepVAE = None
+ self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
+
self.in_iteration_models = ("dit",)
self.units = [
- AceStepUnit_ShapeChecker(),
AceStepUnit_PromptEmbedder(),
- AceStepUnit_SilenceLatentInitializer(),
- AceStepUnit_ContextLatentBuilder(),
+ AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
+ AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
+ AceStepUnit_AudioCodeDecoder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
@@ -66,7 +57,8 @@ def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: str = get_device_type(),
model_configs: list[ModelConfig] = [],
- text_tokenizer_config: ModelConfig = None,
+ text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
"""Load pipeline from pretrained checkpoints."""
@@ -77,11 +69,15 @@ def from_pretrained(
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
+ pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
text_tokenizer_config.download_if_necessary()
from transformers import AutoTokenizer
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
+ if silence_latent_config is not None:
+ silence_latent_config.download_if_necessary()
+ pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -97,9 +93,19 @@ def __call__(
# Lyrics
lyrics: str = "",
# Reference audio (optional, for timbre conditioning)
- reference_audio = None,
+ reference_audios: List[torch.Tensor] = None,
+ # Src audio
+ src_audio: torch.Tensor = None,
+ denoising_strength: float = 1.0,
+ # Simple Mode: LLM-generated audio codes (optional)
+ audio_codes: str = None,
# Shape
- duration: float = 60.0,
+ duration: int = 60,
+ # Audio Meta
+ bpm: Optional[int] = 100,
+ keyscale: Optional[str] = "B minor",
+ timesignature: Optional[str] = "4",
+ vocal_language: Optional[str] = 'zh',
# Randomness
seed: int = None,
rand_device: str = "cpu",
@@ -111,11 +117,7 @@ def __call__(
progress_bar_cmd=tqdm,
):
# 1. Scheduler
- self.scheduler.set_timesteps(
- num_inference_steps=num_inference_steps,
- denoising_strength=1.0,
- shift=shift,
- )
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
# 2. 三字典输入
inputs_posi = {"prompt": prompt}
@@ -123,8 +125,11 @@ def __call__(
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
- "reference_audio": reference_audio,
+ "reference_audios": reference_audios,
+ "src_audio": src_audio,
+ "audio_codes": audio_codes,
"duration": duration,
+ "bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
"rand_device": rand_device,
"num_inference_steps": num_inference_steps,
@@ -159,6 +164,10 @@ def __call__(
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
audio = self.output_audio_format_check(audio_output)
+
+ # Peak normalization to match target library behavior
+ audio = self.normalize_audio(audio, target_db=-1.0)
+
self.load_models_to_device([])
return audio
@@ -172,294 +181,303 @@ def output_audio_format_check(self, audio_output):
audio_output = audio_output.squeeze(0)
return audio_output.float()
+ def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
+ """Apply peak normalization to audio data, matching target library behavior.
-class AceStepUnit_ShapeChecker(PipelineUnit):
- """Check and compute sequence length from duration."""
- def __init__(self):
- super().__init__(
- input_params=("duration",),
- output_params=("duration", "seq_len"),
- )
+ Target library reference: `acestep/audio_utils.py:normalize_audio()`
+ peak = max(abs(audio))
+ gain = 10^(target_db/20) / peak
+ audio = audio * gain
- def process(self, pipe, duration):
- # ACE-Step: 25 Hz latent rate
- seq_len = int(duration * 25)
- return {"duration": duration, "seq_len": seq_len}
+ Args:
+ audio: Audio tensor [C, T]
+ target_db: Target peak level in dB (default: -1.0)
+ Returns:
+ Normalized audio tensor
+ """
+ peak = torch.max(torch.abs(audio))
+ if peak < 1e-6:
+ return audio
+ target_amp = 10 ** (target_db / 20.0)
+ gain = target_amp / peak
+ return audio * gain
class AceStepUnit_PromptEmbedder(PipelineUnit):
- """Encode prompt and lyrics using Qwen3-Embedding.
+ SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
+ INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
+ LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
- Uses seperate_cfg=True to read prompt from inputs_posi (not inputs_shared).
- The negative condition uses null_condition_emb (handled by ConditionEmbedder),
- so negative text encoding is not needed here.
- """
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
- input_params_nega={},
- input_params=("lyrics",),
+ input_params_nega={"prompt": "prompt"},
+ input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
- def _encode_text(self, pipe, text):
+ def _encode_text(self, pipe, text, max_length=256):
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
- if pipe.tokenizer is None:
- return None, None
text_inputs = pipe.tokenizer(
text,
- padding="max_length",
- max_length=512,
+ max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
- attention_mask = text_inputs.attention_mask.to(pipe.device)
+ attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder(input_ids, attention_mask)
return hidden_states, attention_mask
- def process(self, pipe, prompt, lyrics, negative_prompt=None):
- pipe.load_models_to_device(['text_encoder'])
-
- text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt)
-
- # Lyrics encoding — use empty string if not provided
- lyric_text = lyrics if lyrics else ""
- lyric_hidden_states, lyric_attention_mask = self._encode_text(pipe, lyric_text)
-
- if text_hidden_states is not None and lyric_hidden_states is not None:
- return {
- "text_hidden_states": text_hidden_states,
- "text_attention_mask": text_attention_mask,
- "lyric_hidden_states": lyric_hidden_states,
- "lyric_attention_mask": lyric_attention_mask,
- }
- return {}
-
-
-class AceStepUnit_SilenceLatentInitializer(PipelineUnit):
- """Generate silence latent (all zeros) and chunk_masks for text2music.
-
- Target library reference: `prepare_condition()` line 1698-1699:
- context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
-
- For text2music mode:
- - src_latents = zeros [B, T, 64] (VAE latent dimension)
- - chunk_masks = ones [B, T, 64] (full visibility mask for text2music)
- - context_latents = [B, T, 128] (concat of src_latents + chunk_masks)
- """
- def __init__(self):
- super().__init__(
- input_params=("seq_len",),
- output_params=("silence_latent", "src_latents", "chunk_masks"),
+ def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
+ text_inputs = pipe.tokenizer(
+ lyric_text,
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
)
+ input_ids = text_inputs.input_ids.to(pipe.device)
+ attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
+ hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
+ return hidden_states, attention_mask
- def process(self, pipe, seq_len):
- # silence_latent shape: [B, T, 64] — 64 is the VAE latent dimension
- silence_latent = torch.zeros(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
- # For text2music: src_latents = silence_latent
- src_latents = silence_latent.clone()
-
- # chunk_masks: [B, T, 64] of ones (same shape as src_latents)
- # In text2music mode (is_covers=0), chunk_masks are all 1.0
- # This matches the target library's behavior at line 1699
- chunk_masks = torch.ones(1, seq_len, 64, device=pipe.device, dtype=pipe.torch_dtype)
-
- return {"silence_latent": silence_latent, "src_latents": src_latents, "chunk_masks": chunk_masks}
-
+ def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
+ bpm = meta_dict.get("bpm", "N/A")
+ timesignature = meta_dict.get("timesignature", "N/A")
+ keyscale = meta_dict.get("keyscale", "N/A")
+ duration = meta_dict.get("duration", 30)
+ duration = f"{int(duration)} seconds"
+ return (
+ f"- bpm: {bpm}\n"
+ f"- timesignature: {timesignature}\n"
+ f"- keyscale: {keyscale}\n"
+ f"- duration: {duration}\n"
+ )
-class AceStepUnit_ContextLatentBuilder(PipelineUnit):
- """Build context_latents from src_latents and chunk_masks.
+ def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
+ pipe.load_models_to_device(['text_encoder'])
+ meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
+ prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
+ text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
+
+ lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
+ lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
+
+ # TODO: remove this
+ newtext = prompt + "\n\n" + lyric_text
+ return {
+ "text_hidden_states": text_hidden_states,
+ "text_attention_mask": text_attention_mask,
+ "lyric_hidden_states": lyric_hidden_states,
+ "lyric_attention_mask": lyric_attention_mask,
+ }
- Target library reference: `prepare_condition()` line 1699:
- context_latents = torch.cat([src_latents, chunk_masks.to(dtype)], dim=-1)
- context_latents is the SAME for positive and negative CFG paths
- (it comes from src_latents + chunk_masks, not from text encoding).
- So this is a普通模式 Unit — outputs go to inputs_shared.
- """
+class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("src_latents", "chunk_masks"),
- output_params=("context_latents", "attention_mask"),
+ input_params=("reference_audios",),
+ output_params=("reference_latents", "refer_audio_order_mask"),
+ onload_model_names=("vae",)
)
- def process(self, pipe, src_latents, chunk_masks):
- # context_latents: cat([src_latents, chunk_masks], dim=-1) → [B, T, 128]
- context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
-
- # attention_mask for the DiT: ones [B, T]
- # The target library uses this for cross-attention with context_latents
- attention_mask = torch.ones(src_latents.shape[0], src_latents.shape[1],
- device=pipe.device, dtype=pipe.torch_dtype)
-
- return {"context_latents": context_latents, "attention_mask": attention_mask}
+ def process(self, pipe, reference_audios):
+ pipe.load_models_to_device(['vae'])
+ if reference_audios is not None and len(reference_audios) > 0:
+ # TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
+ pass
+ else:
+ reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
+ reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
+ return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
+
+ def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Infer packed reference-audio latents and order mask."""
+ refer_audio_order_mask = []
+ refer_audio_latents = []
+
+ def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
+ if not isinstance(a, torch.Tensor):
+ raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
+ if a.dim() == 3 and a.shape[0] == 1:
+ a = a.squeeze(0)
+ if a.dim() == 1:
+ a = a.unsqueeze(0)
+ if a.dim() != 2:
+ raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
+ if a.shape[0] == 1:
+ a = torch.cat([a, a], dim=0)
+ return a[:2]
+
+ def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
+ if z.dim() == 4 and z.shape[0] == 1:
+ z = z.squeeze(0)
+ if z.dim() == 2:
+ z = z.unsqueeze(0)
+ return z
+
+ refer_encode_cache: Dict[int, torch.Tensor] = {}
+ for batch_idx, refer_audios in enumerate(refer_audioss):
+ if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
+ refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
+ refer_audio_latents.append(refer_audio_latent)
+ refer_audio_order_mask.append(batch_idx)
+ else:
+ # TODO: check
+ for refer_audio in refer_audios:
+ cache_key = refer_audio.data_ptr()
+ if cache_key in refer_encode_cache:
+ refer_audio_latent = refer_encode_cache[cache_key].clone()
+ else:
+ refer_audio = _normalize_audio_2d(refer_audio)
+ refer_audio_latent = pipe.vae.encode(refer_audio)
+ refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
+ if refer_audio_latent.dim() == 2:
+ refer_audio_latent = refer_audio_latent.unsqueeze(0)
+ refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
+ refer_encode_cache[cache_key] = refer_audio_latent
+ refer_audio_latents.append(refer_audio_latent)
+ refer_audio_order_mask.append(batch_idx)
+
+ refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
+ refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
+ return refer_audio_latents, refer_audio_order_mask
class AceStepUnit_ConditionEmbedder(PipelineUnit):
- """Generate encoder_hidden_states via ACEStepConditioner.
- Target library reference: `prepare_condition()` line 1674-1681:
- encoder_hidden_states, encoder_attention_mask = self.encoder(...)
-
- Uses seperate_cfg mode:
- - Positive: encode with full condition (text + lyrics + reference audio)
- - Negative: replace text with null_condition_emb, keep lyrics/timbre same
-
- context_latents is handled by ContextLatentBuilder (普通模式), not here.
- """
def __init__(self):
super().__init__(
seperate_cfg=True,
- input_params_posi={
- "text_hidden_states": "text_hidden_states",
- "text_attention_mask": "text_attention_mask",
- "lyric_hidden_states": "lyric_hidden_states",
- "lyric_attention_mask": "lyric_attention_mask",
- "reference_audio": "reference_audio",
- "refer_audio_order_mask": "refer_audio_order_mask",
- },
- input_params_nega={},
- input_params=("cfg_scale",),
- output_params=(
- "encoder_hidden_states", "encoder_attention_mask",
- "negative_encoder_hidden_states", "negative_encoder_attention_mask",
- ),
- onload_model_names=("conditioner",)
+ input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
+ input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
+ input_params=("reference_latents", "refer_audio_order_mask"),
+ output_params=("encoder_hidden_states", "encoder_attention_mask"),
+ onload_model_names=("conditioner",),
)
- def _prepare_condition(self, pipe, text_hidden_states, text_attention_mask,
- lyric_hidden_states, lyric_attention_mask,
- refer_audio_acoustic_hidden_states_packed=None,
- refer_audio_order_mask=None):
- """Call ACEStepConditioner forward to produce encoder_hidden_states."""
+ def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
pipe.load_models_to_device(['conditioner'])
-
- # Handle reference audio
- if refer_audio_acoustic_hidden_states_packed is None:
- # No reference audio: create 2D packed zeros [N=1, d=64]
- # TimbreEncoder.unpack expects [N, d], not [B, T, d]
- refer_audio_acoustic_hidden_states_packed = torch.zeros(
- 1, 64, device=pipe.device, dtype=pipe.torch_dtype
- )
- refer_audio_order_mask = torch.LongTensor([0]).to(pipe.device)
-
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
- refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
+ reference_latents=reference_latents,
refer_audio_order_mask=refer_audio_order_mask,
)
+ return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
- return encoder_hidden_states, encoder_attention_mask
-
- def _prepare_negative_condition(self, pipe, lyric_hidden_states, lyric_attention_mask,
- refer_audio_acoustic_hidden_states_packed=None,
- refer_audio_order_mask=None):
- """Generate negative condition using null_condition_emb."""
- if pipe.conditioner is None or not hasattr(pipe.conditioner, 'null_condition_emb'):
- return None, None
-
- null_emb = pipe.conditioner.null_condition_emb # [1, 1, hidden_size]
- bsz = 1
- if lyric_hidden_states is not None:
- bsz = lyric_hidden_states.shape[0]
- null_hidden_states = null_emb.expand(bsz, -1, -1)
- null_attn_mask = torch.ones(bsz, 1, device=pipe.device, dtype=pipe.torch_dtype)
-
- # For negative: use null_condition_emb as text, keep lyrics and timbre
- neg_encoder_hidden_states, neg_encoder_attention_mask = pipe.conditioner(
- text_hidden_states=null_hidden_states,
- text_attention_mask=null_attn_mask,
- lyric_hidden_states=lyric_hidden_states,
- lyric_attention_mask=lyric_attention_mask,
- refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
- refer_audio_order_mask=refer_audio_order_mask,
- )
-
- return neg_encoder_hidden_states, neg_encoder_attention_mask
- def process(self, pipe, text_hidden_states, text_attention_mask,
- lyric_hidden_states, lyric_attention_mask,
- reference_audio=None, refer_audio_order_mask=None,
- negative_prompt=None, cfg_scale=1.0):
-
- # Positive condition
- pos_enc_hs, pos_enc_mask = self._prepare_condition(
- pipe, text_hidden_states, text_attention_mask,
- lyric_hidden_states, lyric_attention_mask,
- None, refer_audio_order_mask,
+class AceStepUnit_ContextLatentBuilder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("duration", "src_audio"),
+ output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
)
- # Negative condition: only needed when CFG is active (cfg_scale > 1.0)
- # For cfg_scale=1.0 (turbo), skip to avoid null_condition_emb dimension mismatch
- result = {
- "encoder_hidden_states": pos_enc_hs,
- "encoder_attention_mask": pos_enc_mask,
- }
-
- if cfg_scale > 1.0:
- neg_enc_hs, neg_enc_mask = self._prepare_negative_condition(
- pipe, lyric_hidden_states, lyric_attention_mask,
- None, refer_audio_order_mask,
- )
- if neg_enc_hs is not None:
- result["negative_encoder_hidden_states"] = neg_enc_hs
- result["negative_encoder_attention_mask"] = neg_enc_mask
-
- return result
+ def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
+ available = pipe.silence_latent.shape[1]
+ if length <= available:
+ return pipe.silence_latent[0, :length, :]
+ repeats = (length + available - 1) // available
+ tiled = pipe.silence_latent[0].repeat(repeats, 1)
+ return tiled[:length, :]
+
+ def process(self, pipe, duration, src_audio):
+ if src_audio is not None:
+ raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
+ else:
+ max_latent_length = duration * pipe.sample_rate // 1920
+ src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
+ chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
+ attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
+ context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
+ return {"context_latents": context_latents, "attention_mask": attention_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
- """Generate initial noise tensor.
-
- Target library reference: `prepare_noise()` line 1781-1818:
- src_latents_shape = (bsz, context_latents.shape[1], context_latents.shape[-1] // 2)
-
- Noise shape = [B, T, context_latents.shape[-1] // 2] = [B, T, 128 // 2] = [B, T, 64]
- """
def __init__(self):
super().__init__(
- input_params=("seed", "seq_len", "rand_device", "context_latents"),
+ input_params=("context_latents", "seed", "rand_device"),
output_params=("noise",),
)
- def process(self, pipe, seed, seq_len, rand_device, context_latents):
- # Noise shape: [B, T, context_latents.shape[-1] // 2]
- # context_latents = [B, T, 128] → noise = [B, T, 64]
- # This matches the target library's prepare_noise() at line 1796
- noise_shape = (context_latents.shape[0], context_latents.shape[1],
- context_latents.shape[-1] // 2)
- noise = pipe.generate_noise(
- noise_shape,
- seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
- )
+ def process(self, pipe, context_latents, seed, rand_device):
+ src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
+ noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
-
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
- """Set up latents for denoise loop.
-
- For text2music (no input audio): latents = noise, input_latents = None.
-
- Target library reference: `generate_audio()` line 1972:
- xt = noise (when cover_noise_strength == 0)
- """
def __init__(self):
super().__init__(
- input_params=("noise",),
+ input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
)
- def process(self, pipe, noise):
- # For text2music: start from pure noise
+ def process(self, pipe, noise, input_audio):
+ if input_audio is None:
+ return {"latents": noise}
+ # TODO: support for train
return {"latents": noise, "input_latents": None}
+class AceStepUnit_AudioCodeDecoder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("audio_codes", "seq_len", "silence_latent"),
+ output_params=("lm_hints_25Hz",),
+ onload_model_names=("tokenizer_model",),
+ )
+
+ @staticmethod
+ def _parse_audio_code_string(code_str: str) -> list:
+ """Extract integer audio codes from tokens like <|audio_code_123|>."""
+ if not code_str:
+ return []
+ codes = []
+ max_audio_code = 63999
+ for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
+ code_value = int(x)
+ codes.append(max(0, min(code_value, max_audio_code)))
+ return codes
+
+ def process(self, pipe, audio_codes, seq_len, silence_latent):
+ if audio_codes is None or not audio_codes.strip():
+ return {"lm_hints_25Hz": None}
+
+ code_ids = self._parse_audio_code_string(audio_codes)
+ if len(code_ids) == 0:
+ return {"lm_hints_25Hz": None}
+
+ pipe.load_models_to_device(["tokenizer_model"])
+
+ quantizer = pipe.tokenizer_model.tokenizer.quantizer
+ detokenizer = pipe.tokenizer_model.detokenizer
+
+ indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
+ indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
+
+ quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
+ if quantized.dtype != pipe.torch_dtype:
+ quantized = quantized.to(pipe.torch_dtype)
+
+ lm_hints = detokenizer(quantized) # [1, N*5, 64]
+
+ # Pad or truncate to seq_len
+ current_len = lm_hints.shape[1]
+ if current_len < seq_len:
+ pad_len = seq_len - current_len
+ pad = silence_latent[:, :pad_len, :]
+ lm_hints = torch.cat([lm_hints, pad], dim=1)
+ elif current_len > seq_len:
+ lm_hints = lm_hints[:, :seq_len, :]
+
+ return {"lm_hints_25Hz": lm_hints}
+
+
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,
@@ -468,49 +486,11 @@ def model_fn_ace_step(
encoder_attention_mask=None,
context_latents=None,
attention_mask=None,
- past_key_values=None,
- negative_encoder_hidden_states=None,
- negative_encoder_attention_mask=None,
- negative_context_latents=None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
**kwargs,
):
- """Model function for ACE-Step DiT forward.
-
- Timestep is already in [0, 1] range — no scaling needed.
-
- Target library reference: `generate_audio()` line 2009-2020:
- decoder_outputs = self.decoder(
- hidden_states=x, timestep=t_curr_tensor, timestep_r=t_curr_tensor,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- context_latents=context_latents,
- use_cache=True, past_key_values=past_key_values,
- )
-
- Args:
- dit: AceStepDiTModel
- latents: [B, T, 64] noise/latent tensor (same shape as src_latents)
- timestep: scalar tensor in [0, 1]
- encoder_hidden_states: [B, T_text, 2048] condition from Conditioner
- (positive or negative depending on CFG pass — the cfg_guided_model_fn
- passes inputs_posi for positive, inputs_nega for negative)
- encoder_attention_mask: [B, T_text]
- context_latents: [B, T, 128] = cat([src_latents, chunk_masks], dim=-1)
- (same for both CFG+/- paths in text2music mode)
- attention_mask: [B, T] ones mask for DiT
- past_key_values: EncoderDecoderCache for KV caching
-
- The DiT internally concatenates: cat([context_latents, latents], dim=-1) = [B, T, 192]
- as the actual input (128 + 64 = 192 channels).
- """
- # ACE-Step uses timestep directly in [0, 1] range — no /1000 scaling
- timestep = timestep.squeeze()
-
- # Expand timestep to match batch size
- bsz = latents.shape[0]
- timestep = timestep.expand(bsz)
-
+ timestep = timestep.unsqueeze(0)
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,
@@ -519,9 +499,5 @@ def model_fn_ace_step(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
- use_cache=True,
- past_key_values=past_key_values,
- )
-
- # Return velocity prediction (first element of decoder_outputs)
- return decoder_outputs[0]
+ )[0]
+ return decoder_outputs
diff --git a/diffsynth/utils/data/audio.py b/diffsynth/utils/data/audio.py
index 414fcb2e..1add550a 100644
--- a/diffsynth/utils/data/audio.py
+++ b/diffsynth/utils/data/audio.py
@@ -99,6 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
"""
if waveform.dim() == 3:
waveform = waveform[0]
+ waveform.cpu()
if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder
diff --git a/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
index edcf2cd4..261c548d 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
@@ -2,8 +2,8 @@
Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
Uses the ACE-Step LLM to expand a simple description into structured
-parameters (caption, lyrics, bpm, keyscale, etc.), then feeds them
-to the DiffSynth Pipeline.
+parameters (caption, lyrics, bpm, keyscale, etc.) AND audio codes,
+then feeds them to the DiffSynth Pipeline.
The LLM expansion uses the target library's LLMHandler. If vLLM is
not available, it falls back to using pre-structured parameters.
@@ -47,11 +47,14 @@ def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-
def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
- """Expand a simple description using LLM Chain-of-Thought."""
+ """Expand a simple description using LLM Chain-of-Thought.
+
+ Returns (params_dict, audio_codes_string).
+ """
result = llm_handler.generate_with_stop_condition(
caption=description,
lyrics="",
- infer_type="dit", # metadata only
+ infer_type="dit", # metadata + audio codes
temperature=0.85,
cfg_scale=1.0,
use_cot_metas=True,
@@ -62,7 +65,7 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
if result.get("success") and result.get("metadata"):
meta = result["metadata"]
- return {
+ params = {
"caption": meta.get("caption", description),
"lyrics": meta.get("lyrics", ""),
"bpm": meta.get("bpm", 100),
@@ -71,9 +74,11 @@ def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
"timesignature": meta.get("timesignature", "4"),
"duration": meta.get("duration", duration),
}
+ audio_codes = result.get("audio_codes", "")
+ return params, audio_codes
print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
- return None
+ return None, ""
def fallback_expand(description: str, duration: float = 30.0):
@@ -87,7 +92,7 @@ def fallback_expand(description: str, duration: float = 30.0):
"language": "en",
"timesignature": "4",
"duration": duration,
- }
+ }, ""
# ---------------------------------------------------------------------------
@@ -114,13 +119,13 @@ def main():
lm_model_path="acestep-5Hz-lm-1.7B",
)
- # 2. Expand parameters
+ # 2. Expand parameters + audio codes
if llm_ok:
- params = expand_with_llm(llm_handler, description, duration=duration)
+ params, audio_codes = expand_with_llm(llm_handler, description, duration=duration)
if params is None:
- params = fallback_expand(description, duration)
+ params, audio_codes = fallback_expand(description, duration)
else:
- params = fallback_expand(description, duration)
+ params, audio_codes = fallback_expand(description, duration)
print(f"\n[Simple Mode] Parameters:")
print(f" Caption: {params['caption'][:100]}...")
@@ -128,6 +133,7 @@ def main():
print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
print(f" Duration: {params['duration']}s")
+ print(f" Audio codes: {len(audio_codes)} chars" if audio_codes else " Audio codes: None (fallback)")
# 3. Load Pipeline
print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
@@ -141,21 +147,17 @@ def main():
),
ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
+ origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
),
ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
+ origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
),
],
- tokenizer_config=ModelConfig(
+ text_tokenizer_config=ModelConfig(
model_id="ACE-Step/Ace-Step1.5",
origin_file_pattern="Qwen3-Embedding-0.6B/"
),
- vae_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/"
- ),
)
# 4. Generate
@@ -164,6 +166,7 @@ def main():
prompt=params["caption"],
lyrics=params["lyrics"],
duration=params["duration"],
+ audio_codes=audio_codes if audio_codes else None,
seed=42,
num_inference_steps=8,
cfg_scale=1.0,
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
index ca3616f1..5442a050 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -1,16 +1,6 @@
-"""
-Ace-Step 1.5 — Text-to-Music (Turbo) inference example.
-
-Demonstrates the standard text2music pipeline with structured parameters
-(caption, lyrics, duration, etc.) — no LLM expansion needed.
-
-For Simple Mode (LLM expands a short description), see:
- - Ace-Step1.5-SimpleMode.py
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
import torch
-import soundfile as sf
-
+from diffsynth.utils.data.audio import save_audio
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
@@ -35,29 +25,21 @@
),
)
-prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline."
-lyrics = """[Intro - Synth Brass Fanfare]
-
-[Verse 1]
-黑夜里的风吹过耳畔
-甜蜜时光转瞬即逝
-脚步飘摇在星光上
-
-[Chorus]
-心电感应在震动间
-拥抱未来勇敢冒险
-
-[Outro - Instrumental]"""
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
num_inference_steps=8,
cfg_scale=1.0,
- shift=3.0,
)
-sf.write("Ace-Step1.5.wav", audio.cpu().numpy(), pipe.sample_rate)
+save_audio(audio.cpu(), pipe.vae.sampling_rate, "Ace-Step1.5.wav")
print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
From 95cfb77881d68cfd7cfecbe1c09e99dfc44eeffe Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Tue, 21 Apr 2026 19:42:57 +0800
Subject: [PATCH 04/16] t2m
---
diffsynth/configs/model_configs.py | 31 ---
diffsynth/diffusion/base_pipeline.py | 2 +-
diffsynth/models/ace_step_lm.py | 79 --------
diffsynth/pipelines/ace_step.py | 188 +++++++++---------
.../ace_step/model_inference/Ace-Step1.5.py | 45 +++--
.../model_inference/acestep-v15-base.py | 51 ++---
6 files changed, 135 insertions(+), 261 deletions(-)
delete mode 100644 diffsynth/models/ace_step_lm.py
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index ee97fec3..ccaa1460 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -962,37 +962,6 @@
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
},
- # === LLM variants ===
- {
- # Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-0.6B", origin_file_pattern="model.safetensors")
- "model_hash": "f3ab4bef9e00745fd0fea7aa8b2a4041",
- "model_name": "ace_step_lm",
- "model_class": "diffsynth.models.ace_step_lm.AceStepLM",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
- "extra_kwargs": {
- "variant": "acestep-5Hz-lm-0.6B",
- },
- },
- {
- # Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-5Hz-lm-1.7B/model.safetensors")
- "model_hash": "a14b6e422b0faa9b41e7efe0fee46766",
- "model_name": "ace_step_lm",
- "model_class": "diffsynth.models.ace_step_lm.AceStepLM",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
- "extra_kwargs": {
- "variant": "acestep-5Hz-lm-1.7B",
- },
- },
- {
- # Example: ModelConfig(model_id="ACE-Step/acestep-5Hz-lm-4B", origin_file_pattern="model-*.safetensors")
- "model_hash": "046a3934f2e6f2f6d450bad23b1f4933",
- "model_name": "ace_step_lm",
- "model_class": "diffsynth.models.ace_step_lm.AceStepLM",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_lm.ace_step_lm_converter",
- "extra_kwargs": {
- "variant": "acestep-5Hz-lm-4B",
- },
- },
# === Qwen3-Embedding (text encoder) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py
index 588f765a..8c87c00c 100644
--- a/diffsynth/diffusion/base_pipeline.py
+++ b/diffsynth/diffusion/base_pipeline.py
@@ -152,7 +152,7 @@ def output_audio_format_check(self, audio_output):
# remove batch dim
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
- return audio_output.float()
+ return audio_output.float().cpu()
def load_models_to_device(self, model_names):
if self.vram_management_enabled:
diff --git a/diffsynth/models/ace_step_lm.py b/diffsynth/models/ace_step_lm.py
deleted file mode 100644
index fc8c0817..00000000
--- a/diffsynth/models/ace_step_lm.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import torch
-
-
-LM_CONFIGS = {
- "acestep-5Hz-lm-0.6B": {
- "hidden_size": 1024,
- "intermediate_size": 3072,
- "num_hidden_layers": 28,
- "num_attention_heads": 16,
- "layer_types": ["full_attention"] * 28,
- "max_window_layers": 28,
- },
- "acestep-5Hz-lm-1.7B": {
- "hidden_size": 2048,
- "intermediate_size": 6144,
- "num_hidden_layers": 28,
- "num_attention_heads": 16,
- "layer_types": ["full_attention"] * 28,
- "max_window_layers": 28,
- },
- "acestep-5Hz-lm-4B": {
- "hidden_size": 2560,
- "intermediate_size": 9728,
- "num_hidden_layers": 36,
- "num_attention_heads": 32,
- "layer_types": ["full_attention"] * 36,
- "max_window_layers": 36,
- },
-}
-
-
-class AceStepLM(torch.nn.Module):
- """
- Language model for ACE-Step.
-
- Converts natural language prompts into structured parameters
- (caption, lyrics, bpm, keyscale, duration, timesignature, etc.)
- for ACE-Step music generation.
-
- Wraps a Qwen3ForCausalLM transformers model. Config is manually
- constructed based on variant type, and model weights are loaded
- via DiffSynth's standard mechanism from safetensors files.
- """
-
- def __init__(
- self,
- variant: str = "acestep-5Hz-lm-1.7B",
- ):
- super().__init__()
- from transformers import Qwen3Config, Qwen3ForCausalLM
-
- config_params = LM_CONFIGS[variant]
-
- config = Qwen3Config(
- attention_bias=False,
- attention_dropout=0.0,
- bos_token_id=151643,
- dtype="bfloat16",
- eos_token_id=151645,
- head_dim=128,
- hidden_act="silu",
- initializer_range=0.02,
- max_position_embeddings=40960,
- model_type="qwen3",
- num_key_value_heads=8,
- pad_token_id=151643,
- rms_norm_eps=1e-06,
- rope_scaling=None,
- rope_theta=1000000,
- sliding_window=None,
- tie_word_embeddings=True,
- use_cache=True,
- use_sliding_window=False,
- vocab_size=217204,
- **config_params,
- )
-
- self.model = Qwen3ForCausalLM(config)
- self.config = config
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index f9254020..bf0f2351 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -42,10 +42,10 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
+ AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
- AceStepUnit_AudioCodeDecoder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
@@ -92,27 +92,27 @@ def __call__(
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
- # Reference audio (optional, for timbre conditioning)
+ # Reference audio
reference_audios: List[torch.Tensor] = None,
# Src audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
- # Simple Mode: LLM-generated audio codes (optional)
- audio_codes: str = None,
+ # Audio codes
+ audio_code_string: Optional[str] = None,
# Shape
duration: int = 60,
# Audio Meta
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
- vocal_language: Optional[str] = 'zh',
+ vocal_language: Optional[str] = 'unknown',
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
# Scheduler-specific parameters
- shift: float = 3.0,
+ shift: float = 1.0,
# Progress
progress_bar_cmd=tqdm,
):
@@ -120,14 +120,14 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
# 2. 三字典输入
- inputs_posi = {"prompt": prompt}
- inputs_nega = {"negative_prompt": negative_prompt}
+ inputs_posi = {"prompt": prompt, "positive": True}
+ inputs_nega = {"positive": False}
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"reference_audios": reference_audios,
"src_audio": src_audio,
- "audio_codes": audio_codes,
+ "audio_code_string": audio_code_string,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
@@ -145,12 +145,13 @@ def __call__(
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
+ self.momentum_buffer = MomentumBuffer()
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
- **models, timestep=timestep, progress_id=progress_id
+ **models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
@@ -163,39 +164,12 @@ def __call__(
vae_output = self.vae.decode(latents)
# VAE returns OobleckDecoderOutput with .sample attribute
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
+ audio_output = self.normalize_audio(audio_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
-
- # Peak normalization to match target library behavior
- audio = self.normalize_audio(audio, target_db=-1.0)
-
self.load_models_to_device([])
return audio
- def output_audio_format_check(self, audio_output):
- """Convert VAE output to standard audio format [C, T], float32.
-
- VAE decode outputs [B, C, T] (audio waveform).
- We squeeze batch dim and return [C, T].
- """
- if audio_output.ndim == 3:
- audio_output = audio_output.squeeze(0)
- return audio_output.float()
-
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
- """Apply peak normalization to audio data, matching target library behavior.
-
- Target library reference: `acestep/audio_utils.py:normalize_audio()`
- peak = max(abs(audio))
- gain = 10^(target_db/20) / peak
- audio = audio * gain
-
- Args:
- audio: Audio tensor [C, T]
- target_db: Target peak level in dB (default: -1.0)
-
- Returns:
- Normalized audio tensor
- """
peak = torch.max(torch.abs(audio))
if peak < 1e-6:
return audio
@@ -203,17 +177,46 @@ def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch
gain = target_amp / peak
return audio * gain
+
+class AceStepUnit_TaskTypeChecker(PipelineUnit):
+ """Check and compute sequence length from duration."""
+ def __init__(self):
+ super().__init__(
+ input_params=("src_audio", "audio_code_string"),
+ output_params=("task_type",),
+ )
+
+ def process(self, pipe, src_audio, audio_code_string):
+ if audio_code_string is not None:
+ print("audio_code_string detected, setting task_type to 'cover'")
+ task_type = "cover"
+ else:
+ task_type = "text2music"
+ return {"task_type": task_type}
+
+
class AceStepUnit_PromptEmbedder(PipelineUnit):
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
- INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
+ INSTRUCTION_MAP = {
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
+ "cover": "Generate audio semantic tokens based on the given conditions:",
+
+ "repaint": "Repaint the mask area based on the given conditions:",
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
+ "extract_default": "Extract the track from the audio:",
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
+ "lego_default": "Generate the track based on the audio context:",
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
+ "complete_default": "Complete the input track:",
+ }
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
def __init__(self):
super().__init__(
seperate_cfg=True,
- input_params_posi={"prompt": "prompt"},
- input_params_nega={"prompt": "prompt"},
- input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language"),
+ input_params_posi={"prompt": "prompt", "positive": "positive"},
+ input_params_nega={"prompt": "prompt", "positive": "positive"},
+ input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
@@ -256,10 +259,13 @@ def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
f"- duration: {duration}\n"
)
- def process(self, pipe, prompt, lyrics, duration, bpm, keyscale, timesignature, vocal_language):
+ def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
+ if not positive:
+ return {}
pipe.load_models_to_device(['text_encoder'])
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
- prompt = self.SFT_GEN_PROMPT.format(self.INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
+ INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
+ prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
@@ -350,31 +356,32 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
- seperate_cfg=True,
- input_params_posi={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
- input_params_nega={"text_hidden_states": "text_hidden_states", "text_attention_mask": "text_attention_mask", "lyric_hidden_states": "lyric_hidden_states", "lyric_attention_mask": "lyric_attention_mask"},
- input_params=("reference_latents", "refer_audio_order_mask"),
+ take_over=True,
output_params=("encoder_hidden_states", "encoder_attention_mask"),
onload_model_names=("conditioner",),
)
- def process(self, pipe, text_hidden_states, text_attention_mask, lyric_hidden_states, lyric_attention_mask, reference_latents, refer_audio_order_mask):
+ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
pipe.load_models_to_device(['conditioner'])
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
- text_hidden_states=text_hidden_states,
- text_attention_mask=text_attention_mask,
- lyric_hidden_states=lyric_hidden_states,
- lyric_attention_mask=lyric_attention_mask,
- reference_latents=reference_latents,
- refer_audio_order_mask=refer_audio_order_mask,
+ text_hidden_states=inputs_posi.get("text_hidden_states", None),
+ text_attention_mask=inputs_posi.get("text_attention_mask", None),
+ lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
+ lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
+ reference_latents=inputs_shared.get("reference_latents", None),
+ refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
- return {"encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask}
+ inputs_posi["encoder_hidden_states"] = encoder_hidden_states
+ inputs_posi["encoder_attention_mask"] = encoder_attention_mask
+ inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
+ inputs_nega["encoder_attention_mask"] = encoder_attention_mask
+ return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("duration", "src_audio"),
+ input_params=("duration", "src_audio", "lm_hints"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
)
@@ -386,9 +393,15 @@ def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
- def process(self, pipe, duration, src_audio):
- if src_audio is not None:
- raise NotImplementedError("Src audio conditioning is not implemented yet. Please set src_audio to None.")
+ def process(self, pipe, duration, src_audio, lm_hints):
+ if lm_hints is not None:
+ max_latent_length = lm_hints.shape[1]
+ src_latents = lm_hints.clone()
+ chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
+ attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
+ context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
+ elif src_audio is not None:
+ raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
@@ -410,6 +423,7 @@ def process(self, pipe, context_latents, seed, rand_device):
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
+
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
@@ -427,8 +441,8 @@ def process(self, pipe, noise, input_audio):
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("audio_codes", "seq_len", "silence_latent"),
- output_params=("lm_hints_25Hz",),
+ input_params=("audio_code_string",),
+ output_params=("lm_hints",),
onload_model_names=("tokenizer_model",),
)
@@ -437,45 +451,29 @@ def _parse_audio_code_string(code_str: str) -> list:
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
if not code_str:
return []
- codes = []
- max_audio_code = 63999
- for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
- code_value = int(x)
- codes.append(max(0, min(code_value, max_audio_code)))
+ try:
+ codes = []
+ max_audio_code = 63999
+ for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
+ code_value = int(x)
+ codes.append(max(0, min(code_value, max_audio_code)))
+ except Exception as e:
+ raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
- def process(self, pipe, audio_codes, seq_len, silence_latent):
- if audio_codes is None or not audio_codes.strip():
- return {"lm_hints_25Hz": None}
-
- code_ids = self._parse_audio_code_string(audio_codes)
+ def process(self, pipe, audio_code_string):
+ if audio_code_string is None or not audio_code_string.strip():
+ return {"lm_hints": None}
+ code_ids = self._parse_audio_code_string(audio_code_string)
if len(code_ids) == 0:
- return {"lm_hints_25Hz": None}
+ return {"lm_hints": None}
pipe.load_models_to_device(["tokenizer_model"])
-
- quantizer = pipe.tokenizer_model.tokenizer.quantizer
- detokenizer = pipe.tokenizer_model.detokenizer
-
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
-
- quantized = quantizer.get_output_from_indices(indices) # [1, N, 2048]
- if quantized.dtype != pipe.torch_dtype:
- quantized = quantized.to(pipe.torch_dtype)
-
- lm_hints = detokenizer(quantized) # [1, N*5, 64]
-
- # Pad or truncate to seq_len
- current_len = lm_hints.shape[1]
- if current_len < seq_len:
- pad_len = seq_len - current_len
- pad = silence_latent[:, :pad_len, :]
- lm_hints = torch.cat([lm_hints, pad], dim=1)
- elif current_len > seq_len:
- lm_hints = lm_hints[:, :seq_len, :]
-
- return {"lm_hints_25Hz": lm_hints}
+ quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
+ lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
+ return {"lm_hints": lm_hints}
def model_fn_ace_step(
@@ -499,5 +497,7 @@ def model_fn_ace_step(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0]
return decoder_outputs
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
index 5442a050..219cb318 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -1,33 +1,20 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
-import torch
from diffsynth.utils.data.audio import save_audio
+import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- text_tokenizer_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
-
audio = pipe(
prompt=prompt,
lyrics=lyrics,
@@ -41,5 +28,23 @@
cfg_scale=1.0,
)
-save_audio(audio.cpu(), pipe.vae.sampling_rate, "Ace-Step1.5.wav")
-print(f"Saved to Ace-Step1.5.wav, shape: {audio.shape}, duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
+
+# input audio codes as reference
+with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
+ audio_code_string = f.read().strip()
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ audio_code_string=audio_code_string,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=8,
+ cfg_scale=1.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-base.py b/examples/ace_step/model_inference/acestep-v15-base.py
index 480a6fec..28b72ea9 100644
--- a/examples/ace_step/model_inference/acestep-v15-base.py
+++ b/examples/ace_step/model_inference/acestep-v15-base.py
@@ -1,52 +1,31 @@
-"""
-Ace-Step 1.5 Base (non-turbo, 24 layers) — Text-to-Music inference example.
-
-Uses cfg_scale=7.0 (standard CFG guidance) and more steps for higher quality.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-base/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-base/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "A cinematic orchestral piece with soaring strings and heroic brass"
-lyrics = "[Intro - Orchestra]\n\n[Verse 1]\nAcross the mountains, through the valley\nA journey of a thousand miles\n\n[Chorus]\nRise above the stormy skies\nLet the music carry you"
-
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=20,
- cfg_scale=7.0, # Base model uses CFG
- shift=3.0,
+ num_inference_steps=30,
+ cfg_scale=4.0,
)
-
-sf.write("acestep-v15-base.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base.wav")
From f5a3201d425393be40057d8dd5bbba6c3aa35ffc Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Tue, 21 Apr 2026 20:12:15 +0800
Subject: [PATCH 05/16] t2m
---
diffsynth/pipelines/ace_step.py | 1 -
.../model_inference/Ace-Step1.5-SimpleMode.py | 183 ------------------
.../model_inference/acestep-v15-sft.py | 48 ++---
.../acestep-v15-turbo-continuous.py | 36 ++++
.../acestep-v15-turbo-shift1.py | 46 ++---
.../acestep-v15-turbo-shift3.py | 47 ++---
.../model_inference/acestep-v15-xl-base.py | 46 ++---
.../model_inference/acestep-v15-xl-sft.py | 47 ++---
.../model_inference/acestep-v15-xl-turbo.py | 48 ++---
9 files changed, 133 insertions(+), 369 deletions(-)
delete mode 100644 examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
create mode 100644 examples/ace_step/model_inference/acestep-v15-turbo-continuous.py
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index bf0f2351..2f3256ce 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -145,7 +145,6 @@ def __call__(
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
- self.momentum_buffer = MomentumBuffer()
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
diff --git a/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py b/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
deleted file mode 100644
index 261c548d..00000000
--- a/examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
+++ /dev/null
@@ -1,183 +0,0 @@
-"""
-Ace-Step 1.5 — Text-to-Music with Simple Mode (LLM expansion).
-
-Uses the ACE-Step LLM to expand a simple description into structured
-parameters (caption, lyrics, bpm, keyscale, etc.) AND audio codes,
-then feeds them to the DiffSynth Pipeline.
-
-The LLM expansion uses the target library's LLMHandler. If vLLM is
-not available, it falls back to using pre-structured parameters.
-
-Usage:
- python examples/ace_step/model_inference/Ace-Step1.5-SimpleMode.py
-"""
-import os
-import sys
-import json
-import torch
-import soundfile as sf
-
-from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
-
-
-# ---------------------------------------------------------------------------
-# Simple Mode: LLM expansion
-# ---------------------------------------------------------------------------
-
-def try_load_llm_handler(checkpoint_dir: str, lm_model_path: str = "acestep-5Hz-lm-1.7B",
- backend: str = "vllm"):
- """Try to load the target library's LLMHandler. Returns (handler, success)."""
- try:
- from acestep.llm_inference import LLMHandler
- handler = LLMHandler()
- status, success = handler.initialize(
- checkpoint_dir=checkpoint_dir,
- lm_model_path=lm_model_path,
- backend=backend,
- )
- if success:
- print(f"[Simple Mode] LLM loaded via {backend} backend: {status}")
- return handler, True
- else:
- print(f"[Simple Mode] LLM init failed: {status}")
- return None, False
- except Exception as e:
- print(f"[Simple Mode] LLMHandler not available: {e}")
- return None, False
-
-
-def expand_with_llm(llm_handler, description: str, duration: float = 30.0):
- """Expand a simple description using LLM Chain-of-Thought.
-
- Returns (params_dict, audio_codes_string).
- """
- result = llm_handler.generate_with_stop_condition(
- caption=description,
- lyrics="",
- infer_type="dit", # metadata + audio codes
- temperature=0.85,
- cfg_scale=1.0,
- use_cot_metas=True,
- use_cot_caption=True,
- use_cot_language=True,
- user_metadata={"duration": int(duration)},
- )
-
- if result.get("success") and result.get("metadata"):
- meta = result["metadata"]
- params = {
- "caption": meta.get("caption", description),
- "lyrics": meta.get("lyrics", ""),
- "bpm": meta.get("bpm", 100),
- "keyscale": meta.get("keyscale", ""),
- "language": meta.get("language", "en"),
- "timesignature": meta.get("timesignature", "4"),
- "duration": meta.get("duration", duration),
- }
- audio_codes = result.get("audio_codes", "")
- return params, audio_codes
-
- print(f"[Simple Mode] LLM expansion failed: {result.get('error', 'unknown')}")
- return None, ""
-
-
-def fallback_expand(description: str, duration: float = 30.0):
- """Fallback: use description as caption with default parameters."""
- print(f"[Simple Mode] LLM not available. Using description as caption.")
- return {
- "caption": description,
- "lyrics": "",
- "bpm": 100,
- "keyscale": "",
- "language": "en",
- "timesignature": "4",
- "duration": duration,
- }, ""
-
-
-# ---------------------------------------------------------------------------
-# Main
-# ---------------------------------------------------------------------------
-
-def main():
- # Target library path (for LLMHandler)
- TARGET_LIB = os.path.join(os.path.dirname(__file__), "../../../../ACE-Step-1.5")
- if TARGET_LIB not in sys.path:
- sys.path.insert(0, TARGET_LIB)
-
- description = "a soft Bengali love song for a quiet evening"
- duration = 30.0
-
- # 1. Try to load LLM
- print("=" * 60)
- print("Ace-Step 1.5 — Simple Mode (LLM expansion)")
- print("=" * 60)
- print(f"\n[Simple Mode] Input: '{description}'")
-
- llm_handler, llm_ok = try_load_llm_handler(
- checkpoint_dir=TARGET_LIB,
- lm_model_path="acestep-5Hz-lm-1.7B",
- )
-
- # 2. Expand parameters + audio codes
- if llm_ok:
- params, audio_codes = expand_with_llm(llm_handler, description, duration=duration)
- if params is None:
- params, audio_codes = fallback_expand(description, duration)
- else:
- params, audio_codes = fallback_expand(description, duration)
-
- print(f"\n[Simple Mode] Parameters:")
- print(f" Caption: {params['caption'][:100]}...")
- print(f" Lyrics: {len(params['lyrics'])} chars")
- print(f" BPM: {params['bpm']}, Keyscale: {params['keyscale']}")
- print(f" Language: {params['language']}, Time Sig: {params['timesignature']}")
- print(f" Duration: {params['duration']}s")
- print(f" Audio codes: {len(audio_codes)} chars" if audio_codes else " Audio codes: None (fallback)")
-
- # 3. Load Pipeline
- print(f"\n[Pipeline] Loading Ace-Step 1.5 (turbo)...")
- pipe = AceStepPipeline.from_pretrained(
- torch_dtype=torch.bfloat16,
- device="cuda",
- model_configs=[
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
- ),
- ],
- text_tokenizer_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- )
-
- # 4. Generate
- print(f"\n[Generation] Running Pipeline...")
- audio = pipe(
- prompt=params["caption"],
- lyrics=params["lyrics"],
- duration=params["duration"],
- audio_codes=audio_codes if audio_codes else None,
- seed=42,
- num_inference_steps=8,
- cfg_scale=1.0,
- shift=3.0,
- )
-
- output_path = "Ace-Step1.5-SimpleMode.wav"
- sf.write(output_path, audio.cpu().numpy(), pipe.sample_rate)
- print(f"\n[Done] Saved to {output_path}")
- print(f" Shape: {audio.shape}, Duration: {audio.shape[-1] / pipe.sample_rate:.1f}s")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/ace_step/model_inference/acestep-v15-sft.py b/examples/ace_step/model_inference/acestep-v15-sft.py
index c9ec0fff..0c573ec2 100644
--- a/examples/ace_step/model_inference/acestep-v15-sft.py
+++ b/examples/ace_step/model_inference/acestep-v15-sft.py
@@ -1,52 +1,38 @@
"""
-Ace-Step 1.5 SFT (supervised fine-tuned, 24 layers) — Text-to-Music inference example.
+Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example.
SFT variant is fine-tuned for specific music styles.
+Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-sft/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-sft/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "A jazzy lo-fi beat with smooth saxophone and vinyl crackle, late night vibes"
-lyrics = "[Intro - Vinyl crackle]\n\n[Verse 1]\nMidnight city, neon glow\nSmooth jazz flowing to and fro\n\n[Chorus]\nLay back, let the music play\nJazzy nights, dreams drift away"
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=20,
- cfg_scale=7.0,
- shift=3.0,
+ num_inference_steps=30,
+ cfg_scale=4.0,
)
-
-sf.write("acestep-v15-sft.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py b/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py
new file mode 100644
index 00000000..f587a8f8
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py
@@ -0,0 +1,36 @@
+"""
+Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example.
+
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+Continuous variant: handles shift range internally, no shift parameter needed.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py b/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py
index 447f6b0d..cdebafe7 100644
--- a/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py
+++ b/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py
@@ -1,52 +1,36 @@
"""
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example.
-Uses shift=1.0 (no timestep transformation) for smoother, slower denoising.
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+shift=1: default value, no need to pass.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "A gentle acoustic guitar melody with soft piano accompaniment, peaceful and warm atmosphere"
-lyrics = "[Verse 1]\nSunlight filtering through the trees\nA quiet moment, just the breeze\n\n[Chorus]\nPeaceful heart, open mind\nLeaving all the noise behind"
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=8,
- cfg_scale=1.0,
- shift=1.0, # shift=1: no timestep transformation
)
-
-sf.write("acestep-v15-turbo-shift1.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py b/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py
index 8091500c..7b761659 100644
--- a/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py
+++ b/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py
@@ -1,52 +1,37 @@
"""
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example.
-Uses shift=3.0 (default turbo shift) for faster denoising convergence.
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="acestep-v15-turbo/model.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/Ace-Step1.5",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "An explosive, high-energy pop-rock track with anime theme song feel"
-lyrics = "[Intro]\n\n[Verse 1]\nRunning through the neon lights\nChasing dreams across the night\n\n[Chorus]\nFeel the fire in my soul\nMusic takes complete control"
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=8,
- cfg_scale=1.0,
- shift=3.0,
+ shift=3,
)
-
-sf.write("acestep-v15-turbo-shift3.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-xl-base.py b/examples/ace_step/model_inference/acestep-v15-xl-base.py
index f1c5b4ec..fac5b0c3 100644
--- a/examples/ace_step/model_inference/acestep-v15-xl-base.py
+++ b/examples/ace_step/model_inference/acestep-v15-xl-base.py
@@ -2,51 +2,37 @@
Ace-Step 1.5 XL Base (32 layers, hidden_size=2560) — Text-to-Music inference example.
XL variant with larger capacity for higher quality generation.
+Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-base",
- origin_file_pattern="model-*.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-base",
- origin_file_pattern="model-*.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-base",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-base",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-base",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "An epic symphonic metal track with double bass drums and soaring vocals"
-lyrics = "[Intro - Heavy guitar riff]\n\n[Verse 1]\nSteel and thunder, fire and rain\nBurning through the endless pain\n\n[Chorus]\nRise up, break the chains\nUnleash the fire in your veins"
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=20,
- cfg_scale=7.0,
- shift=3.0,
+ num_inference_steps=30,
+ cfg_scale=4.0,
)
-
-sf.write("acestep-v15-xl-base.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-xl-sft.py b/examples/ace_step/model_inference/acestep-v15-xl-sft.py
index 73d54d96..d62508dc 100644
--- a/examples/ace_step/model_inference/acestep-v15-xl-sft.py
+++ b/examples/ace_step/model_inference/acestep-v15-xl-sft.py
@@ -1,50 +1,37 @@
"""
Ace-Step 1.5 XL SFT (32 layers, supervised fine-tuned) — Text-to-Music inference example.
+
+Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-sft",
- origin_file_pattern="model-*.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-sft",
- origin_file_pattern="model-*.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-sft",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-sft",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-sft",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "A beautiful piano ballad with lush strings and emotional vocals, cinematic feel"
-lyrics = "[Intro - Solo piano]\n\n[Verse 1]\nWhispers of a distant shore\nMemories I hold so dear\n\n[Chorus]\nIn your eyes I see the dawn\nAll my fears are gone"
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=20,
- cfg_scale=7.0,
- shift=3.0,
+ num_inference_steps=30,
+ cfg_scale=4.0,
)
-
-sf.write("acestep-v15-xl-sft.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-xl-turbo.py b/examples/ace_step/model_inference/acestep-v15-xl-turbo.py
index 9116567f..c23c6111 100644
--- a/examples/ace_step/model_inference/acestep-v15-xl-turbo.py
+++ b/examples/ace_step/model_inference/acestep-v15-xl-turbo.py
@@ -1,52 +1,36 @@
"""
-Ace-Step 1.5 XL Turbo (32 layers) — Text-to-Music inference example.
+Ace-Step 1.5 XL Turbo (32 layers, fast generation) — Text-to-Music inference example.
-XL turbo with fast generation (8 steps, shift=3.0, no CFG).
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
import torch
-import soundfile as sf
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-turbo",
- origin_file_pattern="model-*.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-turbo",
- origin_file_pattern="model-*.safetensors"
- ),
- ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-turbo",
- origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"
- ),
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
- tokenizer_config=ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-turbo",
- origin_file_pattern="Qwen3-Embedding-0.6B/"
- ),
- vae_config=ModelConfig(
- model_id="ACE-Step/acestep-v15-xl-turbo",
- origin_file_pattern="vae/"
- ),
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
-prompt = "An upbeat electronic dance track with pulsing synths and driving bassline"
-lyrics = "[Intro - Synth build]\n\n[Verse 1]\nFeel the rhythm in the air\nElectric beats are everywhere\n\n[Drop]\n\n[Chorus]\nDance until the break of dawn\nMove your body, carry on"
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
- duration=30.0,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
seed=42,
- num_inference_steps=8,
- cfg_scale=1.0, # turbo: no CFG
- shift=3.0,
)
-
-sf.write("acestep-v15-xl-turbo.wav", audio.cpu().numpy(), pipe.sample_rate)
-print(f"Saved, shape: {audio.shape}")
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo.wav")
From b0680ef711192545e05060ed21b63c178aae8776 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Wed, 22 Apr 2026 12:47:38 +0800
Subject: [PATCH 06/16] low_vram
---
.../configs/vram_management_module_maps.py | 37 +++++++++++
diffsynth/models/ace_step_dit.py | 4 +-
diffsynth/models/ace_step_tokenizer.py | 2 +-
diffsynth/models/ace_step_vae.py | 8 ++-
diffsynth/pipelines/ace_step.py | 19 ++++--
.../ace_step/model_inference/Ace-Step1.5.py | 4 --
.../model_inference_low_vram/Ace-Step1.5.py | 66 +++++++++++++++++++
.../acestep-v15-base.py | 49 ++++++++++++++
.../acestep-v15-sft.py | 51 ++++++++++++++
.../acestep-v15-turbo-continuous.py | 49 ++++++++++++++
.../acestep-v15-turbo-shift1.py | 49 ++++++++++++++
.../acestep-v15-turbo-shift3.py | 50 ++++++++++++++
.../acestep-v15-xl-base.py | 51 ++++++++++++++
.../acestep-v15-xl-sft.py | 50 ++++++++++++++
.../acestep-v15-xl-turbo.py | 48 ++++++++++++++
15 files changed, 523 insertions(+), 14 deletions(-)
create mode 100644 examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-base.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-sft.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py
diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py
index 8d4800b3..299a830e 100644
--- a/diffsynth/configs/vram_management_module_maps.py
+++ b/diffsynth/configs/vram_management_module_maps.py
@@ -295,6 +295,43 @@
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
+ # ACE-Step module maps
+ "diffsynth.models.ace_step_dit.AceStepDiTModel": {
+ "diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
+ "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
+ "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
+ "diffsynth.models.ace_step_vae.AceStepVAE": {
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
+ "diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ },
}
def QwenImageTextEncoder_Module_Map_Updater():
diff --git a/diffsynth/models/ace_step_dit.py b/diffsynth/models/ace_step_dit.py
index d9172771..16669dcb 100644
--- a/diffsynth/models/ace_step_dit.py
+++ b/diffsynth/models/ace_step_dit.py
@@ -522,7 +522,7 @@ def forward(
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb
+ self.scale_shift_table.to(temb.device) + temb
).chunk(6, dim=1)
# Step 1: Self-attention with adaptive layer norm (AdaLN)
@@ -889,7 +889,7 @@ def forward(
return hidden_states
# Extract scale-shift parameters for adaptive output normalization
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
diff --git a/diffsynth/models/ace_step_tokenizer.py b/diffsynth/models/ace_step_tokenizer.py
index c01e9d50..5bd0e74e 100644
--- a/diffsynth/models/ace_step_tokenizer.py
+++ b/diffsynth/models/ace_step_tokenizer.py
@@ -594,7 +594,7 @@ def forward(
x = self.embed_tokens(x)
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
special_tokens = self.special_tokens.expand(B, T, -1, -1)
- x = x + special_tokens
+ x = x + special_tokens.to(x.device)
x = rearrange(x, "b t p c -> (b t) p c")
cache_position = torch.arange(0, x.shape[1], device=x.device)
diff --git a/diffsynth/models/ace_step_vae.py b/diffsynth/models/ace_step_vae.py
index 168f8517..ae5b501a 100644
--- a/diffsynth/models/ace_step_vae.py
+++ b/diffsynth/models/ace_step_vae.py
@@ -22,7 +22,7 @@
import torch
import torch.nn as nn
-from torch.nn.utils import weight_norm
+from torch.nn.utils import weight_norm, remove_weight_norm
class Snake1d(nn.Module):
@@ -240,3 +240,9 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decoder(z)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization from all conv layers (for export/inference)."""
+ for module in self.modules():
+ if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
+ remove_weight_norm(module)
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 2f3256ce..d369da63 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -69,6 +69,7 @@ def from_pretrained(
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
+ pipe.vae.remove_weight_norm()
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
@@ -372,8 +373,9 @@ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
)
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
- inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
- inputs_nega["encoder_attention_mask"] = encoder_attention_mask
+ if inputs_shared["cfg_scale"] != 1.0:
+ inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
+ inputs_nega["encoder_attention_mask"] = encoder_attention_mask
return inputs_shared, inputs_posi, inputs_nega
@@ -468,10 +470,15 @@ def process(self, pipe, audio_code_string):
return {"lm_hints": None}
pipe.load_models_to_device(["tokenizer_model"])
- indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
- indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
- quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
- lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
+ quantizer = pipe.tokenizer_model.tokenizer.quantizer
+ detokenizer = pipe.tokenizer_model.detokenizer
+
+ indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
+ codes = quantizer.get_codes_from_indices(indices)
+ quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
+ quantized = quantizer.project_out(quantized)
+
+ lm_hints = detokenizer(quantized).to(pipe.device)
return {"lm_hints": lm_hints}
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
index 219cb318..f0983968 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -24,8 +24,6 @@
timesignature="4",
vocal_language="zh",
seed=42,
- num_inference_steps=8,
- cfg_scale=1.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
@@ -44,7 +42,5 @@
timesignature="4",
vocal_language="zh",
seed=42,
- num_inference_steps=8,
- cfg_scale=1.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes.wav")
diff --git a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
new file mode 100644
index 00000000..3ccd39dd
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
@@ -0,0 +1,66 @@
+"""
+Ace-Step 1.5 (main model, turbo) — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
+
+# input audio codes as reference
+with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
+ audio_code_string = f.read().strip()
+
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ audio_code_string=audio_code_string,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base.py
new file mode 100644
index 00000000..fc997f2e
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base.py
@@ -0,0 +1,49 @@
+"""
+Ace-Step 1.5 Base — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py b/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py
new file mode 100644
index 00000000..189c26a6
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py
@@ -0,0 +1,51 @@
+"""
+Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+SFT variant is fine-tuned for specific music styles.
+Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py
new file mode 100644
index 00000000..420bc933
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py
@@ -0,0 +1,49 @@
+"""
+Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+Continuous variant: handles shift range internally, no shift parameter needed.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py
new file mode 100644
index 00000000..cfa1583c
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py
@@ -0,0 +1,49 @@
+"""
+Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+shift=1: default value, no need to pass.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py
new file mode 100644
index 00000000..aa2af9c9
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py
@@ -0,0 +1,50 @@
+"""
+Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+shift=3: explicitly passed for this variant.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ shift=3,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py
new file mode 100644
index 00000000..dc772ba9
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py
@@ -0,0 +1,51 @@
+"""
+Ace-Step 1.5 XL Base — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+torch.cuda.reset_peak_memory_stats("cuda")
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py
new file mode 100644
index 00000000..5ac17b08
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py
@@ -0,0 +1,50 @@
+"""
+Ace-Step 1.5 XL SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft-low-vram.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py
new file mode 100644
index 00000000..53a5ec5e
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py
@@ -0,0 +1,48 @@
+"""
+Ace-Step 1.5 XL Turbo — Text-to-Music inference example (Low VRAM).
+
+Low VRAM version: models are offloaded to CPU and loaded on-demand.
+Turbo model: no num_inference_steps or cfg_scale (use defaults).
+"""
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo-low-vram.wav")
From c53c813c12a50c3e03cd973c7ab43b7408f30b97 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Wed, 22 Apr 2026 17:58:10 +0800
Subject: [PATCH 07/16] ace-step train
---
diffsynth/core/data/operators.py | 25 +++
diffsynth/models/ace_step_dit.py | 21 +--
diffsynth/models/ace_step_vae.py | 47 ++++-
diffsynth/pipelines/ace_step.py | 78 ++++++--
.../ace_step/model_inference/Ace-Step1.5.py | 1 +
.../model_inference_low_vram/Ace-Step1.5.py | 1 +
.../model_training/full/Ace-Step1.5.sh | 18 ++
.../model_training/full/acestep-v15-base.sh | 18 ++
.../model_training/full/acestep-v15-sft.sh | 18 ++
.../full/acestep-v15-turbo-continuous.sh | 18 ++
.../full/acestep-v15-turbo-shift1.sh | 18 ++
.../full/acestep-v15-turbo-shift3.sh | 18 ++
.../full/acestep-v15-xl-base.sh | 18 ++
.../model_training/full/acestep-v15-xl-sft.sh | 18 ++
.../full/acestep-v15-xl-turbo.sh | 18 ++
.../model_training/lora/Ace-Step1.5.sh | 20 ++
.../model_training/lora/acestep-v15-base.sh | 20 ++
.../model_training/lora/acestep-v15-sft.sh | 20 ++
.../lora/acestep-v15-turbo-continuous.sh | 20 ++
.../lora/acestep-v15-turbo-shift1.sh | 20 ++
.../lora/acestep-v15-turbo-shift3.sh | 20 ++
.../lora/acestep-v15-xl-base.sh | 20 ++
.../model_training/lora/acestep-v15-xl-sft.sh | 20 ++
.../lora/acestep-v15-xl-turbo.sh | 20 ++
examples/ace_step/model_training/train.py | 173 ++++++++++++++++++
.../validate_full/Ace-Step1.5.py | 35 ++++
.../validate_full/acestep-v15-base.py | 35 ++++
.../validate_full/acestep-v15-sft.py | 35 ++++
.../acestep-v15-turbo-continuous.py | 35 ++++
.../validate_full/acestep-v15-turbo-shift1.py | 35 ++++
.../validate_full/acestep-v15-turbo-shift3.py | 35 ++++
.../validate_full/acestep-v15-xl-base.py | 35 ++++
.../validate_full/acestep-v15-xl-sft.py | 35 ++++
.../validate_lora/Ace-Step1.5.py | 33 ++++
.../validate_lora/acestep-v15-base.py | 33 ++++
.../validate_lora/acestep-v15-sft.py | 33 ++++
.../acestep-v15-turbo-continuous.py | 33 ++++
.../validate_lora/acestep-v15-turbo-shift1.py | 33 ++++
.../validate_lora/acestep-v15-turbo-shift3.py | 33 ++++
.../validate_lora/acestep-v15-xl-base.py | 33 ++++
.../validate_lora/acestep-v15-xl-sft.py | 33 ++++
.../validate_lora/acestep-v15-xl-turbo.py | 33 ++++
42 files changed, 1235 insertions(+), 30 deletions(-)
create mode 100644 examples/ace_step/model_training/full/Ace-Step1.5.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-base.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-sft.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-xl-base.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-xl-sft.sh
create mode 100644 examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh
create mode 100644 examples/ace_step/model_training/lora/Ace-Step1.5.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-base.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-sft.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-xl-base.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh
create mode 100644 examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh
create mode 100644 examples/ace_step/model_training/train.py
create mode 100644 examples/ace_step/model_training/validate_full/Ace-Step1.5.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-base.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-sft.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py
create mode 100644 examples/ace_step/model_training/validate_lora/Ace-Step1.5.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-base.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-sft.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py
create mode 100644 examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py
diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py
index 9288705f..53bdb7ed 100644
--- a/diffsynth/core/data/operators.py
+++ b/diffsynth/core/data/operators.py
@@ -3,6 +3,7 @@
import imageio.v3 as iio
from PIL import Image
import torchaudio
+from diffsynth.utils.data.audio import read_audio
class DataProcessingPipeline:
@@ -276,3 +277,27 @@ def __call__(self, data: str):
except:
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
return None
+
+
+class LoadPureAudioWithTorchaudio(DataProcessingOperator):
+
+ def __init__(self, target_sample_rate=None, target_duration=None):
+ self.target_sample_rate = target_sample_rate
+ self.target_duration = target_duration
+ self.resample = True if target_sample_rate is not None else False
+
+ def __call__(self, data: str):
+ try:
+ waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
+ if self.target_duration is not None:
+ target_samples = int(self.target_duration * sample_rate)
+ current_samples = waveform.shape[-1]
+ if current_samples > target_samples:
+ waveform = waveform[..., :target_samples]
+ elif current_samples < target_samples:
+ padding = target_samples - current_samples
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
+ return waveform, sample_rate
+ except Exception as e:
+ warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
+ return None
diff --git a/diffsynth/models/ace_step_dit.py b/diffsynth/models/ace_step_dit.py
index 16669dcb..d8f270e7 100644
--- a/diffsynth/models/ace_step_dit.py
+++ b/diffsynth/models/ace_step_dit.py
@@ -864,20 +864,13 @@ def forward(
layer_kwargs = flash_attn_kwargs
# Use gradient checkpointing if enabled
- if use_gradient_checkpointing or use_gradient_checkpointing_offload:
- layer_outputs = gradient_checkpoint_forward(
- layer_module,
- use_gradient_checkpointing,
- use_gradient_checkpointing_offload,
- *layer_args,
- **layer_kwargs,
- )
- else:
- layer_outputs = layer_module(
- *layer_args,
- **layer_kwargs,
- )
-
+ layer_outputs = gradient_checkpoint_forward(
+ layer_module,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ *layer_args,
+ **layer_kwargs,
+ )
hidden_states = layer_outputs[0]
if output_attentions and self.layers[index_block].use_cross_attention:
diff --git a/diffsynth/models/ace_step_vae.py b/diffsynth/models/ace_step_vae.py
index ae5b501a..047e199f 100644
--- a/diffsynth/models/ace_step_vae.py
+++ b/diffsynth/models/ace_step_vae.py
@@ -191,6 +191,43 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return self.conv2(hidden_state)
+class OobleckDiagonalGaussianDistribution(object):
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ self.parameters = parameters
+ self.mean, self.scale = parameters.chunk(2, dim=1)
+ self.std = nn.functional.softplus(self.scale) + 1e-4
+ self.var = self.std * self.std
+ self.logvar = torch.log(self.var)
+ self.deterministic = deterministic
+
+ def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = torch.randn(
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
+ else:
+ normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
+ var_ratio = self.var / other.var
+ logvar_diff = self.logvar - other.logvar
+
+ kl = normalized_diff + var_ratio + logvar_diff - 1
+
+ kl = kl.sum(1).mean()
+ return kl
+
+
class AceStepVAE(nn.Module):
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
@@ -229,17 +266,19 @@ def __init__(
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
- """Audio waveform [B, audio_channels, T] → latent [B, encoder_hidden_size, T']."""
- return self.encoder(x)
+ """Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
+ h = self.encoder(x)
+ output = OobleckDiagonalGaussianDistribution(h).sample()
+ return output
def decode(self, z: torch.Tensor) -> torch.Tensor:
- """Latent [B, encoder_hidden_size, T] → audio waveform [B, audio_channels, T']."""
+ """Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
return self.decoder(z)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""Full round-trip: encode → decode."""
z = self.encode(sample)
- return self.decoder(z)
+ return self.decode(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index d369da63..71d180c0 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -7,6 +7,7 @@
import torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
+import random
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -89,13 +90,14 @@ def __call__(
self,
# Prompt
prompt: str,
- negative_prompt: str = "",
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
+ # Task type
+ task_type: Optional[str] = "text2music",
# Reference audio
reference_audios: List[torch.Tensor] = None,
- # Src audio
+ # Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0,
# Audio codes
@@ -126,6 +128,7 @@ def __call__(
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
+ "task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
"audio_code_string": audio_code_string,
@@ -147,7 +150,7 @@ def __call__(
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
- timestep = timestep.to(dtype=self.torch_dtype, device=self.device)
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
@@ -182,13 +185,14 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
- input_params=("src_audio", "audio_code_string"),
+ input_params=("audio_code_string"),
output_params=("task_type",),
)
- def process(self, pipe, src_audio, audio_code_string):
+ def process(self, pipe, audio_code_string):
+ if pipe.scheduler.training:
+ return {"task_type": "text2music"}
if audio_code_string is not None:
- print("audio_code_string detected, setting task_type to 'cover'")
task_type = "cover"
else:
task_type = "text2music"
@@ -200,7 +204,6 @@ class AceStepUnit_PromptEmbedder(PipelineUnit):
INSTRUCTION_MAP = {
"text2music": "Fill the audio semantic mask based on the given conditions:",
"cover": "Generate audio semantic tokens based on the given conditions:",
-
"repaint": "Repaint the mask area based on the given conditions:",
"extract": "Extract the {TRACK_NAME} track from the audio:",
"extract_default": "Extract the track from the audio:",
@@ -292,6 +295,7 @@ def __init__(self):
def process(self, pipe, reference_audios):
pipe.load_models_to_device(['vae'])
if reference_audios is not None and len(reference_audios) > 0:
+ raise NotImplementedError("Reference audio embedding is not implemented yet.")
# TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
pass
else:
@@ -299,6 +303,49 @@ def process(self, pipe, reference_audios):
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
+ # def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
+
+ # try:
+ # audio_np, sr = _read_audio_file(audio_file)
+ # audio = self._numpy_to_channels_first(audio_np)
+
+ # logger.debug(
+ # f"[process_reference_audio] Reference audio shape: {audio.shape}"
+ # )
+ # logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
+ # logger.debug(
+ # f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
+ # )
+
+ # audio = self._normalize_audio_to_stereo_48k(audio, sr)
+ # if self.is_silence(audio):
+ # return None
+
+ # target_frames = 30 * 48000
+ # segment_frames = 10 * 48000
+
+ # if audio.shape[-1] < target_frames:
+ # repeat_times = math.ceil(target_frames / audio.shape[-1])
+ # audio = audio.repeat(1, repeat_times)
+
+ # total_frames = audio.shape[-1]
+ # segment_size = total_frames // 3
+
+ # front_start = random.randint(0, max(0, segment_size - segment_frames))
+ # front_audio = audio[:, front_start : front_start + segment_frames]
+
+ # middle_start = segment_size + random.randint(
+ # 0, max(0, segment_size - segment_frames)
+ # )
+ # middle_audio = audio[:, middle_start : middle_start + segment_frames]
+
+ # back_start = 2 * segment_size + random.randint(
+ # 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
+ # )
+ # back_audio = audio[:, back_start : back_start + segment_frames]
+
+ # return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
+
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
@@ -401,8 +448,8 @@ def process(self, pipe, duration, src_audio, lm_hints):
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
- elif src_audio is not None:
- raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
+ # elif src_audio is not None:
+ # raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
else:
max_latent_length = duration * pipe.sample_rate // 1920
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
@@ -435,8 +482,16 @@ def __init__(self):
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
- # TODO: support for train
- return {"latents": noise, "input_latents": None}
+ if pipe.scheduler.training:
+ pipe.load_models_to_device(['vae'])
+ input_audio, sample_rate = input_audio
+ input_audio = torch.clamp(input_audio, -1.0, 1.0)
+ if input_audio.dim() == 2:
+ input_audio = input_audio.unsqueeze(0)
+ input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
+ # prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
+ input_latents = input_latents[:, :noise.shape[1]]
+ return {"input_latents": input_latents}
class AceStepUnit_AudioCodeDecoder(PipelineUnit):
@@ -494,7 +549,6 @@ def model_fn_ace_step(
use_gradient_checkpointing_offload=False,
**kwargs,
):
- timestep = timestep.unsqueeze(0)
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
index f0983968..ae41f11c 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -35,6 +35,7 @@
audio = pipe(
prompt=prompt,
lyrics=lyrics,
+ task_type="cover",
audio_code_string=audio_code_string,
duration=160,
bpm=100,
diff --git a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
index 3ccd39dd..4bc2e5e1 100644
--- a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
@@ -55,6 +55,7 @@
audio = pipe(
prompt=prompt,
lyrics=lyrics,
+ task_type="cover",
audio_code_string=audio_code_string,
duration=160,
bpm=100,
diff --git a/examples/ace_step/model_training/full/Ace-Step1.5.sh b/examples/ace_step/model_training/full/Ace-Step1.5.sh
new file mode 100644
index 00000000..bc2558e4
--- /dev/null
+++ b/examples/ace_step/model_training/full/Ace-Step1.5.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/Ace-Step1.5_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-base.sh b/examples/ace_step/model_training/full/acestep-v15-base.sh
new file mode 100644
index 00000000..77330995
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-base.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-base_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-sft.sh b/examples/ace_step/model_training/full/acestep-v15-sft.sh
new file mode 100644
index 00000000..e94aa46e
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-sft.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-sft_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh b/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh
new file mode 100644
index 00000000..d772a279
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-turbo-continuous_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh b/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh
new file mode 100644
index 00000000..6840fa60
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-turbo-shift1_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh b/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh
new file mode 100644
index 00000000..a255e222
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-turbo-shift3_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-xl-base.sh b/examples/ace_step/model_training/full/acestep-v15-xl-base.sh
new file mode 100644
index 00000000..40e30d99
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-xl-base.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-xl-base_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh b/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh
new file mode 100644
index 00000000..8dd6969a
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-xl-sft_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh b/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh
new file mode 100644
index 00000000..2f768d84
--- /dev/null
+++ b/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh
@@ -0,0 +1,18 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --trainable_models "dit" \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-xl-turbo_full" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/Ace-Step1.5.sh b/examples/ace_step/model_training/lora/Ace-Step1.5.sh
new file mode 100644
index 00000000..c8f7207d
--- /dev/null
+++ b/examples/ace_step/model_training/lora/Ace-Step1.5.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/Ace-Step1.5_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-base.sh b/examples/ace_step/model_training/lora/acestep-v15-base.sh
new file mode 100644
index 00000000..7c10325d
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-base.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-base_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-sft.sh b/examples/ace_step/model_training/lora/acestep-v15-sft.sh
new file mode 100644
index 00000000..ac4bf8a3
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-sft.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-sft_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh b/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh
new file mode 100644
index 00000000..778c2d7c
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-turbo-continuous_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh b/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh
new file mode 100644
index 00000000..82368ba7
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-turbo-shift1_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh b/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh
new file mode 100644
index 00000000..b45ece4b
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-turbo-shift3_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh b/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh
new file mode 100644
index 00000000..829ebe72
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-xl-base_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh b/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh
new file mode 100644
index 00000000..18985925
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-xl-sft_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh b/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh
new file mode 100644
index 00000000..94c53cb8
--- /dev/null
+++ b/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh
@@ -0,0 +1,20 @@
+# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
+# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
+
+accelerate launch examples/ace_step/model_training/train.py \
+ --learning_rate 1e-4 \
+ --num_epochs 20 \
+ --lora_rank 32 \
+ --use_gradient_checkpointing \
+ --find_unused_parameters \
+ --dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
+ --dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
+ --model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
+ --tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
+ --silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
+ --lora_base_model "dit" \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --dataset_repeat 50 \
+ --output_path "./models/train/acestep-v15-xl-turbo_lora" \
+ --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
+ --data_file_keys "audio"
diff --git a/examples/ace_step/model_training/train.py b/examples/ace_step/model_training/train.py
new file mode 100644
index 00000000..a24da2c3
--- /dev/null
+++ b/examples/ace_step/model_training/train.py
@@ -0,0 +1,173 @@
+import torch, os, argparse, accelerate, warnings, torchaudio
+import math
+from diffsynth.core import UnifiedDataset
+from diffsynth.core.data.operators import ToAbsolutePath, RouteByType, DataProcessingOperator, LoadPureAudioWithTorchaudio
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.diffusion import *
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+class LoadAceStepAudio(DataProcessingOperator):
+ """Load audio file and return waveform tensor [2, T] at 48kHz."""
+ def __init__(self, target_sr=48000):
+ self.target_sr = target_sr
+
+ def __call__(self, data: str):
+ try:
+ waveform, sample_rate = torchaudio.load(data)
+ if sample_rate != self.target_sr:
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
+ waveform = resampler(waveform)
+ if waveform.shape[0] == 1:
+ waveform = waveform.repeat(2, 1)
+ return waveform
+ except Exception as e:
+ warnings.warn(f"Cannot load audio from {data}: {e}")
+ return None
+
+
+class AceStepTrainingModule(DiffusionTrainingModule):
+ def __init__(
+ self,
+ model_paths=None, model_id_with_origin_paths=None,
+ tokenizer_path=None, silence_latent_path=None,
+ trainable_models=None,
+ lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
+ preset_lora_path=None, preset_lora_model=None,
+ use_gradient_checkpointing=True,
+ use_gradient_checkpointing_offload=False,
+ extra_inputs=None,
+ fp8_models=None,
+ offload_models=None,
+ device="cpu",
+ task="sft",
+ ):
+ super().__init__()
+ # ===== 解析模型配置(固定写法) =====
+ model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
+ # ===== Tokenizer 配置 =====
+ text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
+ silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
+ # ===== 构建 Pipeline =====
+ self.pipe = AceStepPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config)
+ # ===== 拆分 Pipeline Units(固定写法) =====
+ self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
+
+ # ===== 切换到训练模式(固定写法) =====
+ self.switch_pipe_to_training_mode(
+ self.pipe, trainable_models,
+ lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
+ preset_lora_path, preset_lora_model,
+ task=task,
+ )
+
+ # ===== 其他配置(固定写法) =====
+ self.use_gradient_checkpointing = use_gradient_checkpointing
+ self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
+ self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
+ self.fp8_models = fp8_models
+ self.task = task
+ # ===== 任务模式路由(固定写法) =====
+ self.task_to_loss = {
+ "sft:data_process": lambda pipe, *args: args,
+ "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
+ "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
+ }
+
+ def get_pipeline_inputs(self, data):
+ inputs_posi = {"prompt": data["prompt"], "positive": True}
+ inputs_nega = {"positive": False}
+ duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
+ # ===== 共享参数 =====
+ inputs_shared = {
+ # ===== 核心字段映射 =====
+ "input_audio": data["audio"],
+ # ===== 音频生成任务所需元数据 =====
+ "lyrics": data["lyrics"],
+ "task_type": "text2music",
+ "duration": duration,
+ "bpm": data.get("bpm", 100),
+ "keyscale": data.get("keyscale", "C major"),
+ "timesignature": data.get("timesignature", "4"),
+ "vocal_language": data.get("vocal_language", "unknown"),
+ # ===== 框架控制参数(固定写法) =====
+ "cfg_scale": 1,
+ "rand_device": self.pipe.device,
+ "use_gradient_checkpointing": self.use_gradient_checkpointing,
+ "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
+ }
+ # ===== 额外字段注入:通过 --extra_inputs 配置的数据集列名(固定写法) =====
+ inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
+ return inputs_shared, inputs_posi, inputs_nega
+
+ def forward(self, data, inputs=None):
+ # ===== 标准实现,不要修改(固定写法) =====
+ if inputs is None: inputs = self.get_pipeline_inputs(data)
+ inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
+ for unit in self.pipe.units:
+ inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
+ loss = self.task_to_loss[self.task](self.pipe, *inputs)
+ return loss
+
+
+def ace_step_parser():
+ parser = argparse.ArgumentParser(description="ACE-Step training.")
+ parser = add_general_config(parser)
+ parser.add_argument("--tokenizer_path", type=str, default=None, help="Tokenizer path in format model_id:origin_pattern.")
+ parser.add_argument("--silence_latent_path", type=str, default=None, help="Silence latent path in format model_id:origin_pattern.")
+ parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = ace_step_parser()
+ args = parser.parse_args()
+ # ===== Accelerator 配置(固定写法) =====
+ accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
+ )
+ # ===== 数据集定义 =====
+ dataset = UnifiedDataset(
+ base_path=args.dataset_base_path,
+ metadata_path=args.dataset_metadata_path,
+ repeat=args.dataset_repeat,
+ data_file_keys=args.data_file_keys.split(","),
+ main_data_operator=None,
+ special_operator_map={
+ "audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(target_sample_rate=48000),
+ },
+ )
+ # ===== TrainingModule =====
+ model = AceStepTrainingModule(
+ model_paths=args.model_paths,
+ model_id_with_origin_paths=args.model_id_with_origin_paths,
+ tokenizer_path=args.tokenizer_path,
+ silence_latent_path=args.silence_latent_path,
+ trainable_models=args.trainable_models,
+ lora_base_model=args.lora_base_model,
+ lora_target_modules=args.lora_target_modules,
+ lora_rank=args.lora_rank,
+ lora_checkpoint=args.lora_checkpoint,
+ preset_lora_path=args.preset_lora_path,
+ preset_lora_model=args.preset_lora_model,
+ use_gradient_checkpointing=args.use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
+ extra_inputs=args.extra_inputs,
+ fp8_models=args.fp8_models,
+ offload_models=args.offload_models,
+ task=args.task,
+ device="cpu" if args.initialize_model_on_cpu else accelerator.device,
+ )
+ # ===== ModelLogger(固定写法) =====
+ model_logger = ModelLogger(
+ args.output_path,
+ remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
+ )
+ # ===== 任务路由(固定写法) =====
+ launcher_map = {
+ "sft:data_process": launch_data_process_task,
+ "sft": launch_training_task,
+ "sft:train": launch_training_task,
+ }
+ launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
diff --git a/examples/ace_step/model_training/validate_full/Ace-Step1.5.py b/examples/ace_step/model_training/validate_full/Ace-Step1.5.py
new file mode 100644
index 00000000..b50caad2
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/Ace-Step1.5.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/Ace-Step1.5_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-base.py b/examples/ace_step/model_training/validate_full/acestep-v15-base.py
new file mode 100644
index 00000000..4dd20e78
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-base.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-base_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-sft.py b/examples/ace_step/model_training/validate_full/acestep-v15-sft.py
new file mode 100644
index 00000000..b4b2e9c3
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-sft.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-sft_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py b/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py
new file mode 100644
index 00000000..42532e44
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-turbo-continuous_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py b/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py
new file mode 100644
index 00000000..821af1df
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-turbo-shift1_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py b/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py
new file mode 100644
index 00000000..5f70e5c4
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-turbo-shift3_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py b/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py
new file mode 100644
index 00000000..815f1a4d
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-xl-base_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_full.wav")
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py b/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py
new file mode 100644
index 00000000..55eea5fa
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-xl-sft_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_full.wav")
diff --git a/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py b/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py
new file mode 100644
index 00000000..5445ec1c
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/Ace-Step1.5_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-base.py b/examples/ace_step/model_training/validate_lora/acestep-v15-base.py
new file mode 100644
index 00000000..b3abedab
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-base.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-base_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py b/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py
new file mode 100644
index 00000000..2a770104
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-sft_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py b/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py
new file mode 100644
index 00000000..d1512509
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-continuous_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py b/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py
new file mode 100644
index 00000000..ddefdf35
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift1_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py b/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py
new file mode 100644
index 00000000..303d161c
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift3_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py b/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py
new file mode 100644
index 00000000..d2603a77
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-base_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py b/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py
new file mode 100644
index 00000000..87a88586
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-sft_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_lora.wav")
diff --git a/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py b/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py
new file mode 100644
index 00000000..c3450da0
--- /dev/null
+++ b/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py
@@ -0,0 +1,33 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-turbo_lora/epoch-9.safetensors", alpha=1)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_lora.wav")
From f2e3427566835a56f02922bcf9959b8db335652d Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Wed, 22 Apr 2026 19:16:04 +0800
Subject: [PATCH 08/16] reference audio input
---
diffsynth/models/ace_step_conditioner.py | 1 -
diffsynth/pipelines/ace_step.py | 109 ++++++-----------------
2 files changed, 25 insertions(+), 85 deletions(-)
diff --git a/diffsynth/models/ace_step_conditioner.py b/diffsynth/models/ace_step_conditioner.py
index 76cc502b..1279d793 100644
--- a/diffsynth/models/ace_step_conditioner.py
+++ b/diffsynth/models/ace_step_conditioner.py
@@ -506,7 +506,6 @@ def __init__(
for layer_idx in range(num_timbre_encoder_hidden_layers)
])
-
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 71d180c0..4730592a 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -8,6 +8,7 @@
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random
+import math
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -43,7 +44,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ConditionEmbedder(),
- AceStepUnit_AudioCodeDecoder(),
+ AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
@@ -293,107 +294,47 @@ def __init__(self):
)
def process(self, pipe, reference_audios):
- pipe.load_models_to_device(['vae'])
- if reference_audios is not None and len(reference_audios) > 0:
- raise NotImplementedError("Reference audio embedding is not implemented yet.")
- # TODO: implement reference audio embedding using VAE encode, and generate refer_audio_order_mask
- pass
+ if reference_audios is not None:
+ pipe.load_models_to_device(['vae'])
+ reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
+ reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
- # def process_reference_audio(self, reference_audios) -> Optional[torch.Tensor]:
-
- # try:
- # audio_np, sr = _read_audio_file(audio_file)
- # audio = self._numpy_to_channels_first(audio_np)
-
- # logger.debug(
- # f"[process_reference_audio] Reference audio shape: {audio.shape}"
- # )
- # logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
- # logger.debug(
- # f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
- # )
-
- # audio = self._normalize_audio_to_stereo_48k(audio, sr)
- # if self.is_silence(audio):
- # return None
-
- # target_frames = 30 * 48000
- # segment_frames = 10 * 48000
-
- # if audio.shape[-1] < target_frames:
- # repeat_times = math.ceil(target_frames / audio.shape[-1])
- # audio = audio.repeat(1, repeat_times)
-
- # total_frames = audio.shape[-1]
- # segment_size = total_frames // 3
-
- # front_start = random.randint(0, max(0, segment_size - segment_frames))
- # front_audio = audio[:, front_start : front_start + segment_frames]
-
- # middle_start = segment_size + random.randint(
- # 0, max(0, segment_size - segment_frames)
- # )
- # middle_audio = audio[:, middle_start : middle_start + segment_frames]
-
- # back_start = 2 * segment_size + random.randint(
- # 0, max(0, (total_frames - 2 * segment_size) - segment_frames)
- # )
- # back_audio = audio[:, back_start : back_start + segment_frames]
-
- # return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
+ def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
+ if audio.ndim == 3 and audio.shape[0] == 1:
+ audio = audio.squeeze(0)
+ target_frames = 30 * 48000
+ segment_frames = 10 * 48000
+ if audio.shape[-1] < target_frames:
+ repeat_times = math.ceil(target_frames / audio.shape[-1])
+ audio = audio.repeat(1, repeat_times)
+ total_frames = audio.shape[-1]
+ segment_size = total_frames // 3
+ front_start = random.randint(0, max(0, segment_size - segment_frames))
+ front_audio = audio[:, front_start:front_start + segment_frames]
+ middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
+ middle_audio = audio[:, middle_start:middle_start + segment_frames]
+ back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
+ back_audio = audio[:, back_start:back_start + segment_frames]
+ return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
refer_audio_latents = []
-
- def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor:
- if not isinstance(a, torch.Tensor):
- raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}")
- if a.dim() == 3 and a.shape[0] == 1:
- a = a.squeeze(0)
- if a.dim() == 1:
- a = a.unsqueeze(0)
- if a.dim() != 2:
- raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}")
- if a.shape[0] == 1:
- a = torch.cat([a, a], dim=0)
- return a[:2]
-
- def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor:
- if z.dim() == 4 and z.shape[0] == 1:
- z = z.squeeze(0)
- if z.dim() == 2:
- z = z.unsqueeze(0)
- return z
-
- refer_encode_cache: Dict[int, torch.Tensor] = {}
for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
- refer_audio_latent = _ensure_latent_3d(pipe.silence_latent[:, :750, :])
+ refer_audio_latent = pipe.silence_latent[:, :750, :]
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
- # TODO: check
for refer_audio in refer_audios:
- cache_key = refer_audio.data_ptr()
- if cache_key in refer_encode_cache:
- refer_audio_latent = refer_encode_cache[cache_key].clone()
- else:
- refer_audio = _normalize_audio_2d(refer_audio)
- refer_audio_latent = pipe.vae.encode(refer_audio)
- refer_audio_latent = refer_audio_latent.to(dtype=pipe.torch_dtype, device=pipe.device)
- if refer_audio_latent.dim() == 2:
- refer_audio_latent = refer_audio_latent.unsqueeze(0)
- refer_audio_latent = _ensure_latent_3d(refer_audio_latent.transpose(1, 2))
- refer_encode_cache[cache_key] = refer_audio_latent
+ refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
-
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask
From 11863791390ab5efb6312b0ea2ab05cd59aa3dbd Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Wed, 22 Apr 2026 21:36:30 +0800
Subject: [PATCH 09/16] noncover
---
diffsynth/pipelines/ace_step.py | 165 ++++++++++--------
.../model_inference/acestep-v15-base-cover.py | 36 ++++
2 files changed, 131 insertions(+), 70 deletions(-)
create mode 100644 examples/ace_step/model_inference/acestep-v15-base-cover.py
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 4730592a..688c0da2 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -9,6 +9,8 @@
from tqdm import tqdm
import random
import math
+import torch.nn.functional as F
+from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -41,11 +43,11 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
self.in_iteration_models = ("dit",)
self.units = [
+ AceStepUnit_TaskTypeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
- AceStepUnit_ConditionEmbedder(),
- AceStepUnit_AudioCodeDecoder(),
AceStepUnit_ContextLatentBuilder(),
+ AceStepUnit_ConditionEmbedder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
]
@@ -100,7 +102,8 @@ def __call__(
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
- denoising_strength: float = 1.0,
+ denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
+ audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
# Shape
@@ -121,7 +124,7 @@ def __call__(
progress_bar_cmd=tqdm,
):
# 1. Scheduler
- self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=1.0, shift=shift)
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# 2. 三字典输入
inputs_posi = {"prompt": prompt, "positive": True}
@@ -132,6 +135,7 @@ def __call__(
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio,
+ "audio_cover_strength": audio_cover_strength,
"audio_code_string": audio_code_string,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
@@ -152,6 +156,7 @@ def __call__(
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
@@ -181,23 +186,28 @@ def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch
gain = target_amp / peak
return audio * gain
+ def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
+ if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
+ return
+ cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
+ if progress_id >= cover_steps:
+ inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
+ inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
+ if inputs_shared["cfg_scale"] != 1.0:
+ inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
+
class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
- input_params=("audio_code_string"),
+ input_params=("task_type",),
output_params=("task_type",),
)
- def process(self, pipe, audio_code_string):
- if pipe.scheduler.training:
- return {"task_type": "text2music"}
- if audio_code_string is not None:
- task_type = "cover"
- else:
- task_type = "text2music"
- return {"task_type": task_type}
+ def process(self, pipe, task_type):
+ assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
+ return {}
class AceStepUnit_PromptEmbedder(PipelineUnit):
@@ -364,14 +374,34 @@ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
+ if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
+ hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
+ pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
+ inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
+ inputs_shared["vocal_language"], "text2music")
+ encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
+ **hidden_states_noncover,
+ reference_latents=inputs_shared.get("reference_latents", None),
+ refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
+ )
+ duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
+ context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
+ inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
+ inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
+ if inputs_shared["cfg_scale"] != 1.0:
+ inputs_shared["nega_noncover"] = {
+ "encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
+ "encoder_attention_mask": encoder_attention_mask_noncover,
+ }
return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("duration", "src_audio", "lm_hints"),
+ input_params=("duration", "src_audio", "audio_code_string"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
+ onload_model_names=("vae", "tokenizer_model",),
)
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
@@ -382,21 +412,55 @@ def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
- def process(self, pipe, duration, src_audio, lm_hints):
- if lm_hints is not None:
- max_latent_length = lm_hints.shape[1]
- src_latents = lm_hints.clone()
- chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
- attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
- context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
- # elif src_audio is not None:
- # raise NotImplementedError("src_audio conditioning is not implemented yet. Please set lm_hints to None.")
+ def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
+ if x.shape[1] % pool_window_size != 0:
+ pad_len = pool_window_size - (x.shape[1] % pool_window_size)
+ x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
+ x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
+ quantized, indices = tokenizer(x)
+ return quantized
+
+ @staticmethod
+ def _parse_audio_code_string(code_str: str) -> list:
+ """Extract integer audio codes from tokens like <|audio_code_123|>."""
+ if not code_str:
+ return []
+ try:
+ codes = []
+ max_audio_code = 63999
+ for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
+ code_value = int(x)
+ codes.append(max(0, min(code_value, max_audio_code)))
+ except Exception as e:
+ raise ValueError(f"Invalid audio_code_string format: {e}")
+ return codes
+
+ def process(self, pipe, duration, src_audio, audio_code_string):
+ # get src_latents from audio_code_string > src_audio > silence
+ if audio_code_string is not None:
+ pipe.load_models_to_device(self.onload_model_names)
+ code_ids = self._parse_audio_code_string(audio_code_string)
+ quantizer = pipe.tokenizer_model.tokenizer.quantizer
+ indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
+ codes = quantizer.get_codes_from_indices(indices)
+ quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
+ quantized = quantizer.project_out(quantized)
+ src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
+ max_latent_length = src_latents.shape[1]
+ elif src_audio is not None:
+ pipe.load_models_to_device(self.onload_model_names)
+ src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
+ src_audio = torch.clamp(src_audio, -1.0, 1.0)
+ src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
+ lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
+ src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
+ max_latent_length = src_latents.shape[1]
else:
- max_latent_length = duration * pipe.sample_rate // 1920
+ max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
- chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
- attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
- context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
+ chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
+ attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
+ context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask}
@@ -410,21 +474,24 @@ def __init__(self):
def process(self, pipe, context_latents, seed, rand_device):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
+ """Only for training."""
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
+ onload_model_names=("vae",),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
- pipe.load_models_to_device(['vae'])
+ pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
@@ -435,48 +502,6 @@ def process(self, pipe, noise, input_audio):
return {"input_latents": input_latents}
-class AceStepUnit_AudioCodeDecoder(PipelineUnit):
- def __init__(self):
- super().__init__(
- input_params=("audio_code_string",),
- output_params=("lm_hints",),
- onload_model_names=("tokenizer_model",),
- )
-
- @staticmethod
- def _parse_audio_code_string(code_str: str) -> list:
- """Extract integer audio codes from tokens like <|audio_code_123|>."""
- if not code_str:
- return []
- try:
- codes = []
- max_audio_code = 63999
- for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
- code_value = int(x)
- codes.append(max(0, min(code_value, max_audio_code)))
- except Exception as e:
- raise ValueError(f"Invalid audio_code_string format: {e}")
- return codes
-
- def process(self, pipe, audio_code_string):
- if audio_code_string is None or not audio_code_string.strip():
- return {"lm_hints": None}
- code_ids = self._parse_audio_code_string(audio_code_string)
- if len(code_ids) == 0:
- return {"lm_hints": None}
-
- pipe.load_models_to_device(["tokenizer_model"])
- quantizer = pipe.tokenizer_model.tokenizer.quantizer
- detokenizer = pipe.tokenizer_model.detokenizer
-
- indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
- codes = quantizer.get_codes_from_indices(indices)
- quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
- quantized = quantizer.project_out(quantized)
-
- lm_hints = detokenizer(quantized).to(pipe.device)
- return {"lm_hints": lm_hints}
-
def model_fn_ace_step(
dit: AceStepDiTModel,
diff --git a/examples/ace_step/model_inference/acestep-v15-base-cover.py b/examples/ace_step/model_inference/acestep-v15-base-cover.py
new file mode 100644
index 00000000..288bc3fd
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-base-cover.py
@@ -0,0 +1,36 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio, read_audio
+import torch
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ task_type="cover",
+ src_audio=src_audio,
+ audio_cover_strength=0.6,
+ denoising_strength=0.9,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")
From 394db06d86afe026e2eedd9ac669ccb5f623ea3b Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Thu, 23 Apr 2026 16:52:59 +0800
Subject: [PATCH 10/16] codes
---
.../configs/vram_management_module_maps.py | 1 +
diffsynth/models/ace_step_tokenizer.py | 2 +-
diffsynth/pipelines/ace_step.py | 85 +++++++++++++++----
...cover.py => acestep-v15-base-CoverTask.py} | 5 +-
.../acestep-v15-base-RepaintTask.py | 39 +++++++++
.../acestep-v15-base-CoverTask.py | 49 +++++++++++
.../acestep-v15-base-RepaintTask.py | 51 +++++++++++
7 files changed, 212 insertions(+), 20 deletions(-)
rename examples/ace_step/model_inference/{acestep-v15-base-cover.py => acestep-v15-base-CoverTask.py} (89%)
create mode 100644 examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
create mode 100644 examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py
index 299a830e..94854be2 100644
--- a/diffsynth/configs/vram_management_module_maps.py
+++ b/diffsynth/configs/vram_management_module_maps.py
@@ -328,6 +328,7 @@
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
+ "vector_quantize_pytorch.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
diff --git a/diffsynth/models/ace_step_tokenizer.py b/diffsynth/models/ace_step_tokenizer.py
index 5bd0e74e..935afa13 100644
--- a/diffsynth/models/ace_step_tokenizer.py
+++ b/diffsynth/models/ace_step_tokenizer.py
@@ -349,7 +349,7 @@ def forward(
) -> torch.Tensor:
B, T, P, D = x.shape
x = self.embed_tokens(x)
- special_tokens = self.special_token.expand(B, T, 1, -1)
+ special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 688c0da2..317a5338 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -106,6 +106,9 @@ def __call__(
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
+ # Inpainting
+ repainting_ranges: Optional[List[Tuple[float, float]]] = None,
+ repainting_strength: float = 1.0,
# Shape
duration: int = 60,
# Audio Meta
@@ -134,9 +137,8 @@ def __call__(
"lyrics": lyrics,
"task_type": task_type,
"reference_audios": reference_audios,
- "src_audio": src_audio,
- "audio_cover_strength": audio_cover_strength,
- "audio_code_string": audio_code_string,
+ "src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
+ "repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
@@ -162,9 +164,8 @@ def __call__(
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
- inputs_shared["latents"] = self.step(
- self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
- )
+ inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
+ progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# 5. VAE 解码
self.load_models_to_device(['vae'])
@@ -201,12 +202,17 @@ class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
- input_params=("task_type",),
+ input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
output_params=("task_type",),
)
- def process(self, pipe, task_type):
+ def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
+ if task_type == "cover":
+ assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
+ elif task_type == "repaint":
+ assert src_audio is not None, "For repaint task, src_audio must be provided."
+ assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
return {}
@@ -399,7 +405,7 @@ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("duration", "src_audio", "audio_code_string"),
+ input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
onload_model_names=("vae", "tokenizer_model",),
)
@@ -435,9 +441,46 @@ def _parse_audio_code_string(code_str: str) -> list:
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
- def process(self, pipe, duration, src_audio, audio_code_string):
+ def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
+ if task_type != "repaint" or repainting_ranges is None:
+ return src_audio, repainting_ranges, None, None
+ min_left = min([start for start, end in repainting_ranges])
+ max_right = max([end for start, end in repainting_ranges])
+ total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
+ pad_left = max(0, -min_left)
+ pad_right = max(0, max_right - total_length)
+ if pad_left > 0 or pad_right > 0:
+ padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
+ src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
+ repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
+ return src_audio, repainting_ranges, pad_left, pad_right
+
+ def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
+ if task_type != "repaint" or repainting_ranges is None:
+ return None, src_latents
+ # let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
+ max_latent_length = src_latents.shape[1]
+ denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
+ for start, end in repainting_ranges:
+ start_frame = start * pipe.vae.sampling_rate // 1920
+ end_frame = end * pipe.vae.sampling_rate // 1920
+ denoise_mask[:, start_frame:end_frame, :] = repainting_strength
+ # set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
+ pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
+ pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
+ denoise_mask[:, :pad_left_frames, :] = 1
+ denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
+
+ silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
+ src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
+ return denoise_mask, src_latents
+
+ def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
# get src_latents from audio_code_string > src_audio > silence
+ source_latents = None
+ denoise_mask = None
if audio_code_string is not None:
+ # use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer
@@ -448,33 +491,42 @@ def process(self, pipe, duration, src_audio, audio_code_string):
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
max_latent_length = src_latents.shape[1]
elif src_audio is not None:
+ # use src_audio to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
src_audio = torch.clamp(src_audio, -1.0, 1.0)
+
+ src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
+
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
- lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
- src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
+ source_latents = src_latents # cache for potential use in audio inpainting tasks
+ denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
+ if task_type == "cover":
+ lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
+ src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
max_latent_length = src_latents.shape[1]
else:
+ # use silence latents.
max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
- return {"context_latents": context_latents, "attention_mask": attention_mask}
+ return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("context_latents", "seed", "rand_device"),
+ input_params=("context_latents", "seed", "rand_device", "src_latents"),
output_params=("noise",),
)
- def process(self, pipe, context_latents, seed, rand_device):
+ def process(self, pipe, context_latents, seed, rand_device, src_latents):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
- noise = pipe.scheduler.add_noise(context_latents[:, :, :src_latents_shape[-1]], noise, timestep=pipe.scheduler.timesteps[0])
+ if src_latents is not None:
+ noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
@@ -502,7 +554,6 @@ def process(self, pipe, noise, input_audio):
return {"input_latents": input_latents}
-
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,
diff --git a/examples/ace_step/model_inference/acestep-v15-base-cover.py b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
similarity index 89%
rename from examples/ace_step/model_inference/acestep-v15-base-cover.py
rename to examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
index 288bc3fd..3f55aa49 100644
--- a/examples/ace_step/model_inference/acestep-v15-base-cover.py
+++ b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
@@ -16,12 +16,14 @@
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
+# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
src_audio=src_audio,
- audio_cover_strength=0.6,
+ audio_cover_strength=0.5,
denoising_strength=0.9,
duration=160,
bpm=100,
@@ -32,5 +34,4 @@
num_inference_steps=30,
cfg_scale=4.0,
)
-
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")
diff --git a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
new file mode 100644
index 00000000..49152457
--- /dev/null
+++ b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
@@ -0,0 +1,39 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio, read_audio
+import torch
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
+# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
+# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ task_type="repaint",
+ src_audio=src_audio,
+ repainting_ranges=[(-10, 30), (150, 200)],
+ repainting_strength=1.0,
+ duration=210,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
new file mode 100644
index 00000000..f16a4bd9
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
@@ -0,0 +1,49 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio, read_audio
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
+# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ task_type="cover",
+ src_audio=src_audio,
+ audio_cover_strength=0.5,
+ denoising_strength=0.9,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
new file mode 100644
index 00000000..42a3c2b8
--- /dev/null
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
@@ -0,0 +1,51 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio, read_audio
+import torch
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
+# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
+# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ task_type="repaint",
+ src_audio=src_audio,
+ repainting_ranges=[(-10, 30), (150, 200)],
+ repainting_strength=1.0,
+ duration=210,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=30,
+ cfg_scale=4.0,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")
From a80fb8422005c02f9b4960eddab33e6b749e25a9 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Thu, 23 Apr 2026 17:31:34 +0800
Subject: [PATCH 11/16] style
---
diffsynth/configs/model_configs.py | 19 ++++---
diffsynth/pipelines/ace_step.py | 52 ++++++++++--------
.../ace_step_conditioner.py | 37 +------------
.../state_dict_converters/ace_step_dit.py | 35 +-----------
.../state_dict_converters/ace_step_lm.py | 55 -------------------
.../ace_step_text_encoder.py | 26 +--------
.../ace_step_tokenizer.py | 21 +------
.../ace_step/model_inference/Ace-Step1.5.py | 6 ++
.../acestep-v15-base-CoverTask.py | 8 +++
.../acestep-v15-base-RepaintTask.py | 8 +++
.../model_inference_low_vram/Ace-Step1.5.py | 6 ++
.../acestep-v15-base-CoverTask.py | 8 +++
.../acestep-v15-base-RepaintTask.py | 8 +++
examples/ace_step/model_training/train.py | 53 ++++--------------
14 files changed, 99 insertions(+), 243 deletions(-)
delete mode 100644 diffsynth/utils/state_dict_converters/ace_step_lm.py
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index ccaa1460..9c26a3c3 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -925,7 +925,7 @@
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
},
# === XL DiT variants (32 layers, hidden_size=2560) ===
# Covers: xl-base, xl-sft, xl-turbo
@@ -934,7 +934,7 @@
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.ace_step_dit_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
"extra_kwargs": {
"hidden_size": 2560,
"intermediate_size": 9728,
@@ -952,7 +952,7 @@
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
{
@@ -960,7 +960,7 @@
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.ace_step_conditioner_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === Qwen3-Embedding (text encoder) ===
{
@@ -968,7 +968,7 @@
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
"model_name": "ace_step_text_encoder",
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.ace_step_text_encoder_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
# === VAE (AutoencoderOobleck CNN) ===
{
@@ -983,7 +983,7 @@
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
# === XL Tokenizer (XL models share same tokenizer architecture) ===
{
@@ -991,8 +991,11 @@
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
- "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.ace_step_tokenizer_converter",
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
]
-MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
+MODEL_CONFIGS = (
+ qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
+)
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 317a5338..4d9971a0 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -3,12 +3,10 @@
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
-import re
-import torch
+import re, torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
-import random
-import math
+import random, math
import torch.nn.functional as F
from einops import rearrange
@@ -39,7 +37,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae: AceStepVAE = None
- self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
+ self.tokenizer_model: AceStepTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
@@ -65,7 +63,6 @@ def from_pretrained(
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
- """Load pipeline from pretrained checkpoints."""
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
@@ -102,7 +99,7 @@ def __call__(
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
- denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
+ denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
@@ -115,7 +112,7 @@ def __call__(
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
- vocal_language: Optional[str] = 'unknown',
+ vocal_language: Optional[str] = "unknown",
# Randomness
seed: int = None,
rand_device: str = "cpu",
@@ -126,10 +123,10 @@ def __call__(
# Progress
progress_bar_cmd=tqdm,
):
- # 1. Scheduler
+ # Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
- # 2. 三字典输入
+ # Parameters
inputs_posi = {"prompt": prompt, "positive": True}
inputs_nega = {"positive": False}
inputs_shared = {
@@ -147,13 +144,12 @@ def __call__(
"shift": shift,
}
- # 3. Unit 链执行
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
- # 4. Denoise loop
+ # Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -164,17 +160,17 @@ def __call__(
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
- inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
- progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+ inputs_shared["latents"] = self.step(
+ self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
+ progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
+ )
- # 5. VAE 解码
+ # Decode
self.load_models_to_device(['vae'])
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
latents = inputs_shared["latents"].transpose(1, 2)
vae_output = self.vae.decode(latents)
- # VAE returns OobleckDecoderOutput with .sample attribute
- audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
- audio_output = self.normalize_audio(audio_output, target_db=-1.0)
+ audio_output = self.normalize_audio(vae_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
self.load_models_to_device([])
return audio
@@ -188,7 +184,9 @@ def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
- if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
+ if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
+ return
+ if inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
@@ -312,7 +310,10 @@ def __init__(self):
def process(self, pipe, reference_audios):
if reference_audios is not None:
pipe.load_models_to_device(['vae'])
- reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
+ reference_audios = [
+ self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
+ for reference_audio in reference_audios
+ ]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
@@ -357,7 +358,6 @@ def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> T
class AceStepUnit_ConditionEmbedder(PipelineUnit):
-
def __init__(self):
super().__init__(
take_over=True,
@@ -378,7 +378,9 @@ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
- inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
+ inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
+ dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
+ )
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
@@ -396,7 +398,9 @@ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
- "encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
+ "encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
+ dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
+ ),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
@@ -483,7 +487,7 @@ def process(self, pipe, duration, src_audio, audio_code_string, task_type=None,
# use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
- quantizer = pipe.tokenizer_model.tokenizer.quantizer
+ quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
diff --git a/diffsynth/utils/state_dict_converters/ace_step_conditioner.py b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py
index b6984b88..041a405c 100644
--- a/diffsynth/utils/state_dict_converters/ace_step_conditioner.py
+++ b/diffsynth/utils/state_dict_converters/ace_step_conditioner.py
@@ -1,38 +1,4 @@
-"""
-State dict converter for ACE-Step Conditioner model.
-
-The original checkpoint stores all model weights in a single file
-(nested in AceStepConditionGenerationModel). The Conditioner weights are
-prefixed with 'encoder.'.
-
-This converter extracts only keys starting with 'encoder.' and strips
-the prefix to match the standalone AceStepConditionEncoder in DiffSynth.
-"""
-
-
-def ace_step_conditioner_converter(state_dict):
- """
- Convert ACE-Step Conditioner checkpoint keys to DiffSynth format.
-
- 参数 state_dict 是 DiskMap 类型。
- 遍历时,key 是 key 名,state_dict[key] 获取实际值。
-
- Original checkpoint contains all model weights under prefixes:
- - decoder.* (DiT)
- - encoder.* (Conditioner)
- - tokenizer.* (Audio Tokenizer)
- - detokenizer.* (Audio Detokenizer)
- - null_condition_emb (CFG null embedding)
-
- This extracts only 'encoder.' keys and strips the prefix.
-
- Example mapping:
- encoder.lyric_encoder.layers.0.self_attn.q_proj.weight -> lyric_encoder.layers.0.self_attn.q_proj.weight
- encoder.attention_pooler.layers.0.self_attn.q_proj.weight -> attention_pooler.layers.0.self_attn.q_proj.weight
- encoder.timbre_encoder.layers.0.self_attn.q_proj.weight -> timbre_encoder.layers.0.self_attn.q_proj.weight
- encoder.audio_tokenizer.audio_acoustic_proj.weight -> audio_tokenizer.audio_acoustic_proj.weight
- encoder.detokenizer.layers.0.self_attn.q_proj.weight -> detokenizer.layers.0.self_attn.q_proj.weight
- """
+def AceStepConditionEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "encoder."
@@ -41,7 +7,6 @@ def ace_step_conditioner_converter(state_dict):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
- # Extract null_condition_emb from top level (used for CFG negative condition)
if "null_condition_emb" in state_dict:
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
diff --git a/diffsynth/utils/state_dict_converters/ace_step_dit.py b/diffsynth/utils/state_dict_converters/ace_step_dit.py
index 758462cc..d5f7cf6d 100644
--- a/diffsynth/utils/state_dict_converters/ace_step_dit.py
+++ b/diffsynth/utils/state_dict_converters/ace_step_dit.py
@@ -1,37 +1,4 @@
-"""
-State dict converter for ACE-Step DiT model.
-
-The original checkpoint stores all model weights in a single file
-(nested in AceStepConditionGenerationModel). The DiT weights are
-prefixed with 'decoder.'.
-
-This converter extracts only keys starting with 'decoder.' and strips
-the prefix to match the standalone AceStepDiTModel in DiffSynth.
-"""
-
-
-def ace_step_dit_converter(state_dict):
- """
- Convert ACE-Step DiT checkpoint keys to DiffSynth format.
-
- 参数 state_dict 是 DiskMap 类型。
- 遍历时,key 是 key 名,state_dict[key] 获取实际值。
-
- Original checkpoint contains all model weights under prefixes:
- - decoder.* (DiT)
- - encoder.* (Conditioner)
- - tokenizer.* (Audio Tokenizer)
- - detokenizer.* (Audio Detokenizer)
- - null_condition_emb (CFG null embedding)
-
- This extracts only 'decoder.' keys and strips the prefix.
-
- Example mapping:
- decoder.layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
- decoder.proj_in.0.linear_1.weight -> proj_in.0.linear_1.weight
- decoder.time_embed.linear_1.weight -> time_embed.linear_1.weight
- decoder.rotary_emb.inv_freq -> rotary_emb.inv_freq
- """
+def AceStepDiTModelStateDictConverter(state_dict):
new_state_dict = {}
prefix = "decoder."
diff --git a/diffsynth/utils/state_dict_converters/ace_step_lm.py b/diffsynth/utils/state_dict_converters/ace_step_lm.py
deleted file mode 100644
index 2067cb16..00000000
--- a/diffsynth/utils/state_dict_converters/ace_step_lm.py
+++ /dev/null
@@ -1,55 +0,0 @@
-"""
-State dict converter for ACE-Step LLM (Qwen3-based).
-
-The safetensors file stores Qwen3 model weights. Different checkpoints
-may have different key formats:
-- Qwen3ForCausalLM format: model.embed_tokens.weight, model.layers.0.*
-- Qwen3Model format: embed_tokens.weight, layers.0.*
-
-Qwen3ForCausalLM wraps a .model attribute (Qwen3Model), so its
-state_dict() has keys:
- model.model.embed_tokens.weight
- model.model.layers.0.self_attn.q_proj.weight
- model.model.norm.weight
- model.lm_head.weight (tied to model.model.embed_tokens)
-
-This converter normalizes all keys to the Qwen3ForCausalLM format.
-
-Example mapping:
- model.embed_tokens.weight -> model.model.embed_tokens.weight
- embed_tokens.weight -> model.model.embed_tokens.weight
- model.layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
- layers.0.self_attn.q_proj.weight -> model.model.layers.0.self_attn.q_proj.weight
- model.norm.weight -> model.model.norm.weight
- norm.weight -> model.model.norm.weight
-"""
-
-
-def ace_step_lm_converter(state_dict):
- """
- Convert ACE-Step LLM checkpoint keys to match Qwen3ForCausalLM state dict.
-
- 参数 state_dict 是 DiskMap 类型。
- 遍历时,key 是 key 名,state_dict[key] 获取实际值。
- """
- new_state_dict = {}
- model_prefix = "model."
- nested_prefix = "model.model."
-
- for key in state_dict:
- if key.startswith(nested_prefix):
- # Already has model.model., keep as is
- new_key = key
- elif key.startswith(model_prefix):
- # Has model., add another model.
- new_key = "model." + key
- else:
- # No prefix, add model.model.
- new_key = "model.model." + key
- new_state_dict[new_key] = state_dict[key]
-
- # Handle tied word embeddings: lm_head.weight shares with embed_tokens
- if "model.model.embed_tokens.weight" in new_state_dict:
- new_state_dict["model.lm_head.weight"] = new_state_dict["model.model.embed_tokens.weight"]
-
- return new_state_dict
diff --git a/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py
index de0b6c7b..4ed1c016 100644
--- a/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py
+++ b/diffsynth/utils/state_dict_converters/ace_step_text_encoder.py
@@ -1,28 +1,4 @@
-"""
-State dict converter for ACE-Step Text Encoder (Qwen3-Embedding-0.6B).
-
-The safetensors stores Qwen3Model weights with keys:
- embed_tokens.weight
- layers.0.self_attn.q_proj.weight
- norm.weight
-
-AceStepTextEncoder wraps a .model attribute (Qwen3Model), so its
-state_dict() has keys with 'model.' prefix:
- model.embed_tokens.weight
- model.layers.0.self_attn.q_proj.weight
- model.norm.weight
-
-This converter adds 'model.' prefix to match the nested structure.
-"""
-
-
-def ace_step_text_encoder_converter(state_dict):
- """
- Convert ACE-Step Text Encoder checkpoint keys to match Qwen3Model wrapped state dict.
-
- 参数 state_dict 是 DiskMap 类型。
- 遍历时,key 是 key 名,state_dict[key] 获取实际值。
- """
+def AceStepTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "model."
nested_prefix = "model.model."
diff --git a/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
index d4cb2bab..66e014c0 100644
--- a/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
+++ b/diffsynth/utils/state_dict_converters/ace_step_tokenizer.py
@@ -1,23 +1,4 @@
-"""
-State dict converter for ACE-Step Tokenizer model.
-
-The original checkpoint stores tokenizer and detokenizer weights at the top level:
-- tokenizer.* (AceStepAudioTokenizer: audio_acoustic_proj, attention_pooler, quantizer)
-- detokenizer.* (AudioTokenDetokenizer: embed_tokens, layers, proj_out)
-
-These map directly to the AceStepTokenizer class which wraps both as
-self.tokenizer and self.detokenizer submodules.
-"""
-
-
-def ace_step_tokenizer_converter(state_dict):
- """
- Convert ACE-Step Tokenizer checkpoint keys to DiffSynth format.
-
- The checkpoint keys `tokenizer.*` and `detokenizer.*` already match
- the DiffSynth AceStepTokenizer module structure (self.tokenizer, self.detokenizer).
- No key remapping needed — just extract the relevant keys.
- """
+def AceStepTokenizerStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
diff --git a/examples/ace_step/model_inference/Ace-Step1.5.py b/examples/ace_step/model_inference/Ace-Step1.5.py
index ae41f11c..ff40d881 100644
--- a/examples/ace_step/model_inference/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference/Ace-Step1.5.py
@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
+from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
@@ -29,6 +30,11 @@
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
# input audio codes as reference
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
+ local_dir="data/diffsynth_example_dataset",
+ allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
+)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()
diff --git a/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
index 3f55aa49..9b8a8a2a 100644
--- a/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
+++ b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
+from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
@@ -15,6 +16,13 @@
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
+ local_dir="data/diffsynth_example_dataset",
+ allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+)
+
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
diff --git a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
index 49152457..68a0a362 100644
--- a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
+++ b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
+from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
@@ -15,6 +16,13 @@
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
+ local_dir="data/diffsynth_example_dataset",
+ allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+)
+
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
diff --git a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
index 4bc2e5e1..0160bcf2 100644
--- a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
@@ -6,6 +6,7 @@
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
+from modelscope import dataset_snapshot_download
import torch
@@ -49,6 +50,11 @@
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
# input audio codes as reference
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
+ local_dir="data/diffsynth_example_dataset",
+ allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
+)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
index f16a4bd9..2ae06fef 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
+from modelscope import dataset_snapshot_download
import torch
vram_config = {
@@ -27,6 +28,13 @@
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
+ local_dir="data/diffsynth_example_dataset",
+ allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+)
+
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
index 42a3c2b8..6cbe1074 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
@@ -1,5 +1,6 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
+from modelscope import dataset_snapshot_download
import torch
vram_config = {
@@ -27,6 +28,13 @@
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
+ local_dir="data/diffsynth_example_dataset",
+ allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+)
+
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
diff --git a/examples/ace_step/model_training/train.py b/examples/ace_step/model_training/train.py
index a24da2c3..21d57dd7 100644
--- a/examples/ace_step/model_training/train.py
+++ b/examples/ace_step/model_training/train.py
@@ -1,31 +1,15 @@
-import torch, os, argparse, accelerate, warnings, torchaudio
+import os
+import torch
import math
+import argparse
+import accelerate
from diffsynth.core import UnifiedDataset
-from diffsynth.core.data.operators import ToAbsolutePath, RouteByType, DataProcessingOperator, LoadPureAudioWithTorchaudio
+from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
-class LoadAceStepAudio(DataProcessingOperator):
- """Load audio file and return waveform tensor [2, T] at 48kHz."""
- def __init__(self, target_sr=48000):
- self.target_sr = target_sr
-
- def __call__(self, data: str):
- try:
- waveform, sample_rate = torchaudio.load(data)
- if sample_rate != self.target_sr:
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sr)
- waveform = resampler(waveform)
- if waveform.shape[0] == 1:
- waveform = waveform.repeat(2, 1)
- return waveform
- except Exception as e:
- warnings.warn(f"Cannot load audio from {data}: {e}")
- return None
-
-
class AceStepTrainingModule(DiffusionTrainingModule):
def __init__(
self,
@@ -43,17 +27,15 @@ def __init__(
task="sft",
):
super().__init__()
- # ===== 解析模型配置(固定写法) =====
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
- # ===== Tokenizer 配置 =====
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
- # ===== 构建 Pipeline =====
- self.pipe = AceStepPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config)
- # ===== 拆分 Pipeline Units(固定写法) =====
+ self.pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
+ text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
+ )
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
- # ===== 切换到训练模式(固定写法) =====
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
@@ -61,13 +43,11 @@ def __init__(
task=task,
)
- # ===== 其他配置(固定写法) =====
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
- # ===== 任务模式路由(固定写法) =====
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
@@ -78,11 +58,8 @@ def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"], "positive": True}
inputs_nega = {"positive": False}
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
- # ===== 共享参数 =====
inputs_shared = {
- # ===== 核心字段映射 =====
"input_audio": data["audio"],
- # ===== 音频生成任务所需元数据 =====
"lyrics": data["lyrics"],
"task_type": "text2music",
"duration": duration,
@@ -90,18 +67,15 @@ def get_pipeline_inputs(self, data):
"keyscale": data.get("keyscale", "C major"),
"timesignature": data.get("timesignature", "4"),
"vocal_language": data.get("vocal_language", "unknown"),
- # ===== 框架控制参数(固定写法) =====
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
- # ===== 额外字段注入:通过 --extra_inputs 配置的数据集列名(固定写法) =====
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
- # ===== 标准实现,不要修改(固定写法) =====
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
@@ -122,12 +96,10 @@ def ace_step_parser():
if __name__ == "__main__":
parser = ace_step_parser()
args = parser.parse_args()
- # ===== Accelerator 配置(固定写法) =====
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
- # ===== 数据集定义 =====
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
@@ -135,10 +107,11 @@ def ace_step_parser():
data_file_keys=args.data_file_keys.split(","),
main_data_operator=None,
special_operator_map={
- "audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(target_sample_rate=48000),
+ "audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
+ target_sample_rate=48000,
+ ),
},
)
- # ===== TrainingModule =====
model = AceStepTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
@@ -159,12 +132,10 @@ def ace_step_parser():
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)
- # ===== ModelLogger(固定写法) =====
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
- # ===== 任务路由(固定写法) =====
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,
From 002e3cdb74c4e2b37f5baf2d5f08334374119e69 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Thu, 23 Apr 2026 18:02:58 +0800
Subject: [PATCH 12/16] docs
---
README.md | 82 +++++++++
README_zh.md | 82 +++++++++
docs/en/Model_Details/ACE-Step.md | 164 ++++++++++++++++++
docs/en/index.rst | 1 +
docs/zh/Model_Details/ACE-Step.md | 164 ++++++++++++++++++
docs/zh/index.rst | 1 +
.../validate_full/acestep-v15-xl-turbo.py | 35 ++++
7 files changed, 529 insertions(+)
create mode 100644 docs/en/Model_Details/ACE-Step.md
create mode 100644 docs/zh/Model_Details/ACE-Step.md
create mode 100644 examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py
diff --git a/README.md b/README.md
index ff469777..1ed3aef8 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,8 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
+- **April 23, 2026** ACE-Step open-sourced, welcome a new member to the audio model family! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
+
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
@@ -1016,6 +1018,86 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
+### Audio Synthesis
+
+#### ACE-Step: [/docs/en/Model_Details/ACE-Step.md](/docs/en/Model_Details/ACE-Step.md)
+
+
+
+Quick Start
+
+Running the following code will quickly load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
+
+```python
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
+```
+
+
+
+
+
+Examples
+
+Example code for ACE-Step is available at: [/examples/ace_step/](/examples/ace_step/)
+
+| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
+|-|-|-|-|-|-|-|
+|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
+|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
+|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
+|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
+|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
+|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
+|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
+|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
+|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
+
+
+
## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
diff --git a/README_zh.md b/README_zh.md
index 77dfaf0f..bd3b0d37 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -34,6 +34,8 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
+- **2026年4月23日** ACE-Step 开源,欢迎加入音频生成模型家族!支持文生音乐推理、低显存推理和 LoRA 训练能力。详情请参考[文档](/docs/zh/Model_Details/ACE-Step.md)和[示例代码](/examples/ace_step/)。
+
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
@@ -1016,6 +1018,86 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
+### 音频生成模型
+
+#### ACE-Step: [/docs/zh/Model_Details/ACE-Step.md](/docs/zh/Model_Details/ACE-Step.md)
+
+
+
+快速开始
+
+运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
+
+```python
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
+```
+
+
+
+
+
+示例代码
+
+ACE-Step 的示例代码位于:[/examples/ace_step/](/examples/ace_step/)
+
+| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
+|-|-|-|-|-|-|-|
+|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
+|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
+|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
+|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
+|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
+|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
+|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
+|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
+|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
+
+
+
## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
diff --git a/docs/en/Model_Details/ACE-Step.md b/docs/en/Model_Details/ACE-Step.md
new file mode 100644
index 00000000..73ed301a
--- /dev/null
+++ b/docs/en/Model_Details/ACE-Step.md
@@ -0,0 +1,164 @@
+# ACE-Step
+
+ACE-Step 1.5 is an open-source music generation model based on DiT architecture, supporting text-to-music, audio cover, repainting and other functionalities, running efficiently on consumer-grade hardware.
+
+## Installation
+
+Before performing model inference and training, please install DiffSynth-Studio first.
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
+
+## Quick Start
+
+Running the following code will load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3GB VRAM.
+
+```python
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
+```
+
+## Model Overview
+
+|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
+|-|-|-|-|-|-|-|
+|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
+|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
+|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
+|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
+|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
+|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
+|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
+|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
+|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
+
+## Model Inference
+
+The model is loaded via `AceStepPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
+
+The input parameters for `AceStepPipeline` inference include:
+
+* `prompt`: Text description of the music.
+* `cfg_scale`: Classifier-free guidance scale, defaults to 1.0.
+* `lyrics`: Lyrics text.
+* `task_type`: Task type,可选 values include `"text2music"` (text-to-music), `"cover"` (audio cover), `"repaint"` (repainting), defaults to `"text2music"`.
+* `reference_audios`: List of reference audio tensors for timbre reference.
+* `src_audio`: Source audio tensor for cover or repaint tasks.
+* `denoising_strength`: Denoising strength, controlling how much the output is influenced by source audio, defaults to 1.0.
+* `audio_cover_strength`: Audio cover step ratio, controlling how many steps use cover condition in cover tasks, defaults to 1.0.
+* `audio_code_string`: Input audio code string for cover tasks with discrete audio codes.
+* `repainting_ranges`: List of repainting time ranges (tuples of floats, in seconds) for repaint tasks.
+* `repainting_strength`: Repainting intensity, controlling the degree of change in repainted areas, defaults to 1.0.
+* `duration`: Audio duration in seconds, defaults to 60.
+* `bpm`: Beats per minute, defaults to 100.
+* `keyscale`: Musical key scale, defaults to "B minor".
+* `timesignature`: Time signature, defaults to "4".
+* `vocal_language`: Vocal language, defaults to "unknown".
+* `seed`: Random seed.
+* `rand_device`: Device for noise generation, defaults to "cpu".
+* `num_inference_steps`: Number of inference steps, defaults to 8.
+* `shift`: Timestep shift parameter for the scheduler, defaults to 1.0.
+
+## Model Training
+
+Models in the ace_step series are trained uniformly via `examples/ace_step/model_training/train.py`. The script parameters include:
+
+* General Training Parameters
+ * Dataset Configuration
+ * `--dataset_base_path`: Root directory of the dataset.
+ * `--dataset_metadata_path`: Path to the dataset metadata file.
+ * `--dataset_repeat`: Number of dataset repeats per epoch.
+ * `--dataset_num_workers`: Number of processes per DataLoader.
+ * `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
+ * Model Loading Configuration
+ * `--model_paths`: Paths to load models from, in JSON format.
+ * `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
+ * `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
+ * `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
+ * Basic Training Configuration
+ * `--learning_rate`: Learning rate.
+ * `--num_epochs`: Number of epochs.
+ * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
+ * `--find_unused_parameters`: Whether unused parameters exist in DDP training.
+ * `--weight_decay`: Weight decay magnitude.
+ * `--task`: Training task, defaults to `sft`.
+ * Output Configuration
+ * `--output_path`: Path to save the model.
+ * `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
+ * `--save_steps`: Interval in training steps to save the model.
+ * LoRA Configuration
+ * `--lora_base_model`: Which model to add LoRA to.
+ * `--lora_target_modules`: Which layers to add LoRA to.
+ * `--lora_rank`: Rank of LoRA.
+ * `--lora_checkpoint`: Path to LoRA checkpoint.
+ * `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
+ * `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
+ * Gradient Configuration
+ * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
+ * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
+ * `--gradient_accumulation_steps`: Number of gradient accumulation steps.
+ * Resolution Configuration
+ * `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
+ * `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
+ * `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
+ * `--num_frames`: Number of frames for video (video generation models only).
+* ACE-Step Specific Parameters
+ * `--tokenizer_path`: Tokenizer path, in format model_id:origin_pattern.
+ * `--silence_latent_path`: Silence latent path, in format model_id:origin_pattern.
+ * `--initialize_model_on_cpu`: Whether to initialize models on CPU.
+
+### Example Dataset
+
+```shell
+modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
+```
+
+We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
diff --git a/docs/en/index.rst b/docs/en/index.rst
index ad333be0..aaa7e4d7 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -32,6 +32,7 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
+ Model_Details/ACE-Step
.. toctree::
:maxdepth: 2
diff --git a/docs/zh/Model_Details/ACE-Step.md b/docs/zh/Model_Details/ACE-Step.md
new file mode 100644
index 00000000..d43e61fa
--- /dev/null
+++ b/docs/zh/Model_Details/ACE-Step.md
@@ -0,0 +1,164 @@
+# ACE-Step
+
+ACE-Step 1.5 是一个开源音乐生成模型,基于 DiT 架构,支持文生音乐、音频翻唱、局部重绘等多种功能,可在消费级硬件上高效运行。
+
+## 安装
+
+在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
+
+```shell
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
+
+## 快速开始
+
+运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
+
+```python
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+import torch
+
+
+vram_config = {
+ "offload_dtype": torch.bfloat16,
+ "offload_device": "cpu",
+ "onload_dtype": torch.bfloat16,
+ "onload_device": "cpu",
+ "preparing_dtype": torch.bfloat16,
+ "preparing_device": "cuda",
+ "computation_dtype": torch.bfloat16,
+ "computation_device": "cuda",
+}
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
+)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=42,
+)
+
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
+```
+
+## 模型总览
+
+|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|
+|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
+|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
+|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
+|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
+|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
+|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
+|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
+|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
+|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
+|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
+
+## 模型推理
+
+模型通过 `AceStepPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
+
+`AceStepPipeline` 推理的输入参数包括:
+
+* `prompt`: 音乐文本描述。
+* `cfg_scale`: 分类器无条件引导比例,默认为 1.0。
+* `lyrics`: 歌词文本。
+* `task_type`: 任务类型,可选值包括 `"text2music"`(文生音乐)、`"cover"`(音频翻唱)、`"repaint"`(局部重绘),默认为 `"text2music"`。
+* `reference_audios`: 参考音频列表(Tensor 列表),用于提供音色参考。
+* `src_audio`: 源音频(Tensor),用于 cover 或 repaint 任务。
+* `denoising_strength`: 降噪强度,控制输出受源音频的影响程度,默认为 1.0。
+* `audio_cover_strength`: 音频翻唱步数比例,控制 cover 任务中前多少步使用翻唱条件,默认为 1.0。
+* `audio_code_string`: 输入音频码字符串,用于 cover 任务中直接传入离散音频码。
+* `repainting_ranges`: 重绘时间区间(浮点元组列表,单位为秒),用于 repaint 任务。
+* `repainting_strength`: 重绘强度,控制重绘区域的变化程度,默认为 1.0。
+* `duration`: 音频时长(秒),默认为 60。
+* `bpm`: 每分钟节拍数,默认为 100。
+* `keyscale`: 音阶调式,默认为 "B minor"。
+* `timesignature`: 拍号,默认为 "4"。
+* `vocal_language`: 演唱语言,默认为 "unknown"。
+* `seed`: 随机种子。
+* `rand_device`: 噪声生成设备,默认为 "cpu"。
+* `num_inference_steps`: 推理步数,默认为 8。
+* `shift`: 调度器时间偏移参数,默认为 1.0。
+
+## 模型训练
+
+ace_step 系列模型统一通过 `examples/ace_step/model_training/train.py` 进行训练,脚本的参数包括:
+
+* 通用训练参数
+ * 数据集基础配置
+ * `--dataset_base_path`: 数据集的根目录。
+ * `--dataset_metadata_path`: 数据集的元数据文件路径。
+ * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
+ * `--dataset_num_workers`: 每个 Dataloader 的进程数量。
+ * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
+ * 模型加载配置
+ * `--model_paths`: 要加载的模型路径。JSON 格式。
+ * `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
+ * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
+ * `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
+ * 训练基础配置
+ * `--learning_rate`: 学习率。
+ * `--num_epochs`: 轮数(Epoch)。
+ * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
+ * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
+ * `--weight_decay`: 权重衰减大小。
+ * `--task`: 训练任务,默认为 `sft`。
+ * 输出配置
+ * `--output_path`: 模型保存路径。
+ * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
+ * `--save_steps`: 保存模型的训练步数间隔。
+ * LoRA 配置
+ * `--lora_base_model`: LoRA 添加到哪个模型上。
+ * `--lora_target_modules`: LoRA 添加到哪些层上。
+ * `--lora_rank`: LoRA 的秩(Rank)。
+ * `--lora_checkpoint`: LoRA 检查点的路径。
+ * `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
+ * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
+ * 梯度配置
+ * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
+ * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
+ * `--gradient_accumulation_steps`: 梯度累积步数。
+ * 分辨率配置
+ * `--height`: 图像/视频的高度。留空启用动态分辨率。
+ * `--width`: 图像/视频的宽度。留空启用动态分辨率。
+ * `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
+ * `--num_frames`: 视频的帧数(仅视频生成模型)。
+* ACE-Step 专有参数
+ * `--tokenizer_path`: Tokenizer 路径,格式为 model_id:origin_pattern。
+ * `--silence_latent_path`: 静音隐变量路径,格式为 model_id:origin_pattern。
+ * `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型。
+
+### 样例数据集
+
+```shell
+modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
+```
+
+我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
diff --git a/docs/zh/index.rst b/docs/zh/index.rst
index 526f0fba..1a29be52 100644
--- a/docs/zh/index.rst
+++ b/docs/zh/index.rst
@@ -32,6 +32,7 @@
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
+ Model_Details/ACE-Step
.. toctree::
:maxdepth: 2
diff --git a/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py b/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py
new file mode 100644
index 00000000..42c86160
--- /dev/null
+++ b/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py
@@ -0,0 +1,35 @@
+from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
+from diffsynth.utils.data.audio import save_audio
+from diffsynth import load_state_dict
+import torch
+
+
+pipe = AceStepPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
+ ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
+ silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
+)
+state_dict = load_state_dict("models/train/acestep-v15-xl-turbo_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+
+prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
+lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
+audio = pipe(
+ prompt=prompt,
+ lyrics=lyrics,
+ duration=160,
+ bpm=100,
+ keyscale="B minor",
+ timesignature="4",
+ vocal_language="zh",
+ seed=1,
+ num_inference_steps=8,
+ cfg_scale=1.0,
+)
+save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_full.wav")
From 3da625432ef8a910c0bdf57792ee2b020dce5d98 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Thu, 23 Apr 2026 18:09:16 +0800
Subject: [PATCH 13/16] path
---
.../ace_step/model_inference/acestep-v15-base-CoverTask.py | 4 ++--
.../ace_step/model_inference/acestep-v15-base-RepaintTask.py | 4 ++--
.../model_inference_low_vram/acestep-v15-base-CoverTask.py | 4 ++--
.../model_inference_low_vram/acestep-v15-base-RepaintTask.py | 4 ++--
4 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
index 9b8a8a2a..bfb4bd89 100644
--- a/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
+++ b/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py
@@ -20,10 +20,10 @@
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
- allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+ allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
)
-src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
diff --git a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
index 68a0a362..9e10c7ec 100644
--- a/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
+++ b/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py
@@ -20,10 +20,10 @@
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
- allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+ allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
)
-src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
index 2ae06fef..c0e9e5cd 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py
@@ -32,10 +32,10 @@
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
- allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+ allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
)
-src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
index 6cbe1074..6437e4a0 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py
@@ -32,10 +32,10 @@
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
- allow_file_pattern="ace_step/acestep-v15-base/audio.wav",
+ allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
)
-src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
+src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
From 641801e01b6c1fd8283b75cc3055582317901464 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Thu, 23 Apr 2026 18:15:34 +0800
Subject: [PATCH 14/16] remove comments
---
diffsynth/diffusion/flow_match.py | 10 ----------
diffsynth/models/ace_step_tokenizer.py | 10 ----------
diffsynth/models/ace_step_vae.py | 6 ------
.../ace_step/model_inference_low_vram/Ace-Step1.5.py | 6 ------
.../model_inference_low_vram/acestep-v15-base.py | 5 -----
.../model_inference_low_vram/acestep-v15-sft.py | 7 -------
.../acestep-v15-turbo-continuous.py | 7 -------
.../acestep-v15-turbo-shift1.py | 7 -------
.../acestep-v15-turbo-shift3.py | 7 -------
.../model_inference_low_vram/acestep-v15-xl-base.py | 6 ------
.../model_inference_low_vram/acestep-v15-xl-sft.py | 6 ------
.../model_inference_low_vram/acestep-v15-xl-turbo.py | 6 ------
12 files changed, 83 deletions(-)
diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py
index 8e77ea80..a29d30dc 100644
--- a/diffsynth/diffusion/flow_match.py
+++ b/diffsynth/diffusion/flow_match.py
@@ -145,16 +145,6 @@ def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, sh
@staticmethod
def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
- """ACE-Step Flow Matching scheduler.
-
- Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
- Shift transformation: t = shift * t / (1 + (shift - 1) * t)
-
- Args:
- num_inference_steps: Number of diffusion steps.
- denoising_strength: Denoising strength (1.0 = full denoising).
- shift: Timestep shift parameter (default 3.0 for turbo).
- """
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
diff --git a/diffsynth/models/ace_step_tokenizer.py b/diffsynth/models/ace_step_tokenizer.py
index 935afa13..ebf30fb6 100644
--- a/diffsynth/models/ace_step_tokenizer.py
+++ b/diffsynth/models/ace_step_tokenizer.py
@@ -11,14 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""ACE-Step Audio Tokenizer — VAE latent discretization pathway.
-
-Contains:
-- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens
-- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features
-
-Only used in cover song mode (is_covers=True). Bypassed in text-to-music.
-"""
from typing import Optional
import torch
@@ -671,8 +663,6 @@ def __init__(
**kwargs,
):
super().__init__()
- # Default layer_types matches target library config (24 alternating entries).
- # Sub-modules (pooler/detokenizer) slice first N entries for their own layer count.
if layer_types is None:
layer_types = ["sliding_attention", "full_attention"] * 12
self.tokenizer = AceStepAudioTokenizer(
diff --git a/diffsynth/models/ace_step_vae.py b/diffsynth/models/ace_step_vae.py
index 047e199f..64f2f9cb 100644
--- a/diffsynth/models/ace_step_vae.py
+++ b/diffsynth/models/ace_step_vae.py
@@ -11,12 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture).
-
-This is a CNN-based VAE for audio waveform encoding/decoding.
-It uses weight-normalized convolutions and Snake1d activations.
-Does NOT depend on diffusers — pure nn.Module implementation.
-"""
import math
from typing import Optional
diff --git a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
index 0160bcf2..72c8296b 100644
--- a/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
+++ b/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py
@@ -1,9 +1,3 @@
-"""
-Ace-Step 1.5 (main model, turbo) — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-base.py b/examples/ace_step/model_inference_low_vram/acestep-v15-base.py
index fc997f2e..11fd7180 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-base.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-base.py
@@ -1,8 +1,3 @@
-"""
-Ace-Step 1.5 Base — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py b/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py
index 189c26a6..6a8e2648 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py
@@ -1,10 +1,3 @@
-"""
-Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-SFT variant is fine-tuned for specific music styles.
-Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py
index 420bc933..8c4f5b2f 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py
@@ -1,10 +1,3 @@
-"""
-Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Turbo model: no num_inference_steps or cfg_scale (use defaults).
-Continuous variant: handles shift range internally, no shift parameter needed.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py
index cfa1583c..8495439d 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py
@@ -1,10 +1,3 @@
-"""
-Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Turbo model: no num_inference_steps or cfg_scale (use defaults).
-shift=1: default value, no need to pass.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py
index aa2af9c9..7d3a552f 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py
@@ -1,10 +1,3 @@
-"""
-Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Turbo model: no num_inference_steps or cfg_scale (use defaults).
-shift=3: explicitly passed for this variant.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py
index dc772ba9..d8d0c47f 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py
@@ -1,9 +1,3 @@
-"""
-Ace-Step 1.5 XL Base — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py
index 5ac17b08..963c4b9a 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py
@@ -1,9 +1,3 @@
-"""
-Ace-Step 1.5 XL SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
diff --git a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py
index 53a5ec5e..cf7893d1 100644
--- a/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py
+++ b/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py
@@ -1,9 +1,3 @@
-"""
-Ace-Step 1.5 XL Turbo — Text-to-Music inference example (Low VRAM).
-
-Low VRAM version: models are offloaded to CPU and loaded on-demand.
-Turbo model: no num_inference_steps or cfg_scale (use defaults).
-"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
From 85bd87b4ab126b2272764e578f65d4c558a796f7 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Fri, 24 Apr 2026 10:33:56 +0800
Subject: [PATCH 15/16] remove modelconfigs comments
---
diffsynth/configs/model_configs.py | 11 -----------
1 file changed, 11 deletions(-)
diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py
index 9c26a3c3..2855e954 100644
--- a/diffsynth/configs/model_configs.py
+++ b/diffsynth/configs/model_configs.py
@@ -917,9 +917,6 @@
]
ace_step_series = [
- # === Standard DiT variants (24 layers, hidden_size=2048) ===
- # Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft
- # All share identical state_dict structure → same hash
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
@@ -927,8 +924,6 @@
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
},
- # === XL DiT variants (32 layers, hidden_size=2560) ===
- # Covers: xl-base, xl-sft, xl-turbo
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
@@ -946,7 +941,6 @@
"layer_types": ["sliding_attention", "full_attention"] * 16,
},
},
- # === Conditioner (shared by all DiT variants, same architecture) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
@@ -954,7 +948,6 @@
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
- # === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
@@ -962,7 +955,6 @@
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
- # === Qwen3-Embedding (text encoder) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
@@ -970,14 +962,12 @@
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
- # === VAE (AutoencoderOobleck CNN) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "51420834e54474986a7f4be0e4d6f687",
"model_name": "ace_step_vae",
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
},
- # === Tokenizer (VAE latent discretization: tokenizer + detokenizer) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
@@ -985,7 +975,6 @@
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
- # === XL Tokenizer (XL models share same tokenizer architecture) ===
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
From 15f474968c3a6d4a09b9e230654ed75d138a4947 Mon Sep 17 00:00:00 2001
From: mi804 <1576993271@qq.com>
Date: Fri, 24 Apr 2026 14:42:53 +0800
Subject: [PATCH 16/16] mior fix
---
diffsynth/pipelines/ace_step.py | 2 --
diffsynth/utils/data/audio.py | 2 +-
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/diffsynth/pipelines/ace_step.py b/diffsynth/pipelines/ace_step.py
index 4d9971a0..fbbb332e 100644
--- a/diffsynth/pipelines/ace_step.py
+++ b/diffsynth/pipelines/ace_step.py
@@ -289,8 +289,6 @@ def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, times
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
- # TODO: remove this
- newtext = prompt + "\n\n" + lyric_text
return {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
diff --git a/diffsynth/utils/data/audio.py b/diffsynth/utils/data/audio.py
index 1add550a..44bd5b16 100644
--- a/diffsynth/utils/data/audio.py
+++ b/diffsynth/utils/data/audio.py
@@ -99,7 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
"""
if waveform.dim() == 3:
waveform = waveform[0]
- waveform.cpu()
+ waveform = waveform.cpu()
if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder