diff --git a/src/openai/__init__.py b/src/openai/__init__.py index b2093ada68..d37315a9d9 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -204,7 +204,7 @@ def webhook_secret(self, value: str | None) -> None: # type: ignore @override def base_url(self) -> _httpx.URL: if base_url is not None: - return _httpx.URL(base_url) + return self._enforce_trailing_slash(_httpx.URL(base_url)) return super().base_url diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index ad64707261..ae52f8454b 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -331,6 +331,7 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: headers["Authorization"] = f"Bearer {azure_ad_token}" elif self.api_key is not API_KEY_SENTINEL: if headers.get("api-key") is None: + self._refresh_api_key() headers["api-key"] = self.api_key else: # should never be hit @@ -614,6 +615,7 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp headers["Authorization"] = f"Bearer {azure_ad_token}" elif self.api_key is not API_KEY_SENTINEL: if headers.get("api-key") is None: + await self._refresh_api_key() headers["api-key"] = self.api_key else: # should never be hit diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 52c24eba27..ac48a07717 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -76,6 +76,133 @@ def test_client_copying_override_options(client: Client) -> None: assert copied._custom_query == {"api-version": "2022-05-01"} +@pytest.mark.respx() +def test_client_api_key_provider_refresh_sync(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def api_key_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AzureOpenAI( + api_version="2024-02-01", + api_key=api_key_provider, + azure_endpoint="https://example-resource.azure.openai.com", + ) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + assert calls[0].request.headers.get("api-key") == "first" + assert calls[1].request.headers.get("api-key") == "second" + + +@pytest.mark.asyncio +@pytest.mark.respx() +async def test_client_api_key_provider_refresh_async(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def api_key_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key=api_key_provider, + azure_endpoint="https://example-resource.azure.openai.com", + ) + + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + assert calls[0].request.headers.get("api-key") == "first" + assert calls[1].request.headers.get("api-key") == "second" + + +@pytest.mark.respx() +def test_client_api_key_provider_skipped_when_azure_ad_token_provider_is_used_sync(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + def api_key_provider() -> str: + raise AssertionError("api_key provider should not be called when Azure AD auth is used") + + client = AzureOpenAI( + api_version="2024-02-01", + api_key=api_key_provider, + azure_ad_token_provider=lambda: "azure-ad-token", + azure_endpoint="https://example-resource.azure.openai.com", + ) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + assert calls[0].request.headers.get("Authorization") == "Bearer azure-ad-token" + assert calls[0].request.headers.get("api-key") is None + + +@pytest.mark.asyncio +@pytest.mark.respx() +async def test_client_api_key_provider_skipped_when_azure_ad_token_provider_is_used_async( + respx_mock: MockRouter, +) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + async def api_key_provider() -> str: + raise AssertionError("api_key provider should not be called when Azure AD auth is used") + + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key=api_key_provider, + azure_ad_token_provider=lambda: "azure-ad-token", + azure_endpoint="https://example-resource.azure.openai.com", + ) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + assert calls[0].request.headers.get("Authorization") == "Bearer azure-ad-token" + assert calls[0].request.headers.get("api-key") is None + + @pytest.mark.respx() def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None: respx_mock.post( diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 9c9a1addab..2d5a928c60 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -42,8 +42,15 @@ def test_base_url_option() -> None: openai.base_url = "http://foo.com" - assert openai.base_url == URL("http://foo.com") - assert openai.completions._client.base_url == URL("http://foo.com") + assert openai.base_url == "http://foo.com" + assert openai.completions._client.base_url.raw_path == b"/" + + +def test_base_url_option_without_trailing_slash() -> None: + openai.base_url = "http://foo.com/custom/path" + + assert openai.base_url == "http://foo.com/custom/path" + assert openai.completions._client.base_url == URL("http://foo.com/custom/path/") def test_timeout_option() -> None: