diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2f19434c9d..48d33d634e 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -28,6 +28,8 @@ import sys import uuid from enum import Enum +from pathlib import Path +from urllib.parse import unquote, urlparse if sys.version_info >= (3, 14): from pydantic import BaseModel @@ -133,12 +135,43 @@ def fromFileSystem(path, **_): def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) - raise Exception("not a valid url") + raise ValueError("not a valid url") @staticmethod def fromBase64(bs64_data: str, **_): return Record(file=f"base64://{bs64_data}", **_) + @staticmethod + def _get_audio_suffix(url: str) -> str: + suffix = Path(unquote(urlparse(url).path)).suffix + return suffix or ".amr" + + def _resolve_audio_source(self) -> str: + source = self.url or self.file + if not source: + raise ValueError("No valid file or URL provided") + return source + + async def _download_audio_url(self, url: str) -> str: + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + file_path = ( + temp_dir / f"recordseg_{uuid.uuid4().hex}{self._get_audio_suffix(url)}" + ) + await download_file(url, str(file_path)) + if file_path.exists(): + return str(file_path.resolve()) + raise RuntimeError(f"download failed: {url}") + + def _write_base64_audio_to_file(self, url: str) -> str: + bs64_data = url.removeprefix("base64://") + audio_bytes = base64.b64decode(bs64_data) + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + file_path = temp_dir / f"recordseg_{uuid.uuid4().hex}.amr" + file_path.write_bytes(audio_bytes) + return str(file_path.resolve()) + async def convert_to_file_path(self) -> str: """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 @@ -146,25 +179,16 @@ async def convert_to_file_path(self) -> str: str: 语音的本地路径,以绝对路径表示。 """ - if not self.file: - raise Exception(f"not a valid file: {self.file}") - if self.file.startswith("file:///"): - return self.file[8:] - if self.file.startswith("http"): - file_path = await download_image_by_url(self.file) - return os.path.abspath(file_path) - if self.file.startswith("base64://"): - bs64_data = self.file.removeprefix("base64://") - image_bytes = base64.b64decode(bs64_data) - file_path = os.path.join( - get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" - ) - with open(file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(file_path) - if os.path.exists(self.file): - return os.path.abspath(self.file) - raise Exception(f"not a valid file: {self.file}") + url = self._resolve_audio_source() + if url.startswith("file:///"): + return url[8:] + if url.startswith("http"): + return await self._download_audio_url(url) + if url.startswith("base64://"): + return self._write_base64_audio_to_file(url) + if os.path.exists(url): + return os.path.abspath(url) + raise FileNotFoundError(f"not a valid file: {url}") async def convert_to_base64(self) -> str: """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 @@ -174,19 +198,18 @@ async def convert_to_base64(self) -> str: """ # convert to base64 - if not self.file: - raise Exception(f"not a valid file: {self.file}") - if self.file.startswith("file:///"): - bs64_data = file_to_base64(self.file[8:]) - elif self.file.startswith("http"): - file_path = await download_image_by_url(self.file) + url = self._resolve_audio_source() + if url.startswith("file:///"): + bs64_data = file_to_base64(url[8:]) + elif url.startswith("http"): + file_path = await self._download_audio_url(url) bs64_data = file_to_base64(file_path) - elif self.file.startswith("base64://"): - bs64_data = self.file - elif os.path.exists(self.file): - bs64_data = file_to_base64(self.file) + elif url.startswith("base64://"): + bs64_data = url + elif os.path.exists(url): + bs64_data = file_to_base64(url) else: - raise Exception(f"not a valid file: {self.file}") + raise FileNotFoundError(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data diff --git a/tests/unit/test_message_record_component.py b/tests/unit/test_message_record_component.py new file mode 100644 index 0000000000..47eb97dc8d --- /dev/null +++ b/tests/unit/test_message_record_component.py @@ -0,0 +1,81 @@ +import base64 +from pathlib import Path + +import pytest + +import astrbot.core.message.components as components +from astrbot.core.message.components import Record + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_prefers_url_when_file_is_name( + monkeypatch, + tmp_path, +): + calls: list[tuple[str, Path]] = [] + audio_bytes = b"audio-content" + audio_url = "http://napcat.local/nt_data/Ptt/2026-04/Ori/voice.amr" + + async def fake_download_file(url: str, path: str, *_, **__) -> None: + target = Path(path) + calls.append((url, target)) + target.write_bytes(audio_bytes) + + monkeypatch.setattr(components, "download_file", fake_download_file) + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(tmp_path)) + + record = Record(file="voice.amr", url=audio_url) + + file_path = Path(await record.convert_to_file_path()) + + assert file_path.read_bytes() == audio_bytes + assert file_path.suffix == ".amr" + assert calls == [(audio_url, file_path)] + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_prefers_url_when_file_is_name( + monkeypatch, + tmp_path, +): + audio_bytes = b"audio-content" + audio_url = "http://napcat.local/nt_data/Ptt/2026-04/Ori/voice.amr" + + async def fake_download_file(url: str, path: str, *_, **__) -> None: + assert url == audio_url + Path(path).write_bytes(audio_bytes) + + monkeypatch.setattr(components, "download_file", fake_download_file) + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(tmp_path)) + + record = Record(file="voice.amr", url=audio_url) + + assert await record.convert_to_base64() == base64.b64encode(audio_bytes).decode() + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_prefers_base64_url_when_file_is_name( + monkeypatch, + tmp_path, +): + audio_bytes = b"audio-content" + audio_url = f"base64://{base64.b64encode(audio_bytes).decode()}" + + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(tmp_path)) + + record = Record(file="voice.amr", url=audio_url) + + file_path = Path(await record.convert_to_file_path()) + + assert file_path.read_bytes() == audio_bytes + assert file_path.suffix == ".amr" + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_prefers_base64_url_when_file_is_name(): + audio_bytes = b"audio-content" + audio_base64 = base64.b64encode(audio_bytes).decode() + + record = Record(file="voice.amr", url=f"base64://{audio_base64}") + + assert await record.convert_to_base64() == audio_base64