-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
fix: skip STT after audio conversion failure #8704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,8 @@ async def process( | |
| logger.debug(f"路径映射: {url} -> {component.url}") | ||
| message_chain[idx] = component | ||
|
|
||
| failed_record_ids: set[int] = set() | ||
|
|
||
| # In here, we convert all Record components to wav format and update the file path. | ||
| message_chain = event.get_messages() | ||
| for idx, component in enumerate(message_chain): | ||
|
Comment on lines
+68
to
72
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Tracking failed records by This relies on the same |
||
|
|
@@ -78,6 +80,7 @@ async def process( | |
| component.path = record_path | ||
| message_chain[idx] = component | ||
| except Exception as e: | ||
| failed_record_ids.add(id(component)) | ||
| logger.warning(f"Voice processing failed: {e}") | ||
|
|
||
| # Also process Record components inside Reply chains (wav conversion) | ||
|
|
@@ -94,6 +97,7 @@ async def process( | |
| reply_comp.path = record_path | ||
| component.chain[idx] = reply_comp | ||
| except Exception as e: | ||
| failed_record_ids.add(id(reply_comp)) | ||
| logger.warning( | ||
| f"Voice processing in reply chain failed: {e}" | ||
| ) | ||
|
|
@@ -141,7 +145,10 @@ async def _stt_record(record_comp: Record, is_reply: bool = False): | |
|
|
||
| message_chain = event.get_messages() | ||
| for idx, component in enumerate(message_chain): | ||
| if isinstance(component, Record): | ||
| if ( | ||
| isinstance(component, Record) | ||
| and id(component) not in failed_record_ids | ||
| ): | ||
| plain_comp = await _stt_record(component) | ||
| if plain_comp: | ||
| message_chain[idx] = plain_comp | ||
|
|
@@ -152,7 +159,10 @@ async def _stt_record(record_comp: Record, is_reply: bool = False): | |
| for component in event.get_messages(): | ||
| if isinstance(component, Reply) and component.chain: | ||
| for idx, reply_comp in enumerate(component.chain): | ||
| if isinstance(reply_comp, Record): | ||
| if ( | ||
| isinstance(reply_comp, Record) | ||
| and id(reply_comp) not in failed_record_ids | ||
| ): | ||
| plain_comp = await _stt_record(reply_comp, is_reply=True) | ||
| if plain_comp: | ||
| component.chain[idx] = plain_comp | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from types import SimpleNamespace | ||
| from unittest.mock import AsyncMock, MagicMock | ||
|
|
||
| import pytest | ||
|
|
||
| from astrbot.core.message.components import Plain, Record, Reply | ||
| from astrbot.core.pipeline.preprocess_stage.stage import PreProcessStage | ||
|
|
||
|
|
||
| def _make_stage(stt_provider: AsyncMock) -> PreProcessStage: | ||
| stage = PreProcessStage() | ||
| stage.config = {} | ||
| stage.platform_settings = {} | ||
| stage.stt_settings = {"enable": True} | ||
| stage.plugin_manager = SimpleNamespace( | ||
| context=SimpleNamespace(get_using_stt_provider=lambda _: stt_provider) | ||
| ) | ||
| return stage | ||
|
|
||
|
|
||
| def _make_event(messages: list) -> MagicMock: | ||
| event = MagicMock() | ||
| event.get_platform_name.return_value = "test" | ||
| event.is_at_or_wake_command = False | ||
| event.get_messages.return_value = messages | ||
| event.unified_msg_origin = "test:friend:test" | ||
| event.message_str = "" | ||
| event.message_obj.message_str = "" | ||
| return event | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_failed_audio_conversion_is_not_sent_to_stt(monkeypatch): | ||
| failed_record = Record(file="failed.amr") | ||
| valid_record = Record(file="valid.wav") | ||
| messages = [failed_record, valid_record] | ||
| stt_provider = AsyncMock() | ||
| stt_provider.get_text.return_value = "transcribed" | ||
|
|
||
| async def convert_to_file_path(record): | ||
| return record.file | ||
|
|
||
| async def convert_to_wav(path): | ||
| if path == "failed.amr": | ||
| raise RuntimeError("ffmpeg not found") | ||
| return path | ||
|
|
||
| monkeypatch.setattr(Record, "convert_to_file_path", convert_to_file_path) | ||
| monkeypatch.setattr( | ||
| "astrbot.core.pipeline.preprocess_stage.stage.ensure_wav", | ||
| convert_to_wav, | ||
| ) | ||
|
|
||
| await _make_stage(stt_provider).process(_make_event(messages)) | ||
|
|
||
| assert messages[0] is failed_record | ||
| assert isinstance(messages[1], Plain) | ||
| assert messages[1].text == "transcribed" | ||
| stt_provider.get_text.assert_awaited_once_with(audio_url="valid.wav") | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_failed_reply_audio_conversion_is_not_sent_to_stt(monkeypatch): | ||
| failed_record = Record(file="failed.amr") | ||
| reply = Reply(id="reply-id", chain=[failed_record]) | ||
| stt_provider = AsyncMock() | ||
|
|
||
| async def convert_to_file_path(record): | ||
| return record.file | ||
|
|
||
| async def convert_to_wav(_): | ||
| raise RuntimeError("ffmpeg not found") | ||
|
|
||
| monkeypatch.setattr(Record, "convert_to_file_path", convert_to_file_path) | ||
| monkeypatch.setattr( | ||
| "astrbot.core.pipeline.preprocess_stage.stage.ensure_wav", | ||
| convert_to_wav, | ||
| ) | ||
|
|
||
| await _make_stage(stt_provider).process(_make_event([reply])) | ||
|
|
||
| assert reply.chain == [failed_record] | ||
| stt_provider.get_text.assert_not_awaited() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication between processing direct
Recordcomponents (lines 70-84) and processingRecordcomponents insideReplychains (lines 87-103). Both blocks perform the exact same WAV conversion, temporary file tracking, and error handling.Following the general rule to avoid code duplication when implementing similar functionality for direct vs. quoted attachments, we should refactor this logic into a shared helper function.
Here is an example of how you can refactor this:
Then, the loops can be simplified to:
References