diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index aea5069f..55e00263 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -11,11 +11,26 @@ from openfeature.hook import Hook from .metadata import Metadata +from .multi_provider import ( + EvaluationStrategy, + FirstMatchStrategy, + MultiProvider, + ProviderEntry, +) if typing.TYPE_CHECKING: from openfeature.flag_evaluation import FlagValueType -__all__ = ["AbstractProvider", "FeatureProvider", "Metadata", "ProviderStatus"] +__all__ = [ + "AbstractProvider", + "EvaluationStrategy", + "FeatureProvider", + "FirstMatchStrategy", + "Metadata", + "MultiProvider", + "ProviderEntry", + "ProviderStatus", +] class ProviderStatus(Enum): diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py new file mode 100644 index 00000000..07aba99e --- /dev/null +++ b/openfeature/provider/multi_provider.py @@ -0,0 +1,488 @@ +""" +Multi-Provider implementation for OpenFeature Python SDK. + +This provider wraps multiple underlying providers, allowing a single client +to interact with multiple flag sources simultaneously. + +See: https://openfeature.dev/specification/appendix-a/#multi-provider +""" + +from __future__ import annotations + +import typing +from collections.abc import Awaitable, Callable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass + +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, GeneralError +from openfeature.flag_evaluation import FlagResolutionDetails, FlagValueType, Reason +from openfeature.hook import Hook +from openfeature.provider import AbstractProvider, FeatureProvider, Metadata + +__all__ = ["MultiProvider", "ProviderEntry", "FirstMatchStrategy", "EvaluationStrategy"] + + +@dataclass +class ProviderEntry: + """Configuration for a provider in the Multi-Provider.""" + + provider: FeatureProvider + name: str | None = None + + +class EvaluationStrategy(typing.Protocol): + """ + Strategy interface for determining which provider's result to use. + + Supports 'sequential' mode (evaluate one at a time, stop early when strategy + is satisfied) and 'parallel' mode (evaluate all providers, then select best + result). Note: Both modes currently execute provider calls sequentially; + true concurrent evaluation using asyncio.gather or ThreadPoolExecutor is + planned for a future enhancement. + """ + + run_mode: typing.Literal["sequential", "parallel"] + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails, + ) -> bool: + """ + Determine if this result should be used (and stop evaluation if sequential). + + :param flag_key: The flag being evaluated + :param provider_name: Name of the provider that returned this result + :param result: The resolution details from the provider + :return: True if this result should be used as the final result + """ + ... + + +class FirstMatchStrategy: + """ + Uses the first successful result from providers (in order). + + In sequential mode, stops at the first non-error result. + In parallel mode, picks the first successful result from the ordered list. + """ + + run_mode: typing.Literal["sequential", "parallel"] = "sequential" + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails, + ) -> bool: + """Use the first result that doesn't have an error.""" + return result.reason != Reason.ERROR + + +class MultiProvider(AbstractProvider): + """ + A provider that aggregates multiple underlying providers. + + Evaluations are delegated to underlying providers based on the configured + strategy (default: FirstMatchStrategy in sequential mode). + + Example: + provider_a = SomeProvider() + provider_b = AnotherProvider() + + multi = MultiProvider([ + ProviderEntry(provider_a, name="primary"), + ProviderEntry(provider_b, name="fallback") + ]) + + api.set_provider(multi) + """ + + def __init__( + self, + providers: list[ProviderEntry], + strategy: EvaluationStrategy | None = None, + ): + """ + Initialize the Multi-Provider. + + :param providers: List of ProviderEntry objects defining the providers + :param strategy: Evaluation strategy (defaults to FirstMatchStrategy) + """ + super().__init__() + + if not providers: + raise ValueError("At least one provider must be provided") + + self.strategy = strategy or FirstMatchStrategy() + self._registered_providers: list[tuple[str, FeatureProvider]] = [] + self._register_providers(providers) + self._cached_hooks: list[Hook] | None = None + + def _register_providers(self, providers: list[ProviderEntry]) -> None: + """ + Register providers with unique names. + + Names are determined by: + 1. Explicit name in ProviderEntry + 2. provider.get_metadata().name if unique and not conflicting + 3. {metadata.name}_{index} if not unique or conflicting + """ + # Count providers by their metadata name to detect duplicates + name_counts: dict[str, int] = {} + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + name_counts[metadata_name] = name_counts.get(metadata_name, 0) + 1 + + # Track used names to prevent conflicts + used_names: set[str] = set() + name_indices: dict[str, int] = {} + + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + + if entry.name: + # Explicit name provided - must be unique + if entry.name in used_names: + raise ValueError(f"Provider name '{entry.name}' is not unique") + final_name = entry.name + elif name_counts[metadata_name] == 1 and metadata_name not in used_names: + # Metadata name is unique and not already taken by explicit name + final_name = metadata_name + else: + # Multiple providers or collision with explicit name, add index + while True: + name_indices[metadata_name] = name_indices.get(metadata_name, 0) + 1 + final_name = f"{metadata_name}_{name_indices[metadata_name]}" + if final_name not in used_names: + break + + used_names.add(final_name) + self._registered_providers.append((final_name, entry.provider)) + + def get_metadata(self) -> Metadata: + """Return metadata including all wrapped provider metadata.""" + return Metadata(name="MultiProvider") + + def get_provider_hooks(self) -> list[Hook]: + """Aggregate hooks from all providers (cached for efficiency).""" + if self._cached_hooks is None: + hooks: list[Hook] = [] + for _, provider in self._registered_providers: + hooks.extend(provider.get_provider_hooks()) + self._cached_hooks = hooks + return self._cached_hooks + + def attach( + self, + on_emit: Callable[[FeatureProvider, ProviderEvent, ProviderEventDetails], None], + ) -> None: + """ + Attach event handler and propagate to all underlying providers. + + Events from underlying providers are forwarded through the MultiProvider. + This enables features like cache invalidation to work across all providers. + """ + super().attach(on_emit) + + # Propagate attach to all wrapped providers + for _, provider in self._registered_providers: + provider.attach(on_emit) + + def detach(self) -> None: + """ + Detach event handler and propagate to all underlying providers. + """ + super().detach() + + # Propagate detach to all wrapped providers + for _, provider in self._registered_providers: + provider.detach() + + def initialize(self, evaluation_context: EvaluationContext) -> None: + """ + Initialize all providers in parallel using ThreadPoolExecutor. + + This allows concurrent initialization of I/O-bound providers. + """ + def init_provider(entry: tuple[str, FeatureProvider]) -> str | None: + name, provider = entry + try: + provider.initialize(evaluation_context) + return None + except Exception as e: + return f"Provider '{name}' initialization failed: {e}" + + with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: + results = list(executor.map(init_provider, self._registered_providers)) + + errors = [r for r in results if r is not None] + if errors: + error_msgs = "; ".join(errors) + raise GeneralError(f"Multi-provider initialization failed: {error_msgs}") + + def shutdown(self) -> None: + """Shutdown all providers in parallel.""" + import logging + + logger = logging.getLogger(__name__) + + def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: + name, provider = entry + try: + provider.shutdown() + except Exception as e: + logger.error(f"Provider '{name}' shutdown failed: {e}") + + with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: + list(executor.map(shutdown_provider, self._registered_providers)) + + def _evaluate_with_providers( + self, + flag_key: str, + default_value: FlagValueType, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], FlagResolutionDetails[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + """ + Core evaluation logic that delegates to providers based on strategy. + + Current implementation evaluates providers sequentially regardless of + strategy.run_mode. True concurrent evaluation for 'parallel' mode is + planned for a future enhancement. + + :param flag_key: The flag key to evaluate + :param default_value: Default value for the flag + :param evaluation_context: Evaluation context + :param resolve_fn: Function to call on each provider for resolution + :return: Final resolution details + """ + results: list[tuple[str, FlagResolutionDetails[FlagValueType]]] = [] + + for provider_name, provider in self._registered_providers: + try: + result = resolve_fn(provider, flag_key, default_value, evaluation_context) + results.append((provider_name, result)) + + # In sequential mode, stop if strategy says to use this result + if (self.strategy.run_mode == "sequential" and + self.strategy.should_use_result(flag_key, provider_name, result)): + return result + + except Exception as e: + # Record error but continue to next provider + error_result = FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=str(e), + ) + results.append((provider_name, error_result)) + + # If all sequential attempts completed (or parallel mode), pick best result + for provider_name, result in results: + if self.strategy.should_use_result(flag_key, provider_name, result): + return result + + # No successful result - return last error or default + if results: + return results[-1][1] + + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message="No providers returned a result", + ) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_boolean_details(k, d, ctx), + ) + + async def _evaluate_with_providers_async( + self, + flag_key: str, + default_value: FlagValueType, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], Awaitable[FlagResolutionDetails[FlagValueType]]], + ) -> FlagResolutionDetails[FlagValueType]: + """ + Async evaluation logic that properly awaits provider async methods. + + :param flag_key: The flag key to evaluate + :param default_value: Default value for the flag + :param evaluation_context: Evaluation context + :param resolve_fn: Async function to call on each provider for resolution + :return: Final resolution details + """ + results: list[tuple[str, FlagResolutionDetails[FlagValueType]]] = [] + + for provider_name, provider in self._registered_providers: + try: + result = await resolve_fn(provider, flag_key, default_value, evaluation_context) + results.append((provider_name, result)) + + # In sequential mode, stop if strategy says to use this result + if (self.strategy.run_mode == "sequential" and + self.strategy.should_use_result(flag_key, provider_name, result)): + return result + + except Exception as e: + # Record error but continue to next provider + error_result = FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=str(e), + ) + results.append((provider_name, error_result)) + + # If all sequential attempts completed (or parallel mode), pick best result + for provider_name, result in results: + if self.strategy.should_use_result(flag_key, provider_name, result): + return result + + # No successful result - return last error or default + if results: + return results[-1][1] + + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message="No providers returned a result", + ) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + """Async boolean evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_boolean_details_async(k, d, ctx), + ) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_string_details(k, d, ctx), + ) + + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + """Async string evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_string_details_async(k, d, ctx), + ) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_integer_details(k, d, ctx), + ) + + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + """Async integer evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_integer_details_async(k, d, ctx), + ) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_float_details(k, d, ctx), + ) + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + """Async float evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_float_details_async(k, d, ctx), + ) + + def resolve_object_details( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_object_details(k, d, ctx), + ) + + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + """Async object evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_object_details_async(k, d, ctx), + ) diff --git a/tests/test_multi_provider.py b/tests/test_multi_provider.py new file mode 100644 index 00000000..2ba7759a --- /dev/null +++ b/tests/test_multi_provider.py @@ -0,0 +1,297 @@ +import pytest + +from openfeature import api +from openfeature.evaluation_context import EvaluationContext +from openfeature.exception import GeneralError +from openfeature.flag_evaluation import FlagResolutionDetails, Reason +from openfeature.provider import Metadata +from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider +from openfeature.provider.multi_provider import ( + FirstMatchStrategy, + MultiProvider, + ProviderEntry, +) +from openfeature.provider.no_op_provider import NoOpProvider + + +def test_multi_provider_requires_at_least_one_provider(): + # Given/When/Then + with pytest.raises(ValueError, match="At least one provider must be provided"): + MultiProvider([]) + + +def test_multi_provider_uses_explicit_names(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # When + multi = MultiProvider([ + ProviderEntry(provider_a, name="first"), + ProviderEntry(provider_b, name="second"), + ]) + + # Then + assert len(multi._registered_providers) == 2 + assert multi._registered_providers[0][0] == "first" + assert multi._registered_providers[1][0] == "second" + + +def test_multi_provider_generates_unique_names_when_metadata_conflicts(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # When - both have same metadata name "NoOpProvider" + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # Then - names are auto-indexed + assert len(multi._registered_providers) == 2 + names = [name for name, _ in multi._registered_providers] + assert names == ["NoOpProvider_1", "NoOpProvider_2"] + + +def test_multi_provider_rejects_duplicate_explicit_names(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # When/Then + with pytest.raises(ValueError, match="Provider name 'duplicate' is not unique"): + MultiProvider([ + ProviderEntry(provider_a, name="duplicate"), + ProviderEntry(provider_b, name="duplicate"), + ]) + + +def test_multi_provider_first_match_strategy_sequential(): + # Given + flags_a = { + "flag1": InMemoryFlag("off", {"on": True, "off": False}), + } + flags_b = { + "flag1": InMemoryFlag("on", {"on": True, "off": False}), + "flag2": InMemoryFlag("on", {"on": True, "off": False}), + } + + provider_a = InMemoryProvider(flags_a) + provider_b = InMemoryProvider(flags_b) + + multi = MultiProvider([ + ProviderEntry(provider_a, name="primary"), + ProviderEntry(provider_b, name="fallback"), + ], strategy=FirstMatchStrategy()) + + # When - flag1 exists in both, should use first (primary) + result = multi.resolve_boolean_details("flag1", False) + + # Then + assert result.value == False # primary provider returns "off" variant + assert result.reason != Reason.ERROR + + +def test_multi_provider_fallback_to_second_provider(): + # Given + flags_a = {} # primary has no flags + flags_b = { + "flag1": InMemoryFlag("on", {"on": True, "off": False}), + } + + provider_a = InMemoryProvider(flags_a) + provider_b = InMemoryProvider(flags_b) + + multi = MultiProvider([ + ProviderEntry(provider_a, name="primary"), + ProviderEntry(provider_b, name="fallback"), + ]) + + # When - flag1 doesn't exist in primary, should fallback + result = multi.resolve_boolean_details("flag1", False) + + # Then + assert result.value == True # fallback provider has the flag + assert result.reason != Reason.ERROR + + +def test_multi_provider_all_types_work(): + # Given + flags = { + "bool-flag": InMemoryFlag("on", {"on": True, "off": False}), + "string-flag": InMemoryFlag("greeting", {"greeting": "hello", "farewell": "goodbye"}), + "int-flag": InMemoryFlag("big", {"small": 10, "big": 100}), + "float-flag": InMemoryFlag("pi", {"pi": 3.14, "e": 2.71}), + "object-flag": InMemoryFlag("full", { + "full": {"name": "test", "value": 42}, + "empty": {}, + }), + } + + provider = InMemoryProvider(flags) + multi = MultiProvider([ProviderEntry(provider)]) + + # When/Then + bool_result = multi.resolve_boolean_details("bool-flag", False) + assert bool_result.value == True + + string_result = multi.resolve_string_details("string-flag", "default") + assert string_result.value == "hello" + + int_result = multi.resolve_integer_details("int-flag", 0) + assert int_result.value == 100 + + float_result = multi.resolve_float_details("float-flag", 0.0) + assert float_result.value == 3.14 + + object_result = multi.resolve_object_details("object-flag", {}) + assert object_result.value == {"name": "test", "value": 42} + + +def test_multi_provider_initialize_all_providers(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # Track if initialize was called + provider_a.initialize = lambda ctx: None + provider_b.initialize = lambda ctx: None + + a_initialized = False + b_initialized = False + + def track_a_init(ctx): + nonlocal a_initialized + a_initialized = True + + def track_b_init(ctx): + nonlocal b_initialized + b_initialized = True + + provider_a.initialize = track_a_init + provider_b.initialize = track_b_init + + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # When + multi.initialize(EvaluationContext()) + + # Then + assert a_initialized + assert b_initialized + + +def test_multi_provider_initialization_failures_are_aggregated(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + def fail_init(ctx): + raise Exception("Init failed") + + provider_a.initialize = fail_init + provider_b.initialize = fail_init + + multi = MultiProvider([ + ProviderEntry(provider_a, name="a"), + ProviderEntry(provider_b, name="b"), + ]) + + # When/Then + with pytest.raises(GeneralError, match="Multi-provider initialization failed"): + multi.initialize(EvaluationContext()) + + +def test_multi_provider_returns_error_when_no_providers_have_flag(): + # Given + provider_a = InMemoryProvider({}) + provider_b = InMemoryProvider({}) + + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # When + result = multi.resolve_boolean_details("nonexistent", False) + + # Then + assert result.value == False # default value + assert result.reason == Reason.ERROR + + +@pytest.mark.asyncio +async def test_multi_provider_async_methods_work(): + # Given + flags = { + "async-flag": InMemoryFlag("on", {"on": True, "off": False}), + } + provider = InMemoryProvider(flags) + multi = MultiProvider([ProviderEntry(provider)]) + + # When + result = await multi.resolve_boolean_details_async("async-flag", False) + + # Then + assert result.value == True + assert result.reason != Reason.ERROR + + +def test_multi_provider_can_be_used_with_api(): + # Given + api.clear_providers() + flags = { + "api-flag": InMemoryFlag("on", {"on": True, "off": False}), + } + provider = InMemoryProvider(flags) + multi = MultiProvider([ProviderEntry(provider)]) + + # When + api.set_provider(multi) + client = api.get_client() + value = client.get_boolean_value("api-flag", False) + + # Then + assert value == True + + +def test_multi_provider_metadata(): + # Given + multi = MultiProvider([ProviderEntry(NoOpProvider())]) + + # When + metadata = multi.get_metadata() + + # Then + assert metadata.name == "MultiProvider" + + +def test_multi_provider_aggregates_hooks(): + # Given + from unittest.mock import MagicMock + + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + hook_a = MagicMock() + hook_b = MagicMock() + + provider_a.get_provider_hooks = lambda: [hook_a] + provider_b.get_provider_hooks = lambda: [hook_b] + + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # When + hooks = multi.get_provider_hooks() + + # Then + assert len(hooks) == 2 + assert hook_a in hooks + assert hook_b in hooks