diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 0265973b9a..6e9526dad6 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -65,6 +65,8 @@ async def process( logger.debug(f"路径映射: {url} -> {component.url}") message_chain[idx] = component + failed_record_ids: set[int] = set() + # In here, we convert all Record components to wav format and update the file path. message_chain = event.get_messages() for idx, component in enumerate(message_chain): @@ -78,6 +80,7 @@ async def process( component.path = record_path message_chain[idx] = component except Exception as e: + failed_record_ids.add(id(component)) logger.warning(f"Voice processing failed: {e}") # Also process Record components inside Reply chains (wav conversion) @@ -94,6 +97,7 @@ async def process( reply_comp.path = record_path component.chain[idx] = reply_comp except Exception as e: + failed_record_ids.add(id(reply_comp)) logger.warning( f"Voice processing in reply chain failed: {e}" ) @@ -141,7 +145,10 @@ async def _stt_record(record_comp: Record, is_reply: bool = False): message_chain = event.get_messages() for idx, component in enumerate(message_chain): - if isinstance(component, Record): + if ( + isinstance(component, Record) + and id(component) not in failed_record_ids + ): plain_comp = await _stt_record(component) if plain_comp: message_chain[idx] = plain_comp @@ -152,7 +159,10 @@ async def _stt_record(record_comp: Record, is_reply: bool = False): for component in event.get_messages(): if isinstance(component, Reply) and component.chain: for idx, reply_comp in enumerate(component.chain): - if isinstance(reply_comp, Record): + if ( + isinstance(reply_comp, Record) + and id(reply_comp) not in failed_record_ids + ): plain_comp = await _stt_record(reply_comp, is_reply=True) if plain_comp: component.chain[idx] = plain_comp diff --git a/tests/test_preprocess_stage.py b/tests/test_preprocess_stage.py new file mode 100644 index 0000000000..24f9840ef1 --- /dev/null +++ b/tests/test_preprocess_stage.py @@ -0,0 +1,83 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.core.message.components import Plain, Record, Reply +from astrbot.core.pipeline.preprocess_stage.stage import PreProcessStage + + +def _make_stage(stt_provider: AsyncMock) -> PreProcessStage: + stage = PreProcessStage() + stage.config = {} + stage.platform_settings = {} + stage.stt_settings = {"enable": True} + stage.plugin_manager = SimpleNamespace( + context=SimpleNamespace(get_using_stt_provider=lambda _: stt_provider) + ) + return stage + + +def _make_event(messages: list) -> MagicMock: + event = MagicMock() + event.get_platform_name.return_value = "test" + event.is_at_or_wake_command = False + event.get_messages.return_value = messages + event.unified_msg_origin = "test:friend:test" + event.message_str = "" + event.message_obj.message_str = "" + return event + + +@pytest.mark.asyncio +async def test_failed_audio_conversion_is_not_sent_to_stt(monkeypatch): + failed_record = Record(file="failed.amr") + valid_record = Record(file="valid.wav") + messages = [failed_record, valid_record] + stt_provider = AsyncMock() + stt_provider.get_text.return_value = "transcribed" + + async def convert_to_file_path(record): + return record.file + + async def convert_to_wav(path): + if path == "failed.amr": + raise RuntimeError("ffmpeg not found") + return path + + monkeypatch.setattr(Record, "convert_to_file_path", convert_to_file_path) + monkeypatch.setattr( + "astrbot.core.pipeline.preprocess_stage.stage.ensure_wav", + convert_to_wav, + ) + + await _make_stage(stt_provider).process(_make_event(messages)) + + assert messages[0] is failed_record + assert isinstance(messages[1], Plain) + assert messages[1].text == "transcribed" + stt_provider.get_text.assert_awaited_once_with(audio_url="valid.wav") + + +@pytest.mark.asyncio +async def test_failed_reply_audio_conversion_is_not_sent_to_stt(monkeypatch): + failed_record = Record(file="failed.amr") + reply = Reply(id="reply-id", chain=[failed_record]) + stt_provider = AsyncMock() + + async def convert_to_file_path(record): + return record.file + + async def convert_to_wav(_): + raise RuntimeError("ffmpeg not found") + + monkeypatch.setattr(Record, "convert_to_file_path", convert_to_file_path) + monkeypatch.setattr( + "astrbot.core.pipeline.preprocess_stage.stage.ensure_wav", + convert_to_wav, + ) + + await _make_stage(stt_provider).process(_make_event([reply])) + + assert reply.chain == [failed_record] + stt_provider.get_text.assert_not_awaited()