Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,51 @@ async def _audio_ref_to_local_path(self, audio_ref: str) -> tuple[str, list[Path
return str(target_path), cleanup_paths
if audio_ref.startswith("file://"):
return self._file_uri_to_path(audio_ref), cleanup_paths
if audio_ref.startswith("data:"):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the general rules, new functionality such as handling attachments should be accompanied by corresponding unit tests. Please add unit tests to verify that data URIs are correctly parsed, decoded, and saved to temporary files.

References
  1. New functionality, such as handling attachments, should be accompanied by corresponding unit tests.

# data URI 格式: data:audio/wav;base64,<base64_data>
# 或 data:audio/mpeg;base64,<base64_data>
try:
# 防止过大的 base64 payload 导致内存耗尽
if len(audio_ref) > 10 * 1024 * 1024: # 10 MB
logger.warning(
"data URI 音频过大 (%.1f MB),将忽略",
len(audio_ref) / (1024 * 1024),
)
return audio_ref, cleanup_paths

header, base64_data = audio_ref.split(",", 1)

# 从 data URI header 中提取 MIME 类型
mime_parts = header.removeprefix("data:").split(";")
mime_type = mime_parts[0] if mime_parts else "audio/wav"

# 使用显式映射避免 audio/mpeg 被错误转换为 .mpeg 而非 .mp3
mime_to_suffix = {
"audio/wav": ".wav",
"audio/x-wav": ".wav",
"audio/mpeg": ".mp3",
"audio/mp3": ".mp3",
"audio/ogg": ".ogg",
"audio/m4a": ".m4a",
"audio/x-m4a": ".m4a",
"audio/aac": ".aac",
"audio/flac": ".flac",
"audio/x-flac": ".flac",
}
suffix = mime_to_suffix.get(mime_type, ".wav")

audio_bytes = base64.b64decode(base64_data)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
target_path = temp_dir / f"provider_audio_{uuid.uuid4().hex}{suffix}"
target_path.write_bytes(audio_bytes)
cleanup_paths.append(target_path)
return str(target_path), cleanup_paths
except (ValueError, binascii.Error, OSError) as exc:
logger.warning(
"解析 data URI 音频失败: %s,错误: %s", audio_ref[:100], exc
)
return audio_ref, cleanup_paths
Comment thread
tjc6666666666666 marked this conversation as resolved.
return audio_ref, cleanup_paths

async def _resolve_audio_part(self, audio_ref: str) -> dict | None:
Expand Down
164 changes: 164 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import builtins
from io import BytesIO
from pathlib import Path
from types import SimpleNamespace

import pytest
Expand Down Expand Up @@ -1809,3 +1810,166 @@ async def fake_create(**kwargs):
assert messages[1] == {"role": "user", "content": "again"}
finally:
await provider.terminate()


# ---------------------------------------------------------------------------
# _audio_ref_to_local_path - data URI 处理测试
# ---------------------------------------------------------------------------


def _minimal_wav_bytes() -> bytes:
"""生成一个最小的有效 WAV 头(44 字节空数据)。"""
import struct

data_size = 0
file_size = 36 + data_size # 4 + 24 + 8 + data_size
header = bytearray(44)
header[0:4] = b"RIFF"
struct.pack_into("<I", header, 4, file_size)
header[8:12] = b"WAVE"
header[12:16] = b"fmt "
struct.pack_into("<I", header, 16, 16) # PCM chunk size
struct.pack_into("<H", header, 20, 1) # PCM format
struct.pack_into("<H", header, 22, 1) # mono
struct.pack_into("<I", header, 24, 44100) # sample rate
struct.pack_into("<I", header, 28, 88200) # byte rate
struct.pack_into("<H", header, 32, 2) # block align
struct.pack_into("<H", header, 34, 16) # bits per sample
header[36:40] = b"data"
struct.pack_into("<I", header, 40, data_size)
return bytes(header)


@pytest.mark.asyncio
async def test_audio_ref_data_uri_wav(monkeypatch, tmp_path):
"""data:audio/wav;base64 URI 应被解码为临时 .wav 文件。"""
wav_data = _minimal_wav_bytes()
b64_data = base64.b64encode(wav_data).decode("utf-8")
data_uri = f"data:audio/wav;base64,{b64_data}"

provider = _make_provider()
try:
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path

monkeypatch.setattr(
"astrbot.core.provider.sources.openai_source.get_astrbot_temp_path",
lambda: str(tmp_path),
)

audio_path, cleanup = await provider._audio_ref_to_local_path(data_uri)
assert cleanup, "应该返回需要清理的临时文件路径"
assert audio_path.endswith(".wav"), f"后缀应为 .wav,实际为 {audio_path}"
assert Path(audio_path).parent == tmp_path
assert Path(audio_path).read_bytes() == wav_data, (
"解码后的内容应与原始 WAV 数据一致"
)
# 清理
for p in cleanup:
p.unlink(missing_ok=True)
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_audio_ref_data_uri_mpeg_uses_mp3_suffix(monkeypatch, tmp_path):
"""data:audio/mpeg;base64 应使用 .mp3 后缀,而非 .mpeg。"""
wav_data = _minimal_wav_bytes()
b64_data = base64.b64encode(wav_data).decode("utf-8")
data_uri = f"data:audio/mpeg;base64,{b64_data}"

provider = _make_provider()
try:
monkeypatch.setattr(
"astrbot.core.provider.sources.openai_source.get_astrbot_temp_path",
lambda: str(tmp_path),
)

audio_path, cleanup = await provider._audio_ref_to_local_path(data_uri)
assert cleanup
assert audio_path.endswith(".mp3"), (
f"audio/mpeg 应映射为 .mp3,实际后缀为 {Path(audio_path).suffix}"
)
for p in cleanup:
p.unlink(missing_ok=True)
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_audio_ref_data_uri_unknown_mime_falls_back_to_wav(monkeypatch, tmp_path):
"""未知 MIME 类型应回退为 .wav 后缀。"""
wav_data = _minimal_wav_bytes()
b64_data = base64.b64encode(wav_data).decode("utf-8")
data_uri = f"data:audio/x-unknown;base64,{b64_data}"

provider = _make_provider()
try:
monkeypatch.setattr(
"astrbot.core.provider.sources.openai_source.get_astrbot_temp_path",
lambda: str(tmp_path),
)

audio_path, cleanup = await provider._audio_ref_to_local_path(data_uri)
assert cleanup
assert audio_path.endswith(".wav"), (
f"未知 MIME 类型应回退为 .wav,实际为 {Path(audio_path).suffix}"
)
for p in cleanup:
p.unlink(missing_ok=True)
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_audio_ref_data_uri_invalid_base64_returns_as_is():
"""无效 base64 的 data URI 应原样返回且不崩溃。"""
data_uri = "data:audio/wav;base64,!!!not_valid_base64!!!"

provider = _make_provider()
try:
audio_path, cleanup = await provider._audio_ref_to_local_path(data_uri)
assert not cleanup, "无效数据不应产生清理文件"
assert audio_path == data_uri, "无效 data URI 应原样返回"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_audio_ref_data_uri_oversized_rejected():
"""超过 10 MB 的 data URI 应被拒绝。"""
# 创建一个 ~11 MB 的 data URI
large_b64 = "A" * (11 * 1024 * 1024) # base64 字符串约 11 MB
data_uri = f"data:audio/wav;base64,{large_b64}"

provider = _make_provider()
try:
audio_path, cleanup = await provider._audio_ref_to_local_path(data_uri)
assert not cleanup, "超大 payload 不应产生临时文件"
assert audio_path == data_uri, "超大 data URI 应原样返回"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_audio_ref_data_uri_ogg_suffix(monkeypatch, tmp_path):
"""data:audio/ogg;base64 应使用 .ogg 后缀。"""
wav_data = _minimal_wav_bytes()
b64_data = base64.b64encode(wav_data).decode("utf-8")
data_uri = f"data:audio/ogg;base64,{b64_data}"

provider = _make_provider()
try:
monkeypatch.setattr(
"astrbot.core.provider.sources.openai_source.get_astrbot_temp_path",
lambda: str(tmp_path),
)

audio_path, cleanup = await provider._audio_ref_to_local_path(data_uri)
assert cleanup
assert audio_path.endswith(".ogg"), (
f"audio/ogg 应使用 .ogg 后缀,实际为 {Path(audio_path).suffix}"
)
for p in cleanup:
p.unlink(missing_ok=True)
finally:
await provider.terminate()
Loading