diff --git a/openfeature/api.py b/openfeature/api.py index 7e06c886..6d4987c9 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -31,6 +31,7 @@ "remove_handler", "set_evaluation_context", "set_provider", + "set_provider_and_wait", "set_transaction_context", "set_transaction_context_propagator", "shutdown", @@ -44,12 +45,36 @@ def get_client( def set_provider(provider: FeatureProvider, domain: str | None = None) -> None: + """Set the provider, calling initialize() synchronously. + + Note: In a future major version, this function should run initialize() + in a background thread to match the non-blocking semantics of + set_provider() in the Java, Go, and Node.js SDKs. Callers who need + blocking behavior should migrate to set_provider_and_wait(). + """ if domain is None: provider_registry.set_default_provider(provider) else: provider_registry.set_provider(domain, provider) +def set_provider_and_wait(provider: FeatureProvider, domain: str | None = None) -> None: + """Set the provider and wait for initialization to complete. + + Blocks the calling thread until the provider's initialize() method + returns successfully or raises an exception. If initialization fails, + the exception is re-raised to the caller. + + Spec reference: Requirement 1.1.2.4 - "The API SHOULD provide functions + to set a provider and wait for the initialize function to return or + abnormally terminate." + """ + if domain is None: + provider_registry.set_default_provider_and_wait(provider) + else: + provider_registry.set_provider_and_wait(domain, provider) + + def clear_providers() -> None: provider_registry.clear_providers() _event_support.clear() diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index bf8fa9a8..b52714d6 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -39,6 +39,24 @@ def set_provider(self, domain: str, provider: FeatureProvider) -> None: self._initialize_provider(provider) providers[domain] = provider + def set_provider_and_wait(self, domain: str, provider: FeatureProvider) -> None: + if provider is None: + raise GeneralError(error_message="No provider") + if domain is None: + raise GeneralError(error_message="No domain") + providers = self._providers + if domain in providers: + old_provider = providers[domain] + del providers[domain] + if ( + old_provider != self._default_provider + and old_provider not in providers.values() + ): + self._shutdown_provider(old_provider) + if provider != self._default_provider and provider not in providers.values(): + self._initialize_provider_and_wait(provider) + providers[domain] = provider + def get_provider(self, domain: str | None) -> FeatureProvider: if domain is None: return self._default_provider @@ -57,6 +75,19 @@ def set_default_provider(self, provider: FeatureProvider) -> None: if self._default_provider not in self._providers.values(): self._initialize_provider(provider) + def set_default_provider_and_wait(self, provider: FeatureProvider) -> None: + if provider is None: + raise GeneralError(error_message="No provider") + if ( + self._default_provider + and self._default_provider not in self._providers.values() + ): + self._shutdown_provider(self._default_provider) + self._default_provider = provider + + if self._default_provider not in self._providers.values(): + self._initialize_provider_and_wait(provider) + def get_default_provider(self) -> FeatureProvider: return self._default_provider @@ -76,6 +107,39 @@ def _get_evaluation_context(self) -> EvaluationContext: return get_evaluation_context() def _initialize_provider(self, provider: FeatureProvider) -> None: + """Initialize the provider synchronously. Errors are dispatched as + PROVIDER_ERROR events but not re-raised to the caller. + + This is the original behavior of set_provider(). + """ + provider.attach(self.dispatch_event) + try: + if hasattr(provider, "initialize"): + provider.initialize(self._get_evaluation_context()) + self.dispatch_event( + provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() + ) + except Exception as err: + error_code = ( + err.error_code + if isinstance(err, OpenFeatureError) + else ErrorCode.GENERAL + ) + self.dispatch_event( + provider, + ProviderEvent.PROVIDER_ERROR, + ProviderEventDetails( + message=f"Provider initialization failed: {err}", + error_code=error_code, + ), + ) + + def _initialize_provider_and_wait(self, provider: FeatureProvider) -> None: + """Initialize the provider synchronously and re-raise on failure. + + Same as _initialize_provider but propagates exceptions to the caller, + used by set_provider_and_wait() / set_default_provider_and_wait(). + """ provider.attach(self.dispatch_event) try: if hasattr(provider, "initialize"): @@ -97,6 +161,7 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: error_code=error_code, ), ) + raise def _shutdown_provider(self, provider: FeatureProvider) -> None: try: diff --git a/tests/test_api.py b/tests/test_api.py index 019037db..6e50c22b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -14,6 +14,7 @@ remove_handler, set_evaluation_context, set_provider, + set_provider_and_wait, shutdown, ) from openfeature.evaluation_context import EvaluationContext @@ -330,6 +331,51 @@ def test_provider_error_handlers_run_if_provider_initialize_function_terminates_ spy.provider_error.assert_called_once() +def test_set_provider_and_wait_blocks_until_initialize_completes(): + # Given + evaluation_context = EvaluationContext("targeting_key", {"attr1": "val1"}) + provider = MagicMock(spec=FeatureProvider) + + # When + set_evaluation_context(evaluation_context) + set_provider_and_wait(provider) + + # Then - initialize should have been called synchronously + provider.initialize.assert_called_with(evaluation_context) + # Provider should be READY after set_provider_and_wait returns + client = get_client() + assert client.get_provider_status() == ProviderStatus.READY + + +def test_set_provider_and_wait_raises_on_initialization_failure(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.initialize.side_effect = ProviderFatalError() + + spy = MagicMock() + add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) + + # When / Then - should propagate the exception to the caller + with pytest.raises(ProviderFatalError): + set_provider_and_wait(provider) + + # Error handler should still have been called + spy.provider_error.assert_called_once() + + +def test_set_provider_and_wait_with_domain(): + # Given + provider = MagicMock(spec=FeatureProvider) + + # When + set_provider_and_wait(provider, domain="test") + + # Then + provider.initialize.assert_called_once() + test_client = get_client("test") + assert test_client.provider == provider + + def test_provider_status_is_updated_after_provider_emits_event(): # Given provider = NoOpProvider()