From 388f6aadc70f34ae99bc21d5279416689e3b222c Mon Sep 17 00:00:00 2001 From: bugkeep <1921817430@qq.com> Date: Wed, 22 Apr 2026 15:29:29 +0800 Subject: [PATCH 1/2] fix: prefer record url for audio downloads --- astrbot/core/message/components.py | 78 +++++++++++++-------- tests/unit/test_message_record_component.py | 53 ++++++++++++++ 2 files changed, 101 insertions(+), 30 deletions(-) create mode 100644 tests/unit/test_message_record_component.py diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2f19434c9d..2637b19433 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -28,6 +28,8 @@ import sys import uuid from enum import Enum +from pathlib import Path +from urllib.parse import unquote, urlparse if sys.version_info >= (3, 14): from pydantic import BaseModel @@ -139,6 +141,22 @@ def fromURL(url: str, **_): def fromBase64(bs64_data: str, **_): return Record(file=f"base64://{bs64_data}", **_) + @staticmethod + def _get_audio_suffix(url: str) -> str: + suffix = Path(unquote(urlparse(url).path)).suffix + return suffix or ".amr" + + async def _download_audio_url(self, url: str) -> str: + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + file_path = ( + temp_dir / f"recordseg_{uuid.uuid4().hex}{self._get_audio_suffix(url)}" + ) + await download_file(url, str(file_path)) + if file_path.exists(): + return str(file_path.resolve()) + raise Exception(f"download failed: {url}") + async def convert_to_file_path(self) -> str: """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 @@ -146,25 +164,24 @@ async def convert_to_file_path(self) -> str: str: 语音的本地路径,以绝对路径表示。 """ - if not self.file: - raise Exception(f"not a valid file: {self.file}") - if self.file.startswith("file:///"): - return self.file[8:] - if self.file.startswith("http"): - file_path = await download_image_by_url(self.file) - return os.path.abspath(file_path) - if self.file.startswith("base64://"): - bs64_data = self.file.removeprefix("base64://") - image_bytes = base64.b64decode(bs64_data) - file_path = os.path.join( - get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" - ) - with open(file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(file_path) - if os.path.exists(self.file): - return os.path.abspath(self.file) - raise Exception(f"not a valid file: {self.file}") + url = self.url or self.file + if not url: + raise Exception(f"not a valid file: {url}") + if url.startswith("file:///"): + return url[8:] + if url.startswith("http"): + return await self._download_audio_url(url) + if url.startswith("base64://"): + bs64_data = url.removeprefix("base64://") + audio_bytes = base64.b64decode(bs64_data) + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + file_path = temp_dir / f"recordseg_{uuid.uuid4().hex}.amr" + file_path.write_bytes(audio_bytes) + return str(file_path.resolve()) + if os.path.exists(url): + return os.path.abspath(url) + raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 @@ -174,19 +191,20 @@ async def convert_to_base64(self) -> str: """ # convert to base64 - if not self.file: - raise Exception(f"not a valid file: {self.file}") - if self.file.startswith("file:///"): - bs64_data = file_to_base64(self.file[8:]) - elif self.file.startswith("http"): - file_path = await download_image_by_url(self.file) + url = self.url or self.file + if not url: + raise Exception(f"not a valid file: {url}") + if url.startswith("file:///"): + bs64_data = file_to_base64(url[8:]) + elif url.startswith("http"): + file_path = await self._download_audio_url(url) bs64_data = file_to_base64(file_path) - elif self.file.startswith("base64://"): - bs64_data = self.file - elif os.path.exists(self.file): - bs64_data = file_to_base64(self.file) + elif url.startswith("base64://"): + bs64_data = url + elif os.path.exists(url): + bs64_data = file_to_base64(url) else: - raise Exception(f"not a valid file: {self.file}") + raise Exception(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data diff --git a/tests/unit/test_message_record_component.py b/tests/unit/test_message_record_component.py new file mode 100644 index 0000000000..b8dbaeff64 --- /dev/null +++ b/tests/unit/test_message_record_component.py @@ -0,0 +1,53 @@ +import base64 +from pathlib import Path + +import pytest + +import astrbot.core.message.components as components +from astrbot.core.message.components import Record + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_prefers_url_when_file_is_name( + monkeypatch, + tmp_path, +): + calls: list[tuple[str, Path]] = [] + audio_bytes = b"audio-content" + audio_url = "http://napcat.local/nt_data/Ptt/2026-04/Ori/voice.amr" + + async def fake_download_file(url: str, path: str, *_, **__) -> None: + target = Path(path) + calls.append((url, target)) + target.write_bytes(audio_bytes) + + monkeypatch.setattr(components, "download_file", fake_download_file) + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(tmp_path)) + + record = Record(file="voice.amr", url=audio_url) + + file_path = Path(await record.convert_to_file_path()) + + assert file_path.read_bytes() == audio_bytes + assert file_path.suffix == ".amr" + assert calls == [(audio_url, file_path)] + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_prefers_url_when_file_is_name( + monkeypatch, + tmp_path, +): + audio_bytes = b"audio-content" + audio_url = "http://napcat.local/nt_data/Ptt/2026-04/Ori/voice.amr" + + async def fake_download_file(url: str, path: str, *_, **__) -> None: + assert url == audio_url + Path(path).write_bytes(audio_bytes) + + monkeypatch.setattr(components, "download_file", fake_download_file) + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(tmp_path)) + + record = Record(file="voice.amr", url=audio_url) + + assert await record.convert_to_base64() == base64.b64encode(audio_bytes).decode() From e2ba5e90dabb7f6b161958ccee2f79cb12526d55 Mon Sep 17 00:00:00 2001 From: bugkeep <1921817430@qq.com> Date: Thu, 23 Apr 2026 12:56:41 +0800 Subject: [PATCH 2/2] fix: harden record audio source handling --- astrbot/core/message/components.py | 39 ++++++++++++--------- tests/unit/test_message_record_component.py | 28 +++++++++++++++ 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2637b19433..48d33d634e 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -135,7 +135,7 @@ def fromFileSystem(path, **_): def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) - raise Exception("not a valid url") + raise ValueError("not a valid url") @staticmethod def fromBase64(bs64_data: str, **_): @@ -146,6 +146,12 @@ def _get_audio_suffix(url: str) -> str: suffix = Path(unquote(urlparse(url).path)).suffix return suffix or ".amr" + def _resolve_audio_source(self) -> str: + source = self.url or self.file + if not source: + raise ValueError("No valid file or URL provided") + return source + async def _download_audio_url(self, url: str) -> str: temp_dir = Path(get_astrbot_temp_path()) temp_dir.mkdir(parents=True, exist_ok=True) @@ -155,7 +161,16 @@ async def _download_audio_url(self, url: str) -> str: await download_file(url, str(file_path)) if file_path.exists(): return str(file_path.resolve()) - raise Exception(f"download failed: {url}") + raise RuntimeError(f"download failed: {url}") + + def _write_base64_audio_to_file(self, url: str) -> str: + bs64_data = url.removeprefix("base64://") + audio_bytes = base64.b64decode(bs64_data) + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + file_path = temp_dir / f"recordseg_{uuid.uuid4().hex}.amr" + file_path.write_bytes(audio_bytes) + return str(file_path.resolve()) async def convert_to_file_path(self) -> str: """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 @@ -164,24 +179,16 @@ async def convert_to_file_path(self) -> str: str: 语音的本地路径,以绝对路径表示。 """ - url = self.url or self.file - if not url: - raise Exception(f"not a valid file: {url}") + url = self._resolve_audio_source() if url.startswith("file:///"): return url[8:] if url.startswith("http"): return await self._download_audio_url(url) if url.startswith("base64://"): - bs64_data = url.removeprefix("base64://") - audio_bytes = base64.b64decode(bs64_data) - temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) - file_path = temp_dir / f"recordseg_{uuid.uuid4().hex}.amr" - file_path.write_bytes(audio_bytes) - return str(file_path.resolve()) + return self._write_base64_audio_to_file(url) if os.path.exists(url): return os.path.abspath(url) - raise Exception(f"not a valid file: {url}") + raise FileNotFoundError(f"not a valid file: {url}") async def convert_to_base64(self) -> str: """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 @@ -191,9 +198,7 @@ async def convert_to_base64(self) -> str: """ # convert to base64 - url = self.url or self.file - if not url: - raise Exception(f"not a valid file: {url}") + url = self._resolve_audio_source() if url.startswith("file:///"): bs64_data = file_to_base64(url[8:]) elif url.startswith("http"): @@ -204,7 +209,7 @@ async def convert_to_base64(self) -> str: elif os.path.exists(url): bs64_data = file_to_base64(url) else: - raise Exception(f"not a valid file: {url}") + raise FileNotFoundError(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data diff --git a/tests/unit/test_message_record_component.py b/tests/unit/test_message_record_component.py index b8dbaeff64..47eb97dc8d 100644 --- a/tests/unit/test_message_record_component.py +++ b/tests/unit/test_message_record_component.py @@ -51,3 +51,31 @@ async def fake_download_file(url: str, path: str, *_, **__) -> None: record = Record(file="voice.amr", url=audio_url) assert await record.convert_to_base64() == base64.b64encode(audio_bytes).decode() + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_prefers_base64_url_when_file_is_name( + monkeypatch, + tmp_path, +): + audio_bytes = b"audio-content" + audio_url = f"base64://{base64.b64encode(audio_bytes).decode()}" + + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(tmp_path)) + + record = Record(file="voice.amr", url=audio_url) + + file_path = Path(await record.convert_to_file_path()) + + assert file_path.read_bytes() == audio_bytes + assert file_path.suffix == ".amr" + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_prefers_base64_url_when_file_is_name(): + audio_bytes = b"audio-content" + audio_base64 = base64.b64encode(audio_bytes).decode() + + record = Record(file="voice.amr", url=f"base64://{audio_base64}") + + assert await record.convert_to_base64() == audio_base64