Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None
_backoff_factor: float
_max_backoff: float

def __init__(
self,
Expand All @@ -382,6 +384,8 @@ def __init__(
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
backoff_factor: float | None = None,
max_backoff: float | None = None,
) -> None:
self._version = version
self._base_url = self._enforce_trailing_slash(URL(base_url))
Expand All @@ -392,6 +396,8 @@ def __init__(
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
self._platform: Platform | None = None
self._backoff_factor = backoff_factor if backoff_factor is not None else INITIAL_RETRY_DELAY
self._max_backoff = max_backoff if max_backoff is not None else MAX_RETRY_DELAY

if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
raise TypeError(
Expand Down Expand Up @@ -763,7 +769,7 @@ def _calculate_retry_timeout(
nb_retries = min(max_retries - remaining_retries, 1000)

# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
sleep_seconds = min(self._backoff_factor * pow(2.0, nb_retries), self._max_backoff)

# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random()
Expand Down Expand Up @@ -855,6 +861,8 @@ def __init__(
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
backoff_factor: float | None = None,
max_backoff: float | None = None,
) -> None:
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
Expand Down Expand Up @@ -883,6 +891,8 @@ def __init__(
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
backoff_factor=backoff_factor,
max_backoff=max_backoff,
)
self._client = http_client or SyncHttpxClientWrapper(
base_url=base_url,
Expand Down Expand Up @@ -1452,6 +1462,8 @@ def __init__(
http_client: httpx.AsyncClient | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
backoff_factor: float | None = None,
max_backoff: float | None = None,
) -> None:
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
Expand Down Expand Up @@ -1480,6 +1492,8 @@ def __init__(
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
backoff_factor=backoff_factor,
max_backoff=max_backoff,
)
self._client = http_client or AsyncHttpxClientWrapper(
base_url=base_url,
Expand Down
16 changes: 16 additions & 0 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
backoff_factor: float | None = None,
max_backoff: float | None = None,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
# Configure a custom httpx client.
Expand Down Expand Up @@ -174,6 +176,8 @@ def __init__(
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
backoff_factor=backoff_factor,
max_backoff=max_backoff,
)

self._default_stream_cls = Stream
Expand Down Expand Up @@ -376,6 +380,8 @@ def copy(
timeout: float | Timeout | None | NotGiven = not_given,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = not_given,
backoff_factor: float | None = None,
max_backoff: float | None = None,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -414,6 +420,8 @@ def copy(
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
backoff_factor=backoff_factor if backoff_factor is not None else self._backoff_factor,
max_backoff=max_backoff if max_backoff is not None else self._max_backoff,
Comment on lines +423 to +424
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid forwarding backoff kwargs into incompatible subclasses

OpenAI.copy() now always forwards backoff_factor/max_backoff into self.__class__(...), which breaks subclasses whose __init__ signatures do not accept those keywords. In this repo, AzureOpenAI.copy() and AsyncAzureOpenAI.copy() delegate to super().copy(), so calling copy()/with_options() on either Azure client now raises TypeError for unexpected keyword arguments, regressing an existing workflow for Azure users.

Useful? React with 👍 / 👎.

default_headers=headers,
default_query=params,
**_extra_kwargs,
Expand Down Expand Up @@ -484,6 +492,8 @@ def __init__(
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
backoff_factor: float | None = None,
max_backoff: float | None = None,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
# Configure a custom httpx client.
Expand Down Expand Up @@ -549,6 +559,8 @@ def __init__(
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
backoff_factor=backoff_factor,
max_backoff=max_backoff,
)

self._default_stream_cls = AsyncStream
Expand Down Expand Up @@ -751,6 +763,8 @@ def copy(
timeout: float | Timeout | None | NotGiven = not_given,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = not_given,
backoff_factor: float | None = None,
max_backoff: float | None = None,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -789,6 +803,8 @@ def copy(
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
backoff_factor=backoff_factor if backoff_factor is not None else self._backoff_factor,
max_backoff=max_backoff if max_backoff is not None else self._max_backoff,
default_headers=headers,
default_query=params,
**_extra_kwargs,
Expand Down