diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 6bdf3011b6..3026672e8f 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -19,6 +19,7 @@ ) from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.provider import TTSProvider +from astrbot.core.utils.tts_text_filter import FilteredQueue AgentRunner = ToolLoopAgentRunner[AstrAgentContext] @@ -354,6 +355,8 @@ async def run_live_agent( show_tool_use: bool = True, show_tool_call_result: bool = False, show_reasoning: bool = False, + tts_filter_enable: bool = False, + tts_filter_custom_rules: list[str] | None = None, buffer_intermediate_messages: bool = False, ) -> AsyncGenerator[MessageChain | None, None]: """Live Mode 的 Agent 运行器,支持流式 TTS @@ -365,6 +368,8 @@ async def run_live_agent( show_tool_use: 是否显示工具使用 show_tool_call_result: 是否显示工具返回结果 show_reasoning: 是否显示推理过程 + tts_filter_enable: 是否启用 TTS 文本过滤 + tts_filter_custom_rules: 自定义 TTS 过滤正则规则 Yields: MessageChain: 包含文本或音频数据的消息链 @@ -398,15 +403,22 @@ async def run_live_agent( first_chunk_received = False # 创建队列 - text_queue: asyncio.Queue[str | None] = asyncio.Queue() + raw_text_queue: asyncio.Queue[str | None] = asyncio.Queue() # audio_queue stored bytes or (text, bytes) audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue() + # 为 TTS 创建过滤队列(Feeder 写入原始文本,TTS 读取过滤后文本) + tts_text_queue: asyncio.Queue[str | None] | FilteredQueue = ( + FilteredQueue(raw_text_queue, tts_filter_custom_rules) + if tts_filter_enable + else raw_text_queue + ) + # 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue feeder_task = asyncio.create_task( _run_agent_feeder( agent_runner, - text_queue, + raw_text_queue, max_step, show_tool_use, show_tool_call_result, @@ -415,14 +427,14 @@ async def run_live_agent( ) ) - # 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue + # 2. 启动 TTS 任务:负责从 tts_text_queue 读取文本并生成音频到 audio_queue if support_stream: tts_task = asyncio.create_task( - _safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue) + _safe_tts_stream_wrapper(tts_provider, tts_text_queue, audio_queue) ) else: tts_task = asyncio.create_task( - _simulated_stream_tts(tts_provider, text_queue, audio_queue) + _simulated_stream_tts(tts_provider, tts_text_queue, audio_queue) ) # 3. 主循环:从 audio_queue 读取音频并 yield diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 903f6c445f..5f355b5730 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -207,6 +207,10 @@ "dual_output": False, "use_file_service": False, "trigger_probability": 1.0, + "tts_text_filter": { + "enable": False, + "custom_rules": [], + }, }, "provider_ltm_settings": { "group_icl_enable": False, @@ -1782,6 +1786,40 @@ class ChatProviderTemplate(TypedDict): "gemini_tts_voice_name": "Leda", "proxy": "", }, + "Qwen TTS Realtime(API)": { + "id": "qwen_tts_realtime", + "type": "qwen_tts_realtime", + "provider": "qwen", + "provider_type": "text_to_speech", + "hint": "千问实时语音合成,支持流式输入输出、低延迟响应。模型可选 qwen3-tts-flash-realtime、qwen3-tts-instruct-flash-realtime、qwen-tts-realtime 等。API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取", + "enable": False, + "api_key": "", + "model": "qwen3-tts-flash-realtime", + "qwen_tts_voice": "Cherry", + "qwen_tts_instructions": "", + "qwen_tts_optimize_instructions": False, + "qwen_tts_speech_rate": 1.0, + "qwen_tts_volume": 1.0, + "qwen_tts_pitch_rate": 1.0, + "qwen_tts_url": "wss://dashscope.aliyuncs.com/api-ws/v1/realtime", + "timeout": "30", + }, + "CosyVoice TTS(API)": { + "id": "cosyvoice_tts", + "type": "cosyvoice_tts", + "provider": "cosyvoice", + "provider_type": "text_to_speech", + "hint": "CosyVoice 语音合成,支持多种系统音色和复刻音色。模型可选 cosyvoice-v3.5-plus、cosyvoice-v3.5-flash、cosyvoice-v3-plus、cosyvoice-v3-flash、cosyvoice-v2 等。API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取", + "enable": False, + "api_key": "", + "model": "cosyvoice-v3-flash", + "cosyvoice_voice": "longanyang", + "cosyvoice_speech_rate": 1.0, + "cosyvoice_volume": 1.0, + "cosyvoice_pitch_rate": 1.0, + "cosyvoice_base_url": "wss://dashscope.aliyuncs.com/api-ws/v1/inference", + "timeout": "20", + }, "OpenAI Embedding": { "id": "openai_embedding", "type": "openai_embedding", @@ -2249,6 +2287,66 @@ class ChatProviderTemplate(TypedDict): "hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)", }, "dashscope_tts_voice": {"description": "音色", "type": "string"}, + "qwen_tts_voice": { + "description": "音色", + "type": "string", + "hint": "Qwen TTS Realtime 音色名称,可选: Cherry(芊悦)、Serena(苏瑶)、Ethan(晨煦) 等。详见 https://help.aliyun.com/zh/model-studio/qwen-tts-realtime-api-reference", + }, + "qwen_tts_instructions": { + "description": "指令控制", + "type": "string", + "hint": "通过自然语言描述控制语音表达效果,如'语速较快,带有明显的上扬语调'。仅 qwen3-tts-instruct-flash-realtime 模型支持。长度不超过 1600 Token。", + }, + "qwen_tts_optimize_instructions": { + "description": "优化指令", + "type": "bool", + "hint": "启用后模型会自动优化指令描述以获得更好的效果。仅 qwen3-tts-instruct-flash-realtime 模型支持。", + }, + "qwen_tts_speech_rate": { + "description": "语速", + "type": "number", + "hint": "语速调节比例,1.0 为正常语速,大于 1.0 加快,小于 1.0 减慢。", + }, + "qwen_tts_volume": { + "description": "音量", + "type": "number", + "hint": "音量调节比例,1.0 为正常音量。", + }, + "qwen_tts_pitch_rate": { + "description": "音调", + "type": "number", + "hint": "音调调节比例,1.0 为正常音调。", + }, + "qwen_tts_url": { + "description": "WebSocket 地址", + "type": "string", + "hint": "Qwen TTS Realtime WebSocket 地址。北京: wss://dashscope.aliyuncs.com/api-ws/v1/realtime;新加坡: wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime", + }, + "cosyvoice_voice": { + "description": "音色", + "type": "string", + "hint": "CosyVoice 音色名称,可选: longanyang、longxiaochun_v2 等。不同模型版本需使用对应版本的音色。详见 https://help.aliyun.com/zh/model-studio/cosyvoice-voice-list", + }, + "cosyvoice_speech_rate": { + "description": "语速", + "type": "number", + "hint": "语速调节比例,1.0 为正常语速。仅部分模型支持。", + }, + "cosyvoice_volume": { + "description": "音量", + "type": "number", + "hint": "音量调节比例,1.0 为正常音量。仅部分模型支持。", + }, + "cosyvoice_pitch_rate": { + "description": "音调", + "type": "number", + "hint": "音调(音高)调节比例,1.0 为正常音调。仅部分模型支持。", + }, + "cosyvoice_base_url": { + "description": "WebSocket 地址", + "type": "string", + "hint": "CosyVoice WebSocket 地址。北京: wss://dashscope.aliyuncs.com/api-ws/v1/inference;新加坡: wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference", + }, "gm_resp_image_modal": { "description": "启用图片模态", "type": "bool", @@ -3127,6 +3225,23 @@ class ChatProviderTemplate(TypedDict): "provider_tts_settings.enable": True, }, }, + "provider_tts_settings.tts_text_filter.enable": { + "description": "过滤 TTS 文本中的括号内容", + "type": "bool", + "hint": "开启后将自动去除 *文字*、【文字】、(文字) 等括号/标记内容,避免 TTS 朗读情绪标记", + "condition": { + "provider_tts_settings.enable": True, + }, + }, + "provider_tts_settings.tts_text_filter.custom_rules": { + "description": "自定义 TTS 过滤正则", + "type": "list", + "items": {"type": "string"}, + "hint": "每行一条正则表达式,将匹配到的内容从 TTS 文本中移除", + "condition": { + "provider_tts_settings.enable": True, + }, + }, "provider_settings.image_caption_prompt": { "description": "图片转述提示词", "type": "text", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index c1d8826562..992d104299 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -277,6 +277,13 @@ async def process( "[Live Mode] TTS Provider 未配置,将使用普通流式模式" ) + # 获取 TTS 文本过滤配置 + tts_filter_cfg = self.ctx.astrbot_config.get( + "provider_tts_settings", {} + ).get("tts_text_filter", {}) + tts_filter_enable = tts_filter_cfg.get("enable", False) + tts_filter_rules = tts_filter_cfg.get("custom_rules", []) + # 使用 run_live_agent,总是使用流式响应 event.set_result( MessageEventResult() @@ -289,6 +296,8 @@ async def process( self.show_tool_use, self.show_tool_call_result, show_reasoning=self.show_reasoning, + tts_filter_enable=tts_filter_enable, + tts_filter_custom_rules=tts_filter_rules, buffer_intermediate_messages=self.buffer_intermediate_messages, ), ), diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 4ee7461305..25d6d79cd8 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -13,6 +13,7 @@ from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry +from astrbot.core.utils.tts_text_filter import TTSTextFilter from ..context import PipelineContext from ..stage import Stage, register_stage, registered_stages @@ -296,8 +297,27 @@ async def process( for comp in result.chain: if isinstance(comp, Plain) and len(comp.text) > 1: try: - logger.info(f"TTS 请求: {comp.text}") - audio_path = await tts_provider.get_audio(comp.text) + # 应用 TTS 文本过滤 + tts_filter_config = self.ctx.astrbot_config[ + "provider_tts_settings" + ].get("tts_text_filter", {}) + tts_filter_enable = tts_filter_config.get("enable", False) + tts_custom_rules = tts_filter_config.get("custom_rules", []) + + if tts_filter_enable: + tts_text = TTSTextFilter.apply( + comp.text, tts_custom_rules + ) + else: + tts_text = comp.text + + if not tts_text: + # 过滤后为空,跳过 TTS + new_chain.append(comp) + continue + + logger.info(f"TTS 请求: {tts_text}") + audio_path = await tts_provider.get_audio(tts_text) logger.info(f"TTS 结果: {audio_path}") if not audio_path: logger.error( diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 0dfdbdcf6d..82cc9861c4 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -461,6 +461,14 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.gemini_tts_source import ( ProviderGeminiTTSAPI as ProviderGeminiTTSAPI, ) + case "qwen_tts_realtime": + from .sources.qwen_tts_realtime_source import ( + ProviderQwenTTSRealtime as ProviderQwenTTSRealtime, + ) + case "cosyvoice_tts": + from .sources.cosyvoice_tts_source import ( + ProviderCosyVoiceTTS as ProviderCosyVoiceTTS, + ) case "openai_embedding": from .sources.openai_embedding_source import ( OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, diff --git a/astrbot/core/provider/sources/cosyvoice_tts_source.py b/astrbot/core/provider/sources/cosyvoice_tts_source.py new file mode 100644 index 0000000000..55474e38eb --- /dev/null +++ b/astrbot/core/provider/sources/cosyvoice_tts_source.py @@ -0,0 +1,96 @@ +"""CosyVoice TTS provider using DashScope API. + +Supports models: +- cosyvoice-v3.5-plus, cosyvoice-v3.5-flash +- cosyvoice-v3-plus, cosyvoice-v3-flash +- cosyvoice-v2, cosyvoice-v1 +- sambert-* models + +Uses dashscope.audio.tts_v2.SpeechSynthesizer for non-streaming TTS. +""" + +import asyncio +import os +import uuid + +from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer + +from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "cosyvoice_tts", + "CosyVoice TTS (DashScope)", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderCosyVoiceTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.chosen_api_key: str = provider_config.get("api_key", "") + self.voice: str = provider_config.get("cosyvoice_voice", "longanyang") + self.speech_rate: float = provider_config.get("cosyvoice_speech_rate", 1.0) + self.volume: float = provider_config.get("cosyvoice_volume", 1.0) + self.pitch_rate: float = provider_config.get("cosyvoice_pitch_rate", 1.0) + self.timeout_ms: float = float(provider_config.get("timeout", 20)) * 1000 + self.base_url: str = provider_config.get( + "cosyvoice_base_url", + "wss://dashscope.aliyuncs.com/api-ws/v1/inference", + ) + + model = provider_config.get("model", "cosyvoice-v3-flash") + self.set_model(model) + + if not self.base_url.startswith("wss://"): + logger.warning( + f"[CosyVoice TTS] WebSocket URL 未使用 wss:// 协议: {self.base_url}" + ) + + async def get_audio(self, text: str) -> str: + """Synthesize speech using CosyVoice and return the audio file path.""" + temp_dir = get_astrbot_temp_path() + os.makedirs(temp_dir, exist_ok=True) + + audio_bytes = await self._synthesize(text) + if not audio_bytes: + raise RuntimeError( + f"Audio synthesis failed for model '{self.get_model()}'. " + "The model may not be supported or the service is unavailable.", + ) + + path = os.path.join(temp_dir, f"cosyvoice_tts_{uuid.uuid4()}.wav") + with open(path, "wb") as f: + f.write(audio_bytes) + return path + + async def _synthesize(self, text: str) -> bytes | None: + """Use CosyVoice SpeechSynthesizer to synthesize speech.""" + loop = asyncio.get_running_loop() + + model = self.get_model() + fmt = AudioFormat.WAV_24000HZ_MONO_16BIT + + synthesizer = SpeechSynthesizer( + model=model, + voice=self.voice, + format=fmt, + api_key=self.chosen_api_key, + url=self.base_url, + ) + + audio_bytes = await loop.run_in_executor( + None, + synthesizer.call, + text, + self.timeout_ms, + ) + + return audio_bytes diff --git a/astrbot/core/provider/sources/qwen_tts_realtime_source.py b/astrbot/core/provider/sources/qwen_tts_realtime_source.py new file mode 100644 index 0000000000..f24b4884a6 --- /dev/null +++ b/astrbot/core/provider/sources/qwen_tts_realtime_source.py @@ -0,0 +1,359 @@ +"""Qwen TTS Realtime - WebSocket streaming TTS provider. + +Supports models: +- qwen3-tts-flash-realtime (and snapshots) +- qwen3-tts-instruct-flash-realtime (and snapshots, with instructions control) +- qwen-tts-realtime (and snapshots) + +Uses dashscope.audio.qwen_tts_realtime.QwenTtsRealtime for WebSocket-based +streaming text-to-speech with low-latency response. +""" + +import asyncio +import base64 +import os +import struct +import threading +import uuid + +try: + from dashscope.audio.qwen_tts_realtime import ( + AudioFormat, + QwenTtsRealtime, + QwenTtsRealtimeCallback, + ) +except ImportError: # pragma: no cover + QwenTtsRealtime = None + QwenTtsRealtimeCallback = None + AudioFormat = None + +from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +class _QwenRealtimeCallback(QwenTtsRealtimeCallback): + """Callback for Qwen TTS Realtime WebSocket events.""" + + def __init__(self) -> None: + self.complete_event = threading.Event() + self._lock = threading.Lock() + self.audio_chunks: list[bytes] = [] + self.error_msg: str | None = None + + def on_open(self) -> None: + logger.debug("[QwenTTS Realtime] WebSocket connection opened") + + def on_close(self, close_status_code: int, close_msg: str) -> None: + logger.debug( + f"[QwenTTS Realtime] Connection closed: code={close_status_code}, msg={close_msg}", + ) + + def on_event(self, response: dict) -> None: + try: + event_type = response.get("type", "") + if event_type == "session.created": + session_id = response.get("session", {}).get("id", "unknown") + logger.debug(f"[QwenTTS Realtime] Session created: {session_id}") + elif event_type == "response.audio.delta": + audio_b64 = response.get("delta", "") + if audio_b64: + with self._lock: + self.audio_chunks.append(base64.b64decode(audio_b64)) + elif event_type == "response.done": + logger.debug("[QwenTTS Realtime] Response done") + elif event_type == "session.finished": + logger.debug("[QwenTTS Realtime] Session finished") + self.complete_event.set() + elif event_type == "error": + self.error_msg = str(response.get("error", "Unknown error")) + logger.error(f"[QwenTTS Realtime] Error: {self.error_msg}") + self.complete_event.set() + except Exception as e: + logger.error(f"[QwenTTS Realtime] Callback error: {e}") + + def drain_audio_chunks(self) -> list[bytes]: + """Thread-safely drain all accumulated audio chunks.""" + with self._lock: + chunks = self.audio_chunks + self.audio_chunks = [] + return chunks + + def wait_for_finished(self, timeout: float = 30) -> bool: + return self.complete_event.wait(timeout=timeout) + + +@register_provider_adapter( + "qwen_tts_realtime", + "Qwen TTS Realtime (WebSocket streaming)", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderQwenTTSRealtime(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.chosen_api_key: str = provider_config.get("api_key", "") + self.voice: str = provider_config.get("qwen_tts_voice", "Cherry") + self.instructions: str = provider_config.get("qwen_tts_instructions", "") + self.optimize_instructions: bool = provider_config.get( + "qwen_tts_optimize_instructions", + False, + ) + self.speech_rate: float = provider_config.get("qwen_tts_speech_rate", 1.0) + self.volume: float = provider_config.get("qwen_tts_volume", 1.0) + self.pitch_rate: float = provider_config.get("qwen_tts_pitch_rate", 1.0) + self.qwen_tts_url: str = provider_config.get( + "qwen_tts_url", + "wss://dashscope.aliyuncs.com/api-ws/v1/realtime", + ) + self.timeout: float = float(provider_config.get("timeout", 30)) + + model = provider_config.get("model", "qwen3-tts-flash-realtime") + self.set_model(model) + + if not self.qwen_tts_url.startswith("wss://"): + logger.warning( + f"[QwenTTS Realtime] WebSocket URL 未使用 wss:// 协议: {self.qwen_tts_url}" + ) + + def support_stream(self) -> bool: + return True + + async def get_audio(self, text: str) -> str: + """Synthesize speech and return the audio file path.""" + temp_dir = get_astrbot_temp_path() + os.makedirs(temp_dir, exist_ok=True) + + audio_bytes = await self._synthesize(text) + if not audio_bytes: + raise RuntimeError( + "Audio synthesis failed, returned empty content. " + "The model may not be supported or the service is unavailable.", + ) + + path = os.path.join(temp_dir, f"qwen_tts_realtime_{uuid.uuid4()}.wav") + with open(path, "wb") as f: + f.write(audio_bytes) + return path + + async def _synthesize(self, text: str) -> bytes | None: + """Use Qwen TTS Realtime WebSocket API to synthesize speech.""" + if QwenTtsRealtime is None: + raise RuntimeError( + "dashscope SDK missing QwenTtsRealtime. " + "Please upgrade the dashscope package to use Qwen TTS Realtime.", + ) + + callback = _QwenRealtimeCallback() + model = self.get_model() + + qwen_tts = QwenTtsRealtime( + model=model, + callback=callback, + url=self.qwen_tts_url, + api_key=self.chosen_api_key, + ) + + loop = asyncio.get_running_loop() + + def _connect_and_send() -> None: + try: + qwen_tts.connect() + kwargs: dict = { + "voice": self.voice, + "response_format": AudioFormat.PCM_24000HZ_MONO_16BIT, + "mode": "server_commit", + } + if self.instructions: + kwargs["instructions"] = self.instructions + kwargs["optimize_instructions"] = self.optimize_instructions + if self.speech_rate != 1.0: + kwargs["speech_rate"] = self.speech_rate + if self.volume != 1.0: + kwargs["volume"] = self.volume + if self.pitch_rate != 1.0: + kwargs["pitch_rate"] = self.pitch_rate + qwen_tts.update_session(**kwargs) + qwen_tts.append_text(text) + qwen_tts.finish() + except Exception as e: + callback.error_msg = str(e) + callback.complete_event.set() + + await loop.run_in_executor(None, _connect_and_send) + finished = callback.wait_for_finished(timeout=self.timeout) + + if callback.error_msg: + logger.error(f"[QwenTTS Realtime] Synthesis error: {callback.error_msg}") + return None + + if not finished: + logger.error("[QwenTTS Realtime] Synthesis timeout") + return None + + # PCM 24000Hz Mono 16bit -> wrap as WAV + pcm_data = b"".join(callback.audio_chunks) + if not pcm_data: + return None + return self._pcm_to_wav(pcm_data, sample_rate=24000) + + def _pcm_to_wav(self, pcm_data: bytes, sample_rate: int = 24000) -> bytes: + """Convert raw PCM to WAV format.""" + num_channels = 1 + bits_per_sample = 16 + byte_rate = sample_rate * num_channels * bits_per_sample // 8 + block_align = num_channels * bits_per_sample // 8 + data_size = len(pcm_data) + + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + 36 + data_size, + b"WAVE", + b"fmt ", + 16, + 1, # PCM + num_channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b"data", + data_size, + ) + return header + pcm_data + + async def get_audio_stream( + self, + text_queue: asyncio.Queue[str | None], + audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None], + ) -> None: + """Streaming TTS using Qwen TTS Realtime WebSocket API. + + Reads text fragments from text_queue, sends them to the Realtime API + incrementally, and streams audio chunks to audio_queue as they arrive. + Sends None to audio_queue when done. + """ + if QwenTtsRealtime is None: + raise RuntimeError( + "dashscope SDK missing QwenTtsRealtime. " + "Please upgrade the dashscope package to use Qwen TTS Realtime.", + ) + + callback = _QwenRealtimeCallback() + model = self.get_model() + + qwen_tts = QwenTtsRealtime( + model=model, + callback=callback, + url=self.qwen_tts_url, + api_key=self.chosen_api_key, + ) + + loop = asyncio.get_running_loop() + + # Connect and configure session on background thread + def _connect() -> None: + try: + qwen_tts.connect() + kwargs: dict = { + "voice": self.voice, + "response_format": AudioFormat.PCM_24000HZ_MONO_16BIT, + "mode": "server_commit", + } + if self.instructions: + kwargs["instructions"] = self.instructions + kwargs["optimize_instructions"] = self.optimize_instructions + if self.speech_rate != 1.0: + kwargs["speech_rate"] = self.speech_rate + if self.volume != 1.0: + kwargs["volume"] = self.volume + if self.pitch_rate != 1.0: + kwargs["pitch_rate"] = self.pitch_rate + qwen_tts.update_session(**kwargs) + except Exception as e: + callback.error_msg = str(e) + callback.complete_event.set() + + await loop.run_in_executor(None, _connect) + + if callback.error_msg: + logger.error(f"[QwenTTS Realtime] Connection error: {callback.error_msg}") + await audio_queue.put(None) + return + + # Background collector: periodically drain audio chunks from callback + # and push to audio_queue + pcm_buffer: list[bytes] = [] + # ~200ms of audio at 24kHz, 16bit, mono = 9600 bytes + chunk_threshold = 9600 + + async def _collector() -> None: + while not callback.complete_event.is_set(): + chunks = callback.drain_audio_chunks() + if chunks: + pcm_buffer.extend(chunks) + total = sum(len(c) for c in pcm_buffer) + if total >= chunk_threshold: + pcm_data = b"".join(pcm_buffer) + pcm_buffer.clear() + wav_data = self._pcm_to_wav(pcm_data, sample_rate=24000) + await audio_queue.put(wav_data) + await asyncio.sleep(0.05) + + # Drain final chunks before exiting + remaining = callback.drain_audio_chunks() + if remaining: + pcm_buffer.extend(remaining) + if pcm_buffer: + pcm_data = b"".join(pcm_buffer) + wav_data = self._pcm_to_wav(pcm_data, sample_rate=24000) + await audio_queue.put(wav_data) + + collector_task = asyncio.create_task(_collector(), name="qwen_tts_collector") + + try: + # Main loop: send text fragments to TTS + while True: + text_part = await text_queue.get() + + if text_part is None: + # End of input: finish synthesis + await loop.run_in_executor(None, qwen_tts.finish) + + # Wait for all audio to be generated + finished = await loop.run_in_executor( + None, + callback.wait_for_finished, + self.timeout, + ) + if not finished: + logger.warning("[QwenTTS Realtime] Streaming timeout") + + # Signal end of audio stream + await audio_queue.put(None) + break + + await loop.run_in_executor( + None, + qwen_tts.append_text, + text_part, + ) + + finally: + collector_task.cancel() + try: + await collector_task + except asyncio.CancelledError: + pass + + try: + await loop.run_in_executor(None, qwen_tts.close) + except Exception: + pass diff --git a/astrbot/core/utils/tts_text_filter.py b/astrbot/core/utils/tts_text_filter.py new file mode 100644 index 0000000000..9525325aec --- /dev/null +++ b/astrbot/core/utils/tts_text_filter.py @@ -0,0 +1,96 @@ +"""TTS 文本过滤器:在发送 TTS 前去除括号/标记等内容。""" + +from __future__ import annotations + +import asyncio +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar("T") + +from astrbot.core import logger + + +class TTSTextFilter: + """过滤 TTS 文本中的括号内容。""" + + # 内置默认规则:匹配各种括号及其内容 + BUILTIN_PATTERNS: list[str] = [ + r"\*\*[^*]+\*\*", # **文字** + r"\*[^*]+\*", # *文字* + r"\([^)]*\)", # (文字) 英文/半角括号 + r"([^)]*)", # (文字)中文括号 + r"【[^】]*】", # 【文字】 + r"\[[^\]]*\]", # [文字] + ] + + @classmethod + def apply(cls, text: str, custom_rules: list[str] | None = None) -> str: + """应用内置规则和自定义规则,返回过滤后的文本。 + + 如果 custom_rules 中包含无效的正则表达式,会记录警告日志并跳过该规则。 + """ + result = text + all_rules = cls.BUILTIN_PATTERNS + (custom_rules or []) + for i, pattern in enumerate(all_rules): + try: + result = re.sub(pattern, "", result) + except re.error: + is_custom = i >= len(cls.BUILTIN_PATTERNS) + if is_custom and custom_rules: + idx = i - len(cls.BUILTIN_PATTERNS) + logger.warning( + f"[TTSTextFilter] 自定义正则规则 #{idx} 无效,已跳过: {pattern}" + ) + # 内置规则出错不记录日志(几乎不会发生) + return result.strip() + + +class FilteredQueue: + """异步队列包装器,在 get() 时自动过滤文本。 + + 用于 TTS 流式场景:Feeder 写入原始文本(用于日志/UI), + TTS 消费者读取过滤后的文本。 + 不继承 asyncio.Queue,而是通过组合模式包装真实队列。 + """ + + def __init__( + self, + real_queue: asyncio.Queue, + custom_rules: list[str] | None = None, + ) -> None: + self._real_queue = real_queue + self._custom_rules = custom_rules + + async def get(self) -> str | None: + while True: + item = await self._real_queue.get() + if item is None: + return None + if isinstance(item, str): + filtered = TTSTextFilter.apply(item, self._custom_rules) + if filtered: + return filtered + continue + return item + + def qsize(self) -> int: + return self._real_queue.qsize() + + def empty(self) -> bool: + return self._real_queue.empty() + + def full(self) -> bool: + return self._real_queue.full() + + async def put(self, item) -> None: + await self._real_queue.put(item) + + def task_done(self) -> None: + self._real_queue.task_done() + + async def join(self) -> None: + await self._real_queue.join() diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 373182fc15..5b2883d2af 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -73,6 +73,16 @@ }, "trigger_probability": { "description": "TTS Trigger Probability" + }, + "tts_text_filter": { + "enable": { + "description": "Filter bracket content in TTS text", + "hint": "When enabled, automatically removes *text*, 【text】, (text) and other bracket/marker content to avoid TTS reading emotion markers" + }, + "custom_rules": { + "description": "Custom TTS filter regex", + "hint": "One regex per line. Matched content will be removed from TTS text" + } } } }, @@ -1336,6 +1346,54 @@ "dashscope_tts_voice": { "description": "Voice" }, + "qwen_tts_voice": { + "description": "Voice", + "hint": "Qwen TTS Realtime voice name, e.g. Cherry, Serena, Ethan, etc. See Alibaba Cloud documentation for details." + }, + "qwen_tts_instructions": { + "description": "Instructions", + "hint": "Natural language description to control speech expression, e.g. 'fast-paced with an upbeat tone'. Only supported by qwen3-tts-instruct-flash-realtime model, max 1600 tokens." + }, + "qwen_tts_optimize_instructions": { + "description": "Optimize Instructions", + "hint": "When enabled, the model will automatically optimize the instruction description for better results. Only supported by qwen3-tts-instruct-flash-realtime model." + }, + "qwen_tts_speech_rate": { + "description": "Speech Rate", + "hint": "Speech rate ratio, 1.0 is normal speed. Greater than 1.0 is faster, less than 1.0 is slower." + }, + "qwen_tts_volume": { + "description": "Volume", + "hint": "Volume ratio, 1.0 is normal volume." + }, + "qwen_tts_pitch_rate": { + "description": "Pitch", + "hint": "Pitch ratio, 1.0 is normal pitch." + }, + "qwen_tts_url": { + "description": "WebSocket URL", + "hint": "Qwen TTS Realtime WebSocket URL. Beijing: wss://dashscope.aliyuncs.com/api-ws/v1/realtime; Singapore: wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime" + }, + "cosyvoice_voice": { + "description": "Voice", + "hint": "CosyVoice voice name, e.g. longanyang, longxiaochun_v2, etc. Different model versions require corresponding voices." + }, + "cosyvoice_speech_rate": { + "description": "Speech Rate", + "hint": "Speech rate ratio, 1.0 is normal speed. Only supported by some models." + }, + "cosyvoice_volume": { + "description": "Volume", + "hint": "Volume ratio, 1.0 is normal volume. Only supported by some models." + }, + "cosyvoice_pitch_rate": { + "description": "Pitch", + "hint": "Pitch ratio, 1.0 is normal pitch. Only supported by some models." + }, + "cosyvoice_base_url": { + "description": "WebSocket URL", + "hint": "CosyVoice WebSocket URL. Beijing: wss://dashscope.aliyuncs.com/api-ws/v1/inference; Singapore: wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference" + }, "gm_resp_image_modal": { "description": "Enable image modality", "hint": "When enabled, responses can include images. Requires model support or it will error. See the Google Gemini website for supported models. Tip: if you need image generation, disable the `Enable member recognition` setting for better results." diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index c578f79d1c..b10949e8bf 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -73,6 +73,16 @@ }, "trigger_probability": { "description": "Вероятность срабатывания TTS" + }, + "tts_text_filter": { + "enable": { + "description": "Фильтровать содержимое скобок в тексте TTS", + "hint": "При включении автоматически удаляет *текст*, 【текст】, (текст) и другие скобки/маркеры, чтобы TTS не зачитывал эмоциональные маркеры" + }, + "custom_rules": { + "description": "Пользовательские регулярные выражения TTS-фильтра", + "hint": "По одному регулярному выражению на строку. Совпавший текст будет удалён из текста TTS" + } } } }, @@ -1333,6 +1343,54 @@ "dashscope_tts_voice": { "description": "Голос" }, + "qwen_tts_voice": { + "description": "Голос Qwen TTS", + "hint": "Имя голоса для Qwen TTS Realtime, например Cherry." + }, + "qwen_tts_instructions": { + "description": "Инструкция Qwen TTS", + "hint": "Дополнительные инструкции для модели TTS, управляющие стилем речи." + }, + "qwen_tts_optimize_instructions": { + "description": "Оптимизировать инструкцию Qwen TTS", + "hint": "Если включено, модель автоматически оптимизирует инструкцию для лучшего результата." + }, + "qwen_tts_speech_rate": { + "description": "Скорость речи Qwen TTS", + "hint": "Множитель скорости речи. 1.0 — нормальная скорость." + }, + "qwen_tts_volume": { + "description": "Громкость Qwen TTS", + "hint": "Множитель громкости. 1.0 — нормальная громкость." + }, + "qwen_tts_pitch_rate": { + "description": "Высота голоса Qwen TTS", + "hint": "Множитель высоты голоса. 1.0 — нормальная высота." + }, + "qwen_tts_url": { + "description": "WebSocket URL Qwen TTS", + "hint": "Адрес WebSocket API для Qwen TTS Realtime." + }, + "cosyvoice_voice": { + "description": "Голос CosyVoice", + "hint": "Имя голоса для CosyVoice TTS, например longanyang." + }, + "cosyvoice_speech_rate": { + "description": "Скорость речи CosyVoice", + "hint": "Множитель скорости речи. 1.0 — нормальная скорость." + }, + "cosyvoice_volume": { + "description": "Громкость CosyVoice", + "hint": "Множитель громкости. 1.0 — нормальная громкость." + }, + "cosyvoice_pitch_rate": { + "description": "Высота голоса CosyVoice", + "hint": "Множитель высоты голоса. 1.0 — нормальная высота." + }, + "cosyvoice_base_url": { + "description": "Base URL CosyVoice", + "hint": "Адрес WebSocket API для CosyVoice TTS." + }, "gm_resp_image_modal": { "description": "Включить визуальную модальность", "hint": "Если включено, ответы могут содержать изображения. Требует поддержки моделью. Совет: для генерации изображений отключите 'Распознавание участников'." diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index dd8711345c..f454440a3e 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -73,6 +73,16 @@ }, "trigger_probability": { "description": "TTS 触发概率" + }, + "tts_text_filter": { + "enable": { + "description": "过滤 TTS 文本中的括号内容", + "hint": "开启后将自动去除 *文字*、【文字】、(文字) 等括号/标记内容,避免 TTS 朗读情绪标记" + }, + "custom_rules": { + "description": "自定义 TTS 过滤正则", + "hint": "每行一条正则表达式,将匹配到的内容从 TTS 文本中移除" + } } } }, @@ -1338,6 +1348,54 @@ "dashscope_tts_voice": { "description": "音色" }, + "qwen_tts_voice": { + "description": "音色", + "hint": "Qwen TTS Realtime 音色名称,可选: Cherry(芊悦)、Serena(苏瑶)、Ethan(晨煦) 等。详见阿里云文档。" + }, + "qwen_tts_instructions": { + "description": "指令控制", + "hint": "通过自然语言描述控制语音表达效果,如'语速较快,带有明显的上扬语调'。仅 qwen3-tts-instruct-flash-realtime 模型支持,长度不超过 1600 Token。" + }, + "qwen_tts_optimize_instructions": { + "description": "优化指令", + "hint": "启用后模型会自动优化指令描述以获得更好的效果。仅 qwen3-tts-instruct-flash-realtime 模型支持。" + }, + "qwen_tts_speech_rate": { + "description": "语速", + "hint": "语速调节比例,1.0 为正常语速,大于 1.0 加快,小于 1.0 减慢。" + }, + "qwen_tts_volume": { + "description": "音量", + "hint": "音量调节比例,1.0 为正常音量。" + }, + "qwen_tts_pitch_rate": { + "description": "音调", + "hint": "音调调节比例,1.0 为正常音调。" + }, + "qwen_tts_url": { + "description": "WebSocket 地址", + "hint": "Qwen TTS Realtime WebSocket 地址。北京: wss://dashscope.aliyuncs.com/api-ws/v1/realtime;新加坡: wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime" + }, + "cosyvoice_voice": { + "description": "音色", + "hint": "CosyVoice 音色名称,可选: longanyang、longxiaochun_v2 等。不同模型版本需使用对应版本的音色。" + }, + "cosyvoice_speech_rate": { + "description": "语速", + "hint": "语速调节比例,1.0 为正常语速。仅部分模型支持。" + }, + "cosyvoice_volume": { + "description": "音量", + "hint": "音量调节比例,1.0 为正常音量。仅部分模型支持。" + }, + "cosyvoice_pitch_rate": { + "description": "音调", + "hint": "音调(音高)调节比例,1.0 为正常音调。仅部分模型支持。" + }, + "cosyvoice_base_url": { + "description": "WebSocket 地址", + "hint": "CosyVoice WebSocket 地址。北京: wss://dashscope.aliyuncs.com/api-ws/v1/inference;新加坡: wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference" + }, "gm_resp_image_modal": { "description": "启用图片模态", "hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。" diff --git a/tests/unit/test_tts_text_filter.py b/tests/unit/test_tts_text_filter.py new file mode 100644 index 0000000000..8c46bf5a68 --- /dev/null +++ b/tests/unit/test_tts_text_filter.py @@ -0,0 +1,185 @@ +"""Tests for TTS text filter utility.""" + +import asyncio + +import pytest + +from astrbot.core.utils.tts_text_filter import FilteredQueue, TTSTextFilter + + +class TestTTSTextFilter: + """Test TTSTextFilter.apply() with various patterns.""" + + def test_builtin_markdown_bold(self): + """Filter **bold** markdown.""" + result = TTSTextFilter.apply("Hello **world** test") + assert result == "Hello test" + + def test_builtin_markdown_italic(self): + """Filter *italic* markdown.""" + result = TTSTextFilter.apply("Hello *world* test") + assert result == "Hello test" + + def test_builtin_parentheses(self): + """Filter (content) in parentheses.""" + result = TTSTextFilter.apply("Hello (world) test") + assert result == "Hello test" + + def test_builtin_chinese_parentheses(self): + """Filter (content) in Chinese parentheses.""" + result = TTSTextFilter.apply("Hello(world)test") + assert result == "Hellotest" + + def test_builtin_corner_brackets(self): + """Filter 【content】 in corner brackets.""" + result = TTSTextFilter.apply("Hello【world】test") + assert result == "Hellotest" + + def test_builtin_square_brackets(self): + """Filter [content] in square brackets.""" + result = TTSTextFilter.apply("Hello [world] test") + assert result == "Hello test" + + def test_multiple_patterns(self): + """Filter multiple patterns in one text.""" + result = TTSTextFilter.apply( + "**bold** and *italic* and (parens) and 【corner】" + ) + assert result == "and and and" + + def test_nested_brackets_simple(self): + """Nested brackets - only outermost is removed.""" + result = TTSTextFilter.apply("Hello (**bold**) test") + # The inner **bold** would be filtered first, then () would catch the rest + assert "bold" not in result + + def test_no_brackets(self): + """Text without brackets passes through unchanged.""" + result = TTSTextFilter.apply("Hello world, this is a test.") + assert result == "Hello world, this is a test." + + def test_empty_string(self): + """Empty string returns empty string.""" + result = TTSTextFilter.apply("") + assert result == "" + + def test_only_brackets(self): + """Text with only brackets returns empty string.""" + result = TTSTextFilter.apply("(test)") + assert result == "" + + def test_custom_rules(self): + """Custom regex rules are applied after built-in rules.""" + # and stripped, but "world" between them remains + result = TTSTextFilter.apply( + "Hello world test", + custom_rules=[r"<[^>]*>"], + ) + assert result == "Hello world test" + + def test_invalid_custom_rule_skipped(self): + """Invalid custom regex rules are skipped without crashing.""" + # Should not raise + result = TTSTextFilter.apply( + "Hello world", + custom_rules=[r"[invalid"], # unterminated bracket set + ) + assert result == "Hello world" + + def test_mixed_content_with_no_match(self): + """Content that doesn't match any pattern is preserved.""" + text = "你好,今天天气不错!" + result = TTSTextFilter.apply(text) + assert result == text + + def test_whitespace_trimming(self): + """Result is stripped of leading/trailing whitespace.""" + result = TTSTextFilter.apply(" Hello world ") + assert result == "Hello world" + + +@pytest.mark.asyncio +class TestFilteredQueue: + """Test FilteredQueue wrapper.""" + + async def test_get_filtered_text(self): + """Getting text from queue returns filtered text.""" + real_queue: asyncio.Queue = asyncio.Queue() + fq = FilteredQueue(real_queue, custom_rules=[]) + + await fq.put("Hello (world) test") + result = await fq.get() + + assert result == "Hello test" + + async def test_none_passthrough(self): + """None sentinel values pass through unfiltered.""" + real_queue: asyncio.Queue = asyncio.Queue() + fq = FilteredQueue(real_queue) + + await fq.put(None) + result = await fq.get() + + assert result is None + + async def test_non_string_passthrough(self): + """Non-string values pass through unfiltered.""" + real_queue: asyncio.Queue = asyncio.Queue() + fq = FilteredQueue(real_queue) + + await fq.put(123) + result = await fq.get() + + assert result == 123 + + async def test_custom_rules_in_queue(self): + """Custom rules are applied during get().""" + real_queue: asyncio.Queue = asyncio.Queue() + fq = FilteredQueue(real_queue, custom_rules=[r"<[^>]*>"]) + + await fq.put("Hello world") + result = await fq.get() + + assert result == "Hello world" + + async def test_queue_size_methods(self): + """qsize, empty, full delegate to real queue.""" + real_queue: asyncio.Queue = asyncio.Queue(maxsize=10) + fq = FilteredQueue(real_queue) + + assert fq.empty() is True + assert fq.full() is False + + await fq.put("item") + assert fq.qsize() == 1 + assert fq.empty() is False + + async def test_multiple_items(self): + """Multiple items through the queue are all filtered.""" + real_queue: asyncio.Queue = asyncio.Queue() + fq = FilteredQueue(real_queue) + + texts = ["Hello (world)", "Foo **bar**", "Normal text"] + for t in texts: + await fq.put(t) + + results = [] + for _ in texts: + results.append(await fq.get()) + + assert results[0] == "Hello" + assert results[1] == "Foo" + assert results[2] == "Normal text" + + async def test_filtered_with_mixed_none(self): + """Mix of text and None pass through correctly.""" + real_queue: asyncio.Queue = asyncio.Queue() + fq = FilteredQueue(real_queue) + + await fq.put("Hello (world)") + await fq.put(None) + await fq.put("**bold** text") + + assert await fq.get() == "Hello" + assert await fq.get() is None + assert await fq.get() == "text"