diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 87c9150575..fd7d13826b 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -24,6 +24,7 @@ ) from ..register import register_provider_adapter +from .request_retry import retry_provider_request, retry_provider_request_context @register_provider_adapter( @@ -366,8 +367,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: self._apply_thinking_config(payloads) try: - completion = await self.client.messages.create( - **payloads, stream=False, extra_body=extra_body + completion = await retry_provider_request( + "Anthropic", + lambda: self.client.messages.create( + **payloads, stream=False, extra_body=extra_body + ), ) except httpx.RequestError as e: proxy = self.provider_config.get("proxy", "") @@ -459,8 +463,9 @@ async def _query_stream( payloads["max_tokens"] = 65536 self._apply_thinking_config(payloads) - async with self.client.messages.stream( - **payloads, extra_body=extra_body + async with retry_provider_request_context( + "Anthropic", + lambda: self.client.messages.stream(**payloads, extra_body=extra_body), ) as stream: assert isinstance(stream, anthropic.AsyncMessageStream) async for event in stream: @@ -838,7 +843,10 @@ def get_current_key(self) -> str: async def get_models(self) -> list[str]: models_str = [] - models = await self.client.models.list() + models = await retry_provider_request( + "Anthropic", + lambda: self.client.models.list(), + ) models = sorted(models.data, key=lambda x: x.id) for model in models: models_str.append(model.id) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index f38fcfc359..0a4e3beecb 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -28,6 +28,7 @@ from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure from ..register import register_provider_adapter +from .request_retry import retry_provider_request class SuppressNonTextPartsWarning(logging.Filter): @@ -630,10 +631,14 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: modalities, temperature, ) - result = await self.client.models.generate_content( - model=model, - contents=cast(types.ContentListUnion, conversation), - config=config, + result = await retry_provider_request( + "Gemini", + lambda: self.client.models.generate_content( + model=model, + contents=cast(types.ContentListUnion, conversation), + config=config, + ), + retry_rate_limits=False, ) logger.debug(f"genai result: {result}") @@ -710,10 +715,14 @@ async def _query_stream( payloads.get("tool_choice", "auto"), system_instruction, ) - result = await self.client.models.generate_content_stream( - model=model, - contents=cast(types.ContentListUnion, conversation), - config=config, + result = await retry_provider_request( + "Gemini", + lambda: self.client.models.generate_content_stream( + model=model, + contents=cast(types.ContentListUnion, conversation), + config=config, + ), + retry_rate_limits=False, ) break except APIError as e: @@ -940,7 +949,10 @@ async def text_chat_stream( async def get_models(self): try: - models = await self.client.models.list() + models = await retry_provider_request( + "Gemini", + lambda: self.client.models.list(), + ) return [ m.name.replace("models/", "") for m in models diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 8aa2778f1b..cb8c6bb786 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -48,6 +48,7 @@ from astrbot.core.utils.string_utils import normalize_and_dedupe_strings from ..register import register_provider_adapter +from .request_retry import retry_provider_request @register_provider_adapter( @@ -560,7 +561,10 @@ def _apply_provider_specific_extra_body_overrides( async def get_models(self): try: models_str = [] - models = await self.client.models.list() + models = await retry_provider_request( + "OpenAI", + lambda: self.client.models.list(), + ) models = sorted(models.data, key=lambda x: x.id) for model in models: models_str.append(model.id) @@ -636,10 +640,14 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: self._sanitize_assistant_messages(payloads) - completion = await self.client.chat.completions.create( - **payloads, - stream=False, - extra_body=extra_body, + completion = await retry_provider_request( + "OpenAI", + lambda: self.client.chat.completions.create( + **payloads, + stream=False, + extra_body=extra_body, + ), + retry_rate_limits=False, ) if not isinstance(completion, ChatCompletion): @@ -688,11 +696,15 @@ async def _query_stream( self._sanitize_assistant_messages(payloads) - stream = await self.client.chat.completions.create( - **payloads, - stream=True, - extra_body=extra_body, - stream_options={"include_usage": True}, + stream = await retry_provider_request( + "OpenAI", + lambda: self.client.chat.completions.create( + **payloads, + stream=True, + extra_body=extra_body, + stream_options={"include_usage": True}, + ), + retry_rate_limits=False, ) llm_response = LLMResponse("assistant", is_chunk=True) diff --git a/astrbot/core/provider/sources/request_retry.py b/astrbot/core/provider/sources/request_retry.py new file mode 100644 index 0000000000..ddbed21c09 --- /dev/null +++ b/astrbot/core/provider/sources/request_retry.py @@ -0,0 +1,141 @@ +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import TypeVar + +from tenacity import ( + AsyncRetrying, + RetryCallState, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) + +from astrbot import logger +from astrbot.core.utils.network_utils import is_connection_error + +T = TypeVar("T") + +REQUEST_RETRY_ATTEMPTS = 5 +REQUEST_RETRY_WAIT_MIN_S = 1 +REQUEST_RETRY_WAIT_MAX_S = 8 +REQUEST_RETRY_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504, 529} + + +def _get_status_code(error: BaseException) -> int | None: + for attr in ("status_code", "status", "code"): + value = getattr(error, attr, None) + if isinstance(value, int): + return value + + response = getattr(error, "response", None) + if response is not None: + status_code = getattr(response, "status_code", None) + if isinstance(status_code, int): + return status_code + + return None + + +def _is_retryable_provider_request_error( + error: BaseException, + *, + retry_rate_limits: bool, +) -> bool: + if is_connection_error(error): + return True + + error_type_name = type(error).__name__ + if error_type_name in {"APIConnectionError", "APITimeoutError"}: + return True + + status_code = _get_status_code(error) + if status_code is None: + return False + + if status_code == 429 and not retry_rate_limits: + return False + + return status_code in REQUEST_RETRY_STATUS_CODES or 500 <= status_code <= 599 + + +def _log_retry(provider_label: str, retry_state: RetryCallState) -> None: + error = retry_state.outcome.exception() if retry_state.outcome else None + logger.warning( + f"[{provider_label}] Request failed with retryable error; " + f"retrying ({retry_state.attempt_number + 1}/{REQUEST_RETRY_ATTEMPTS}): " + f"{error}" + ) + + +def _build_retrying( + provider_label: str, + *, + retry_rate_limits: bool, +) -> AsyncRetrying: + return AsyncRetrying( + retry=retry_if_exception( + lambda error: _is_retryable_provider_request_error( + error, + retry_rate_limits=retry_rate_limits, + ) + ), + stop=stop_after_attempt(REQUEST_RETRY_ATTEMPTS), + wait=wait_exponential( + multiplier=1, + min=REQUEST_RETRY_WAIT_MIN_S, + max=REQUEST_RETRY_WAIT_MAX_S, + ), + before_sleep=lambda retry_state: _log_retry(provider_label, retry_state), + reraise=True, + ) + + +async def retry_provider_request( + provider_label: str, + request_factory: Callable[[], Awaitable[T]], + *, + retry_rate_limits: bool = True, +) -> T: + retrying = _build_retrying( + provider_label, + retry_rate_limits=retry_rate_limits, + ) + + async for attempt in retrying: + with attempt: + return await request_factory() + + raise RuntimeError("Provider request retry loop exited unexpectedly.") + + +@asynccontextmanager +async def retry_provider_request_context( + provider_label: str, + context_manager_factory: Callable[[], AbstractAsyncContextManager[T]], + *, + retry_rate_limits: bool = True, +) -> AsyncIterator[T]: + manager: AbstractAsyncContextManager[T] | None = None + + async def _enter_context() -> T: + nonlocal manager + manager = context_manager_factory() + return await manager.__aenter__() + + value = await retry_provider_request( + provider_label, + _enter_context, + retry_rate_limits=retry_rate_limits, + ) + + if manager is None: + raise RuntimeError("Provider request context was not created.") + + try: + yield value + except BaseException as error: + if await manager.__aexit__(type(error), error, error.__traceback__): + return + raise + else: + await manager.__aexit__(None, None, None) diff --git a/tests/test_anthropic_kimi_code_provider.py b/tests/test_anthropic_kimi_code_provider.py index 3958550269..0ae61f61a0 100644 --- a/tests/test_anthropic_kimi_code_provider.py +++ b/tests/test_anthropic_kimi_code_provider.py @@ -1,9 +1,12 @@ import builtins +from types import SimpleNamespace +import httpx import pytest import astrbot.core.provider.sources.anthropic_source as anthropic_source import astrbot.core.provider.sources.kimi_code_source as kimi_code_source +import astrbot.core.provider.sources.request_retry as request_retry from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse @@ -171,6 +174,36 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0): assert captured["httpx_module"] is anthropic_source.httpx +@pytest.mark.asyncio +async def test_anthropic_get_models_retries_transient_request_error(monkeypatch): + monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0) + monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0) + + class FakeModels: + def __init__(self): + self.calls = 0 + + async def list(self): + self.calls += 1 + if self.calls == 1: + raise httpx.ConnectError("temporary connection failure") + return SimpleNamespace( + data=[ + SimpleNamespace(id="claude-b"), + SimpleNamespace(id="claude-a"), + ] + ) + + models = FakeModels() + provider = anthropic_source.ProviderAnthropic.__new__( + anthropic_source.ProviderAnthropic + ) + provider.client = SimpleNamespace(models=models) + + assert await provider.get_models() == ["claude-a", "claude-b"] + assert models.calls == 2 + + @pytest.mark.asyncio async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch): monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic) diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py index 4db8e92bfe..9294ea46b2 100644 --- a/tests/test_gemini_source.py +++ b/tests/test_gemini_source.py @@ -1,6 +1,10 @@ +from types import SimpleNamespace + +import httpx import pytest from astrbot.core.exceptions import EmptyModelOutputError +import astrbot.core.provider.sources.request_retry as request_retry from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI @@ -27,3 +31,35 @@ def test_gemini_reasoning_only_output_is_allowed(): response_id="resp_reasoning", finish_reason="STOP", ) + + +@pytest.mark.asyncio +async def test_gemini_get_models_retries_transient_request_error(monkeypatch): + monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0) + monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0) + + class FakeModels: + def __init__(self): + self.calls = 0 + + async def list(self): + self.calls += 1 + if self.calls == 1: + raise httpx.ConnectError("temporary connection failure") + return [ + SimpleNamespace( + name="models/gemini-a", + supported_actions=["generateContent"], + ), + SimpleNamespace( + name="models/gemini-b", + supported_actions=["embedContent"], + ), + ] + + models = FakeModels() + provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI) + provider.client = SimpleNamespace(models=models) + + assert await provider.get_models() == ["gemini-a"] + assert models.calls == 2 diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index b5587ffb14..478ee7d37b 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -3,12 +3,14 @@ from io import BytesIO from types import SimpleNamespace +import httpx import pytest from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from PIL import Image as PILImage import astrbot.core.provider.sources.openai_source as openai_source_module +import astrbot.core.provider.sources.request_retry as request_retry from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.sources.groq_source import ProviderGroq from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial @@ -116,6 +118,34 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0): assert captured["httpx_module"] is openai_source_module.httpx +@pytest.mark.asyncio +async def test_get_models_retries_transient_request_error(monkeypatch): + monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0) + monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0) + + class FakeModels: + def __init__(self): + self.calls = 0 + + async def list(self): + self.calls += 1 + if self.calls == 1: + raise httpx.ConnectError("temporary connection failure") + return SimpleNamespace( + data=[ + SimpleNamespace(id="gpt-b"), + SimpleNamespace(id="gpt-a"), + ] + ) + + models = FakeModels() + provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial) + provider.client = SimpleNamespace(models=models) + + assert await provider.get_models() == ["gpt-a", "gpt-b"] + assert models.calls == 2 + + @pytest.mark.asyncio async def test_handle_api_error_content_moderated_removes_images(): provider = _make_provider(