-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
feat(stt): honor proxy in OpenAI Whisper provider #8668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,9 +4,16 @@ | |
|
|
||
| import pytest | ||
|
|
||
| import astrbot.core.provider.sources.whisper_api_source as whisper_api_source | ||
| from astrbot.core.provider.sources.whisper_api_source import ProviderOpenAIWhisperAPI | ||
|
|
||
|
|
||
| class _FakeAsyncOpenAI: | ||
| def __init__(self, **kwargs): | ||
| self.kwargs = kwargs | ||
| self.close = AsyncMock() | ||
|
|
||
|
|
||
| def _make_provider() -> ProviderOpenAIWhisperAPI: | ||
| provider = ProviderOpenAIWhisperAPI( | ||
| provider_config={ | ||
|
|
@@ -28,6 +35,134 @@ def _make_provider() -> ProviderOpenAIWhisperAPI: | |
| return provider | ||
|
|
||
|
|
||
| def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch): | ||
| captured: dict[str, object] = {} | ||
| fake_http_client = SimpleNamespace(aclose=AsyncMock()) | ||
|
|
||
| def fake_create_proxy_client( | ||
| provider_label: str, | ||
| proxy: str | None = None, | ||
| headers: dict[str, str] | None = None, | ||
| verify=None, | ||
| httpx_module=None, | ||
|
Comment on lines
+38
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding a complementary test case for when no proxy is configured Please also add a test where Suggested implementation: def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch):
captured: dict[str, object] = {}
fake_http_client = object()
def fake_create_proxy_client(
provider_label: str,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify=None,
httpx_module=None,
):
captured["provider_label"] = provider_label
captured["proxy"] = proxy
captured["headers"] = headers
captured["verify"] = verify
captured["httpx_module"] = httpx_module
return fake_http_client
# Adjust the target string here to match where `create_proxy_client` is imported/used
monkeypatch.setattr(
"whisper_api_source.create_proxy_client",
fake_create_proxy_client,
)
provider = ProviderOpenAIWhisperAPI(
provider_config={
# include whatever keys are normally required by your provider_config
"provider_label": "openai_whisper_api",
"proxy": "http://configured-proxy.example.com",
}
)
# Trigger the code path that builds the HTTP client. Adjust this call as needed.
http_client = provider._get_http_client()
assert http_client is fake_http_client
assert captured["provider_label"] == "openai_whisper_api"
assert captured["proxy"] == "http://configured-proxy.example.com"
def test_provider_uses_default_http_client_when_proxy_missing(monkeypatch):
captured: dict[str, object] = {}
fake_http_client = object()
def fake_create_proxy_client(
provider_label: str,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify=None,
httpx_module=None,
):
captured["provider_label"] = provider_label
captured["proxy"] = proxy
captured["headers"] = headers
captured["verify"] = verify
captured["httpx_module"] = httpx_module
return fake_http_client
# Adjust the target string here to match where `create_proxy_client` is imported/used
monkeypatch.setattr(
"whisper_api_source.create_proxy_client",
fake_create_proxy_client,
)
# Construct the provider WITHOUT a `proxy` key in provider_config
provider = ProviderOpenAIWhisperAPI(
provider_config={
# include whatever keys are normally required by your provider_config
"provider_label": "openai_whisper_api",
# NOTE: no "proxy" key here on purpose
}
)
# Trigger the code path that builds the HTTP client. Adjust this call as needed.
http_client = provider._get_http_client()
assert http_client is fake_http_client
assert captured["provider_label"] == "openai_whisper_api"
# Depending on your intended contract, assert the default value here:
# - if default is "", use `== ""`
# - if default is None, use `is None`
assert captured["proxy"] in ("", None)To integrate this cleanly with your existing codebase, you will likely need to:
|
||
| ): | ||
| captured["provider_label"] = provider_label | ||
| captured["proxy"] = proxy | ||
| captured["headers"] = headers | ||
| captured["httpx_module"] = httpx_module | ||
| return fake_http_client | ||
|
|
||
| monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) | ||
| monkeypatch.setattr( | ||
| whisper_api_source, | ||
| "create_proxy_client", | ||
| fake_create_proxy_client, | ||
| ) | ||
|
|
||
| provider = ProviderOpenAIWhisperAPI( | ||
| provider_config={ | ||
| "id": "test-whisper-api", | ||
| "type": "openai_whisper_api", | ||
| "model": "whisper-1", | ||
| "api_key": "test-key", | ||
| "api_base": "https://api.example.com/v1", | ||
| "proxy": "http://127.0.0.1:7890", | ||
| "timeout": 30, | ||
| }, | ||
| provider_settings={}, | ||
| ) | ||
|
|
||
| assert provider.client.kwargs["api_key"] == "test-key" | ||
| assert provider.client.kwargs["base_url"] == "https://api.example.com/v1" | ||
| assert provider.client.kwargs["timeout"] == 30 | ||
| assert provider.client.kwargs["http_client"] is fake_http_client | ||
| assert set(provider.client.kwargs) == { | ||
| "api_key", | ||
| "base_url", | ||
| "timeout", | ||
| "http_client", | ||
| } | ||
| assert provider.http_client is fake_http_client | ||
| assert captured["provider_label"] == "OpenAI Whisper" | ||
| assert captured["proxy"] == "http://127.0.0.1:7890" | ||
| assert captured["headers"] is None | ||
| assert captured["httpx_module"] is not None | ||
|
|
||
|
|
||
| def test_provider_uses_default_http_client_when_proxy_missing(monkeypatch): | ||
| captured: dict[str, object] = {} | ||
| fake_http_client = SimpleNamespace(aclose=AsyncMock()) | ||
|
|
||
| def fake_create_proxy_client( | ||
| provider_label: str, | ||
| proxy: str | None = None, | ||
| headers: dict[str, str] | None = None, | ||
| verify=None, | ||
| httpx_module=None, | ||
| ): | ||
| captured["provider_label"] = provider_label | ||
| captured["proxy"] = proxy | ||
| captured["headers"] = headers | ||
| captured["httpx_module"] = httpx_module | ||
| return fake_http_client | ||
|
|
||
| monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) | ||
| monkeypatch.setattr( | ||
| whisper_api_source, | ||
| "create_proxy_client", | ||
| fake_create_proxy_client, | ||
| ) | ||
|
|
||
| provider = ProviderOpenAIWhisperAPI( | ||
| provider_config={ | ||
| "id": "test-whisper-api", | ||
| "type": "openai_whisper_api", | ||
| "model": "whisper-1", | ||
| "api_key": "test-key", | ||
| }, | ||
| provider_settings={}, | ||
| ) | ||
|
|
||
| assert provider.client.kwargs["http_client"] is fake_http_client | ||
| assert set(provider.client.kwargs) == { | ||
| "api_key", | ||
| "base_url", | ||
| "timeout", | ||
| "http_client", | ||
| } | ||
| assert provider.http_client is fake_http_client | ||
| assert captured["provider_label"] == "OpenAI Whisper" | ||
| assert captured["proxy"] is None | ||
| assert captured["headers"] is None | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_terminate_closes_openai_client_and_custom_http_client(monkeypatch): | ||
| fake_http_client = SimpleNamespace(aclose=AsyncMock()) | ||
|
|
||
| monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) | ||
| monkeypatch.setattr( | ||
| whisper_api_source, | ||
| "create_proxy_client", | ||
| lambda *args, **kwargs: fake_http_client, | ||
| ) | ||
|
|
||
| provider = ProviderOpenAIWhisperAPI( | ||
| provider_config={ | ||
| "id": "test-whisper-api", | ||
| "type": "openai_whisper_api", | ||
| "model": "whisper-1", | ||
| "api_key": "test-key", | ||
| }, | ||
| provider_settings={}, | ||
| ) | ||
|
|
||
| await provider.terminate() | ||
|
|
||
| provider.client.close.assert_awaited_once() | ||
| fake_http_client.aclose.assert_awaited_once() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_get_text_converts_opus_files_to_wav_before_transcription( | ||
| tmp_path: Path, monkeypatch: pytest.MonkeyPatch | ||
|
|
@@ -38,7 +173,9 @@ async def test_get_text_converts_opus_files_to_wav_before_transcription( | |
|
|
||
| conversions: list[tuple[str, str]] = [] | ||
|
|
||
| async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None): | ||
| async def fake_convert_audio_to_wav( | ||
| audio_path: str, output_path: str | None = None | ||
| ): | ||
| assert output_path is not None | ||
| conversions.append((audio_path, output_path)) | ||
| Path(output_path).write_bytes(b"fake wav data") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a custom
http_clientis passed toAsyncOpenAI, callingself.client.close()does not close the custom client (the OpenAI SDK explicitly skips closing custom clients to avoid lifecycle conflicts). This leads to unclosed client/connection leaks.To fix this, store the created client in
self.http_clientso that it can be closed interminate():