Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions openfeature/_event_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def add_client_handler(
handlers = _client_handlers[client][event]
handlers.append(handler)

# outside the lock intentionally: the immediate-fire status check acquires the registry lock, so calling it
# under _client_lock risks lock-order inversion against run_handlers_for_provider (registry lock → _client_lock).
# As a consequence, a narrow double-fire is possible: if dispatch_event(client's event) runs concurrently, it
# sets the matching provider status (enabling the immediate fire below) and then re-runs every handler for this
# client. If _run_immediate_handler lands after that status set but before dispatch snapshots the handler list,
# the handler fires twice — once here, once from dispatch. Only happens when the registered event matches the event
# being dispatched; otherwise the immediate fire is a no-op.
_run_immediate_handler(client, event, handler)


Expand All @@ -78,6 +85,7 @@ def add_global_handler(event: ProviderEvent, handler: EventHandler) -> None:

from openfeature.api import get_client # noqa: PLC0415

# See comment in add_client_handler for why this runs outside the lock.
_run_immediate_handler(get_client(), event, handler)


Expand Down Expand Up @@ -133,7 +141,6 @@ def _run_handler(handler: EventHandler, details: EventDetails) -> None:


def clear() -> None:
with _global_lock:
with _global_lock, _client_lock:
_global_handlers.clear()
with _client_lock:
_client_handlers.clear()
14 changes: 10 additions & 4 deletions openfeature/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import threading
import typing
from collections.abc import Awaitable, Mapping, Sequence
from dataclasses import dataclass
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
self.version = version
self.context = context or EvaluationContext()
self.hooks = hooks or []
self._hooks_lock = threading.RLock()

@property
def provider(self) -> FeatureProvider:
Expand All @@ -98,7 +100,10 @@ def get_metadata(self) -> ClientMetadata:
return ClientMetadata(domain=self.domain)

def add_hooks(self, hooks: list[Hook]) -> None:
self.hooks = self.hooks + hooks
# Guards the read-concat-store against a lost update; this practically never races under the default 5ms GIL
# switch interval, but is essential under a no-GIL build.
with self._hooks_lock:
self.hooks = self.hooks + hooks

def get_boolean_value(
self,
Expand Down Expand Up @@ -468,8 +473,9 @@ def _establish_hooks_and_provider(

def _assert_provider_status(
self,
provider: FeatureProvider,
) -> OpenFeatureError | None:
status = self.get_provider_status()
status = provider_registry.get_provider_status(provider)
if status == ProviderStatus.NOT_READY:
return ProviderNotReadyError()
if status == ProviderStatus.FATAL:
Expand Down Expand Up @@ -589,7 +595,7 @@ async def evaluate_flag_details_async(
)

try:
if provider_err := self._assert_provider_status():
if provider_err := self._assert_provider_status(provider):
error_hooks(
flag_type,
provider_err,
Expand Down Expand Up @@ -765,7 +771,7 @@ def evaluate_flag_details(
)

try:
if provider_err := self._assert_provider_status():
if provider_err := self._assert_provider_status(provider):
error_hooks(
flag_type,
provider_err,
Expand Down
17 changes: 13 additions & 4 deletions openfeature/hook/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import threading
import typing
from collections.abc import Mapping, MutableMapping, Sequence
from datetime import datetime
Expand All @@ -24,6 +25,7 @@
]

_hooks: list[Hook] = []
_hooks_lock = threading.RLock()


# https://openfeature.dev/specification/sections/hooks/#requirement-461
Expand Down Expand Up @@ -151,14 +153,21 @@ def supports_flag_value_type(self, flag_type: FlagType) -> bool:
return True


# while the lock guarantees safety, even without it there was never a loss within 50.000 runs (with the default GIL
# switch interval of 5ms). only when the switch interval was significantly shortened to 0.1 microseconds, losses were
# observed without locks every now and then. with a no-GIL python, the lock would be essential


def add_hooks(hooks: list[Hook]) -> None:
global _hooks
_hooks = _hooks + hooks
with _hooks_lock:
global _hooks
_hooks = _hooks + hooks


def clear_hooks() -> None:
global _hooks
_hooks = []
with _hooks_lock:
global _hooks
_hooks = []


def get_hooks() -> list[Hook]:
Expand Down
5 changes: 3 additions & 2 deletions openfeature/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,6 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None:
self.emit(ProviderEvent.PROVIDER_STALE, details)

def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None:
if hasattr(self, "_on_emit"):
self._on_emit(self, event, details)
on_emit = getattr(self, "_on_emit", None)
if on_emit is not None:
on_emit(self, event, details)
Comment on lines -264 to +266

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change needed?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of what was pointed out in the original issue:
"
AbstractProvider._on_emit (provider/__init__.py) - emit() does if hasattr(self, "_on_emit"): self._on_emit(...) which is a TOCTOU; detach() during shutdown can delete _on_emit while a background thread is between the check and the call
"

17 changes: 11 additions & 6 deletions openfeature/transaction_context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading

from openfeature.evaluation_context import EvaluationContext
from openfeature.transaction_context.context_var_transaction_context_propagator import (
ContextVarsTransactionContextPropagator,
Expand All @@ -21,25 +23,28 @@
_evaluation_transaction_context_propagator: TransactionContextPropagator = (
NoOpTransactionContextPropagator()
)
_propagator_lock = threading.RLock()


def set_transaction_context_propagator(
transaction_context_propagator: TransactionContextPropagator,
) -> None:
global _evaluation_transaction_context_propagator
_evaluation_transaction_context_propagator = transaction_context_propagator
with _propagator_lock:
_evaluation_transaction_context_propagator = transaction_context_propagator


def clear_transaction_context_propagator() -> None:
set_transaction_context_propagator(NoOpTransactionContextPropagator())


def get_transaction_context() -> EvaluationContext:
return _evaluation_transaction_context_propagator.get_transaction_context()
with _propagator_lock:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this one here, if this is really needed or not

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I tried to ask Claude but the answer was sometimes yes and sometimes no 😅

propagator = _evaluation_transaction_context_propagator
return propagator.get_transaction_context()


def set_transaction_context(evaluation_context: EvaluationContext) -> None:
global _evaluation_transaction_context_propagator
_evaluation_transaction_context_propagator.set_transaction_context(
evaluation_context
)
with _propagator_lock:
propagator = _evaluation_transaction_context_propagator
propagator.set_transaction_context(evaluation_context)
44 changes: 38 additions & 6 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import types
import uuid
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock

import pytest

from openfeature import _event_support, api
from openfeature import client as client_module
from openfeature.api import (
add_hooks,
clear_hooks,
Expand All @@ -20,7 +21,7 @@
from openfeature.client import OpenFeatureClient, _typecheck_flag_value
from openfeature.evaluation_context import EvaluationContext
from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails
from openfeature.exception import ErrorCode, OpenFeatureError
from openfeature.exception import ErrorCode, OpenFeatureError, ProviderFatalError
from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason
from openfeature.hook import Hook
from openfeature.provider import FeatureProvider, ProviderStatus
Expand Down Expand Up @@ -291,9 +292,10 @@ def test_provider_should_return_error_status_if_failed():
async def test_should_shortcircuit_if_provider_is_not_ready(
no_op_provider_client, monkeypatch
):
# Given
monkeypatch.setattr(
no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY
provider_registry,
"get_provider_status",
lambda provider: ProviderStatus.NOT_READY,
)
spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook])
Expand Down Expand Up @@ -321,9 +323,10 @@ async def test_should_shortcircuit_if_provider_is_not_ready(
async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
no_op_provider_client, monkeypatch
):
# Given
monkeypatch.setattr(
no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL
provider_registry,
"get_provider_status",
lambda provider: ProviderStatus.FATAL,
)
spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook])
Expand Down Expand Up @@ -768,3 +771,32 @@ def test_should_noop_if_provider_does_not_support_tracking(monkeypatch):
set_provider(provider)
client = get_client()
client.track(tracking_event_name="test")


def test_assert_provider_status_uses_passed_provider_not_current_registry_state():
fatal_provider = NoOpProvider()
ready_provider = NoOpProvider()

registry_mock = Mock()
registry_mock.get_provider_status.side_effect = lambda p: (
ProviderStatus.FATAL if p is fatal_provider else ProviderStatus.READY
)
registry_mock.get_provider.return_value = ready_provider

original = client_module.provider_registry
client_module.provider_registry = registry_mock
try:
c = OpenFeatureClient(domain=None, version=None)
assert c.provider is ready_provider, (
"test setup: self.provider should resolve via the patched registry"
)

err = c._assert_provider_status(fatal_provider)
assert isinstance(err, ProviderFatalError), (
"status check used self.provider (READY) instead of the captured "
"fatal_provider — TOCTOU regression"
)
registry_mock.get_provider_status.assert_any_call(fatal_provider)
assert c._assert_provider_status(ready_provider) is None
finally:
client_module.provider_registry = original
Loading