diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 8aa2778f1b..34fbf398b5 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -367,6 +367,51 @@ async def _audio_ref_to_local_path(self, audio_ref: str) -> tuple[str, list[Path return str(target_path), cleanup_paths if audio_ref.startswith("file://"): return self._file_uri_to_path(audio_ref), cleanup_paths + if audio_ref.startswith("data:"): + # data URI 格式: data:audio/wav;base64, + # 或 data:audio/mpeg;base64, + try: + # 防止过大的 base64 payload 导致内存耗尽 + if len(audio_ref) > 10 * 1024 * 1024: # 10 MB + logger.warning( + "data URI 音频过大 (%.1f MB),将忽略", + len(audio_ref) / (1024 * 1024), + ) + return audio_ref, cleanup_paths + + header, base64_data = audio_ref.split(",", 1) + + # 从 data URI header 中提取 MIME 类型 + mime_parts = header.removeprefix("data:").split(";") + mime_type = mime_parts[0] if mime_parts else "audio/wav" + + # 使用显式映射避免 audio/mpeg 被错误转换为 .mpeg 而非 .mp3 + mime_to_suffix = { + "audio/wav": ".wav", + "audio/x-wav": ".wav", + "audio/mpeg": ".mp3", + "audio/mp3": ".mp3", + "audio/ogg": ".ogg", + "audio/m4a": ".m4a", + "audio/x-m4a": ".m4a", + "audio/aac": ".aac", + "audio/flac": ".flac", + "audio/x-flac": ".flac", + } + suffix = mime_to_suffix.get(mime_type, ".wav") + + audio_bytes = base64.b64decode(base64_data) + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + target_path = temp_dir / f"provider_audio_{uuid.uuid4().hex}{suffix}" + target_path.write_bytes(audio_bytes) + cleanup_paths.append(target_path) + return str(target_path), cleanup_paths + except (ValueError, binascii.Error, OSError) as exc: + logger.warning( + "解析 data URI 音频失败: %s,错误: %s", audio_ref[:100], exc + ) + return audio_ref, cleanup_paths return audio_ref, cleanup_paths async def _resolve_audio_part(self, audio_ref: str) -> dict | None: diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index b5587ffb14..5fb7a60482 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,6 +1,7 @@ import base64 import builtins from io import BytesIO +from pathlib import Path from types import SimpleNamespace import pytest @@ -1809,3 +1810,166 @@ async def fake_create(**kwargs): assert messages[1] == {"role": "user", "content": "again"} finally: await provider.terminate() + + +# --------------------------------------------------------------------------- +# _audio_ref_to_local_path - data URI 处理测试 +# --------------------------------------------------------------------------- + + +def _minimal_wav_bytes() -> bytes: + """生成一个最小的有效 WAV 头(44 字节空数据)。""" + import struct + + data_size = 0 + file_size = 36 + data_size # 4 + 24 + 8 + data_size + header = bytearray(44) + header[0:4] = b"RIFF" + struct.pack_into("