diff --git a/openfeature/_event_support.py b/openfeature/_event_support.py index 3928be3e..2ed6eeec 100644 --- a/openfeature/_event_support.py +++ b/openfeature/_event_support.py @@ -61,6 +61,13 @@ def add_client_handler( handlers = _client_handlers[client][event] handlers.append(handler) + # outside the lock intentionally: the immediate-fire status check acquires the registry lock, so calling it + # under _client_lock risks lock-order inversion against run_handlers_for_provider (registry lock → _client_lock). + # As a consequence, a narrow double-fire is possible: if dispatch_event(client's event) runs concurrently, it + # sets the matching provider status (enabling the immediate fire below) and then re-runs every handler for this + # client. If _run_immediate_handler lands after that status set but before dispatch snapshots the handler list, + # the handler fires twice — once here, once from dispatch. Only happens when the registered event matches the event + # being dispatched; otherwise the immediate fire is a no-op. _run_immediate_handler(client, event, handler) @@ -78,6 +85,7 @@ def add_global_handler(event: ProviderEvent, handler: EventHandler) -> None: from openfeature.api import get_client # noqa: PLC0415 + # See comment in add_client_handler for why this runs outside the lock. _run_immediate_handler(get_client(), event, handler) @@ -133,7 +141,6 @@ def _run_handler(handler: EventHandler, details: EventDetails) -> None: def clear() -> None: - with _global_lock: + with _global_lock, _client_lock: _global_handlers.clear() - with _client_lock: _client_handlers.clear() diff --git a/openfeature/client.py b/openfeature/client.py index 95dc5b6d..80273c3e 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -1,4 +1,5 @@ import logging +import threading import typing from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass @@ -86,6 +87,7 @@ def __init__( self.version = version self.context = context or EvaluationContext() self.hooks = hooks or [] + self._hooks_lock = threading.RLock() @property def provider(self) -> FeatureProvider: @@ -98,7 +100,10 @@ def get_metadata(self) -> ClientMetadata: return ClientMetadata(domain=self.domain) def add_hooks(self, hooks: list[Hook]) -> None: - self.hooks = self.hooks + hooks + # Guards the read-concat-store against a lost update; this practically never races under the default 5ms GIL + # switch interval, but is essential under a no-GIL build. + with self._hooks_lock: + self.hooks = self.hooks + hooks def get_boolean_value( self, @@ -468,8 +473,9 @@ def _establish_hooks_and_provider( def _assert_provider_status( self, + provider: FeatureProvider, ) -> OpenFeatureError | None: - status = self.get_provider_status() + status = provider_registry.get_provider_status(provider) if status == ProviderStatus.NOT_READY: return ProviderNotReadyError() if status == ProviderStatus.FATAL: @@ -589,7 +595,7 @@ async def evaluate_flag_details_async( ) try: - if provider_err := self._assert_provider_status(): + if provider_err := self._assert_provider_status(provider): error_hooks( flag_type, provider_err, @@ -765,7 +771,7 @@ def evaluate_flag_details( ) try: - if provider_err := self._assert_provider_status(): + if provider_err := self._assert_provider_status(provider): error_hooks( flag_type, provider_err, diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index 247d316b..35f76c3e 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading import typing from collections.abc import Mapping, MutableMapping, Sequence from datetime import datetime @@ -24,6 +25,7 @@ ] _hooks: list[Hook] = [] +_hooks_lock = threading.RLock() # https://openfeature.dev/specification/sections/hooks/#requirement-461 @@ -151,14 +153,21 @@ def supports_flag_value_type(self, flag_type: FlagType) -> bool: return True +# while the lock guarantees safety, even without it there was never a loss within 50.000 runs (with the default GIL +# switch interval of 5ms). only when the switch interval was significantly shortened to 0.1 microseconds, losses were +# observed without locks every now and then. with a no-GIL python, the lock would be essential + + def add_hooks(hooks: list[Hook]) -> None: - global _hooks - _hooks = _hooks + hooks + with _hooks_lock: + global _hooks + _hooks = _hooks + hooks def clear_hooks() -> None: - global _hooks - _hooks = [] + with _hooks_lock: + global _hooks + _hooks = [] def get_hooks() -> list[Hook]: diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 1b2b5206..e02aab78 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -261,5 +261,6 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_STALE, details) def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: - if hasattr(self, "_on_emit"): - self._on_emit(self, event, details) + on_emit = getattr(self, "_on_emit", None) + if on_emit is not None: + on_emit(self, event, details) diff --git a/openfeature/transaction_context/__init__.py b/openfeature/transaction_context/__init__.py index 15ac7e01..70483347 100644 --- a/openfeature/transaction_context/__init__.py +++ b/openfeature/transaction_context/__init__.py @@ -1,3 +1,5 @@ +import threading + from openfeature.evaluation_context import EvaluationContext from openfeature.transaction_context.context_var_transaction_context_propagator import ( ContextVarsTransactionContextPropagator, @@ -21,13 +23,15 @@ _evaluation_transaction_context_propagator: TransactionContextPropagator = ( NoOpTransactionContextPropagator() ) +_propagator_lock = threading.RLock() def set_transaction_context_propagator( transaction_context_propagator: TransactionContextPropagator, ) -> None: global _evaluation_transaction_context_propagator - _evaluation_transaction_context_propagator = transaction_context_propagator + with _propagator_lock: + _evaluation_transaction_context_propagator = transaction_context_propagator def clear_transaction_context_propagator() -> None: @@ -35,11 +39,12 @@ def clear_transaction_context_propagator() -> None: def get_transaction_context() -> EvaluationContext: - return _evaluation_transaction_context_propagator.get_transaction_context() + with _propagator_lock: + propagator = _evaluation_transaction_context_propagator + return propagator.get_transaction_context() def set_transaction_context(evaluation_context: EvaluationContext) -> None: - global _evaluation_transaction_context_propagator - _evaluation_transaction_context_propagator.set_transaction_context( - evaluation_context - ) + with _propagator_lock: + propagator = _evaluation_transaction_context_propagator + propagator.set_transaction_context(evaluation_context) diff --git a/tests/test_client.py b/tests/test_client.py index 44b49e5f..64598936 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,11 +4,12 @@ import types import uuid from concurrent.futures import ThreadPoolExecutor -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest from openfeature import _event_support, api +from openfeature import client as client_module from openfeature.api import ( add_hooks, clear_hooks, @@ -20,7 +21,7 @@ from openfeature.client import OpenFeatureClient, _typecheck_flag_value from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails -from openfeature.exception import ErrorCode, OpenFeatureError +from openfeature.exception import ErrorCode, OpenFeatureError, ProviderFatalError from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason from openfeature.hook import Hook from openfeature.provider import FeatureProvider, ProviderStatus @@ -291,9 +292,10 @@ def test_provider_should_return_error_status_if_failed(): async def test_should_shortcircuit_if_provider_is_not_ready( no_op_provider_client, monkeypatch ): - # Given monkeypatch.setattr( - no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY + provider_registry, + "get_provider_status", + lambda provider: ProviderStatus.NOT_READY, ) spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) @@ -321,9 +323,10 @@ async def test_should_shortcircuit_if_provider_is_not_ready( async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( no_op_provider_client, monkeypatch ): - # Given monkeypatch.setattr( - no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL + provider_registry, + "get_provider_status", + lambda provider: ProviderStatus.FATAL, ) spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) @@ -768,3 +771,32 @@ def test_should_noop_if_provider_does_not_support_tracking(monkeypatch): set_provider(provider) client = get_client() client.track(tracking_event_name="test") + + +def test_assert_provider_status_uses_passed_provider_not_current_registry_state(): + fatal_provider = NoOpProvider() + ready_provider = NoOpProvider() + + registry_mock = Mock() + registry_mock.get_provider_status.side_effect = lambda p: ( + ProviderStatus.FATAL if p is fatal_provider else ProviderStatus.READY + ) + registry_mock.get_provider.return_value = ready_provider + + original = client_module.provider_registry + client_module.provider_registry = registry_mock + try: + c = OpenFeatureClient(domain=None, version=None) + assert c.provider is ready_provider, ( + "test setup: self.provider should resolve via the patched registry" + ) + + err = c._assert_provider_status(fatal_provider) + assert isinstance(err, ProviderFatalError), ( + "status check used self.provider (READY) instead of the captured " + "fatal_provider — TOCTOU regression" + ) + registry_mock.get_provider_status.assert_any_call(fatal_provider) + assert c._assert_provider_status(ready_provider) is None + finally: + client_module.provider_registry = original