Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

from ..register import register_provider_adapter
from .request_retry import retry_provider_request, retry_provider_request_context


@register_provider_adapter(
Expand Down Expand Up @@ -366,8 +367,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
self._apply_thinking_config(payloads)

try:
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
completion = await retry_provider_request(
"Anthropic",
lambda: self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
),
)
except httpx.RequestError as e:
proxy = self.provider_config.get("proxy", "")
Expand Down Expand Up @@ -459,8 +463,9 @@ async def _query_stream(
payloads["max_tokens"] = 65536
self._apply_thinking_config(payloads)

async with self.client.messages.stream(
**payloads, extra_body=extra_body
async with retry_provider_request_context(
"Anthropic",
lambda: self.client.messages.stream(**payloads, extra_body=extra_body),
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
Expand Down Expand Up @@ -838,7 +843,10 @@ def get_current_key(self) -> str:

async def get_models(self) -> list[str]:
models_str = []
models = await self.client.models.list()
models = await retry_provider_request(
"Anthropic",
lambda: self.client.models.list(),
)
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
Expand Down
30 changes: 21 additions & 9 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure

from ..register import register_provider_adapter
from .request_retry import retry_provider_request


class SuppressNonTextPartsWarning(logging.Filter):
Expand Down Expand Up @@ -630,10 +631,14 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
modalities,
temperature,
)
result = await self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
retry_rate_limits=False,
)
logger.debug(f"genai result: {result}")

Expand Down Expand Up @@ -710,10 +715,14 @@ async def _query_stream(
payloads.get("tool_choice", "auto"),
system_instruction,
)
result = await self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
retry_rate_limits=False,
)
break
except APIError as e:
Expand Down Expand Up @@ -940,7 +949,10 @@ async def text_chat_stream(

async def get_models(self):
try:
models = await self.client.models.list()
models = await retry_provider_request(
"Gemini",
lambda: self.client.models.list(),
)
return [
m.name.replace("models/", "")
for m in models
Expand Down
32 changes: 22 additions & 10 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings

from ..register import register_provider_adapter
from .request_retry import retry_provider_request


@register_provider_adapter(
Expand Down Expand Up @@ -560,7 +561,10 @@ def _apply_provider_specific_extra_body_overrides(
async def get_models(self):
try:
models_str = []
models = await self.client.models.list()
models = await retry_provider_request(
"OpenAI",
lambda: self.client.models.list(),
)
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
Expand Down Expand Up @@ -636,10 +640,14 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:

self._sanitize_assistant_messages(payloads)

completion = await self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
completion = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
),
retry_rate_limits=False,
)

if not isinstance(completion, ChatCompletion):
Expand Down Expand Up @@ -688,11 +696,15 @@ async def _query_stream(

self._sanitize_assistant_messages(payloads)

stream = await self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
stream = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
),
retry_rate_limits=False,
)

llm_response = LLMResponse("assistant", is_chunk=True)
Expand Down
141 changes: 141 additions & 0 deletions astrbot/core/provider/sources/request_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TypeVar

from tenacity import (
AsyncRetrying,
RetryCallState,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)

from astrbot import logger
from astrbot.core.utils.network_utils import is_connection_error

T = TypeVar("T")

REQUEST_RETRY_ATTEMPTS = 5
REQUEST_RETRY_WAIT_MIN_S = 1
REQUEST_RETRY_WAIT_MAX_S = 8
REQUEST_RETRY_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504, 529}


def _get_status_code(error: BaseException) -> int | None:
for attr in ("status_code", "status", "code"):
value = getattr(error, attr, None)
if isinstance(value, int):
return value

response = getattr(error, "response", None)
if response is not None:
status_code = getattr(response, "status_code", None)
if isinstance(status_code, int):
return status_code

return None


def _is_retryable_provider_request_error(
error: BaseException,
*,
retry_rate_limits: bool,
) -> bool:
if is_connection_error(error):
return True

error_type_name = type(error).__name__
if error_type_name in {"APIConnectionError", "APITimeoutError"}:
return True

status_code = _get_status_code(error)
if status_code is None:
return False

if status_code == 429 and not retry_rate_limits:
return False

return status_code in REQUEST_RETRY_STATUS_CODES or 500 <= status_code <= 599


def _log_retry(provider_label: str, retry_state: RetryCallState) -> None:
error = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
f"[{provider_label}] Request failed with retryable error; "
f"retrying ({retry_state.attempt_number + 1}/{REQUEST_RETRY_ATTEMPTS}): "
f"{error}"
)


def _build_retrying(
provider_label: str,
*,
retry_rate_limits: bool,
) -> AsyncRetrying:
return AsyncRetrying(
retry=retry_if_exception(
lambda error: _is_retryable_provider_request_error(
error,
retry_rate_limits=retry_rate_limits,
)
),
stop=stop_after_attempt(REQUEST_RETRY_ATTEMPTS),
wait=wait_exponential(
multiplier=1,
min=REQUEST_RETRY_WAIT_MIN_S,
max=REQUEST_RETRY_WAIT_MAX_S,
),
before_sleep=lambda retry_state: _log_retry(provider_label, retry_state),
reraise=True,
)


async def retry_provider_request(
provider_label: str,
request_factory: Callable[[], Awaitable[T]],
*,
retry_rate_limits: bool = True,
) -> T:
retrying = _build_retrying(
provider_label,
retry_rate_limits=retry_rate_limits,
)

async for attempt in retrying:
with attempt:
return await request_factory()

raise RuntimeError("Provider request retry loop exited unexpectedly.")


@asynccontextmanager
async def retry_provider_request_context(
provider_label: str,
context_manager_factory: Callable[[], AbstractAsyncContextManager[T]],
*,
retry_rate_limits: bool = True,
) -> AsyncIterator[T]:
manager: AbstractAsyncContextManager[T] | None = None

async def _enter_context() -> T:
nonlocal manager
manager = context_manager_factory()
return await manager.__aenter__()

value = await retry_provider_request(
provider_label,
_enter_context,
retry_rate_limits=retry_rate_limits,
)

if manager is None:
raise RuntimeError("Provider request context was not created.")

try:
yield value
except BaseException as error:
if await manager.__aexit__(type(error), error, error.__traceback__):
return
raise
else:
await manager.__aexit__(None, None, None)
33 changes: 33 additions & 0 deletions tests/test_anthropic_kimi_code_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import builtins
from types import SimpleNamespace

import httpx
import pytest

import astrbot.core.provider.sources.anthropic_source as anthropic_source
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse

Expand Down Expand Up @@ -171,6 +174,36 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
assert captured["httpx_module"] is anthropic_source.httpx


@pytest.mark.asyncio
async def test_anthropic_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)

class FakeModels:
def __init__(self):
self.calls = 0

async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="claude-b"),
SimpleNamespace(id="claude-a"),
]
)

models = FakeModels()
provider = anthropic_source.ProviderAnthropic.__new__(
anthropic_source.ProviderAnthropic
)
provider.client = SimpleNamespace(models=models)

assert await provider.get_models() == ["claude-a", "claude-b"]
assert models.calls == 2


@pytest.mark.asyncio
async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
Expand Down
Loading
Loading