diff --git a/openapi/templates/rest.mustache b/openapi/templates/rest.mustache index 17a85a1d9..e65520a45 100644 --- a/openapi/templates/rest.mustache +++ b/openapi/templates/rest.mustache @@ -17,7 +17,10 @@ from rapidata.api_client.exceptions import ApiException, ApiValueError _logger = logging.getLogger("rapidata.api_client") _TRANSIENT_RETRY_MAX_ATTEMPTS = 3 _TRANSIENT_RETRY_BASE_DELAY = 0.5 -_RETRYABLE_STATUS_CODES = {502, 503, 504} +# 429 = rate limit, 502/503/504 = transient upstream / gateway errors. +# All five are safe to retry with exponential backoff. +_RETRYABLE_STATUS_CODES = {429, 502, 503, 504} +_RETRY_AFTER_MAX_SECONDS = 60.0 class RESTResponse(io.IOBase): @@ -139,7 +142,11 @@ class RESTClientObject: continue if r.status_code in _RETRYABLE_STATUS_CODES and attempt < _TRANSIENT_RETRY_MAX_ATTEMPTS: - delay = _TRANSIENT_RETRY_BASE_DELAY * (2 ** attempt) + # Prefer the server's `Retry-After` hint when present + # (common for 429). Fall back to exponential backoff. + delay = self._retry_after_from_response(r) + if delay is None: + delay = _TRANSIENT_RETRY_BASE_DELAY * (2 ** attempt) _logger.warning( "Server error on %s %s (attempt %d/%d): %d. Retrying in %.1fs...", method, url, attempt + 1, _TRANSIENT_RETRY_MAX_ATTEMPTS + 1, @@ -213,6 +220,26 @@ class RESTClientObject: data[key] = value return files, data + @staticmethod + def _retry_after_from_response(response: httpx.Response) -> Optional[float]: + """Parse a `Retry-After` header into a bounded float seconds value. + + Accepts the integer-seconds form (RFC 7231). HTTP-date form is + intentionally not supported; in that case we fall back to the + exponential backoff schedule. Bounded by `_RETRY_AFTER_MAX_SECONDS` + so a hostile or broken server can't extend the retry loop. + """ + raw = response.headers.get("Retry-After") + if not raw: + return None + try: + seconds = float(raw) + except (TypeError, ValueError): + return None + if seconds < 0: + return None + return min(seconds, _RETRY_AFTER_MAX_SECONDS) + @staticmethod def _build_timeout(_request_timeout): """Build a Timeout object from the request timeout parameter.""" diff --git a/src/rapidata/api_client/rest.py b/src/rapidata/api_client/rest.py index 0e2bcee85..9ae39f9d6 100644 --- a/src/rapidata/api_client/rest.py +++ b/src/rapidata/api_client/rest.py @@ -27,7 +27,10 @@ _logger = logging.getLogger("rapidata.api_client") _TRANSIENT_RETRY_MAX_ATTEMPTS = 3 _TRANSIENT_RETRY_BASE_DELAY = 0.5 -_RETRYABLE_STATUS_CODES = {502, 503, 504} +# 429 = rate limit, 502/503/504 = transient upstream / gateway errors. +# All five are safe to retry with exponential backoff. +_RETRYABLE_STATUS_CODES = {429, 502, 503, 504} +_RETRY_AFTER_MAX_SECONDS = 60.0 class RESTResponse(io.IOBase): @@ -149,7 +152,11 @@ def request( continue if r.status_code in _RETRYABLE_STATUS_CODES and attempt < _TRANSIENT_RETRY_MAX_ATTEMPTS: - delay = _TRANSIENT_RETRY_BASE_DELAY * (2 ** attempt) + # Prefer the server's `Retry-After` hint when present + # (common for 429). Fall back to exponential backoff. + delay = self._retry_after_from_response(r) + if delay is None: + delay = _TRANSIENT_RETRY_BASE_DELAY * (2 ** attempt) _logger.warning( "Server error on %s %s (attempt %d/%d): %d. Retrying in %.1fs...", method, url, attempt + 1, _TRANSIENT_RETRY_MAX_ATTEMPTS + 1, @@ -223,6 +230,26 @@ def _parse_multipart_params(post_params): data[key] = value return files, data + @staticmethod + def _retry_after_from_response(response: httpx.Response) -> Optional[float]: + """Parse a `Retry-After` header into a bounded float seconds value. + + Accepts the integer-seconds form (RFC 7231). HTTP-date form is + intentionally not supported; in that case we fall back to the + exponential backoff schedule. Bounded by `_RETRY_AFTER_MAX_SECONDS` + so a hostile or broken server can't extend the retry loop. + """ + raw = response.headers.get("Retry-After") + if not raw: + return None + try: + seconds = float(raw) + except (TypeError, ValueError): + return None + if seconds < 0: + return None + return min(seconds, _RETRY_AFTER_MAX_SECONDS) + @staticmethod def _build_timeout(_request_timeout): """Build a Timeout object from the request timeout parameter."""