Skip to content

Commit 742e9fe

Browse files
committed
fix: two more tests
1 parent 5a1ce05 commit 742e9fe

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

laygo/context/parallel.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import Iterator
88
import multiprocessing as mp
99
from multiprocessing.managers import DictProxy
10+
import threading
1011
from threading import Lock
1112
from typing import Any
1213
from typing import TypeVar
@@ -64,23 +65,24 @@ def __init__(self, initial_context: dict[str, Any] | None = None, handle: Parall
6465
self._shared_dict = self._manager.dict(initial_context or {})
6566
self._lock = self._manager.Lock()
6667

67-
self._is_locked = False
68+
# Thread-local storage for lock state to handle concurrent access
69+
self._local = threading.local()
6870

6971
def _lock_context(self) -> None:
7072
"""Acquire the lock for this context manager."""
71-
if not self._is_locked:
73+
if not getattr(self._local, "is_locked", False):
7274
self._lock.acquire()
73-
self._is_locked = True
75+
self._local.is_locked = True
7476

7577
def _unlock_context(self) -> None:
7678
"""Release the lock for this context manager."""
77-
if self._is_locked:
79+
if getattr(self._local, "is_locked", False):
7880
self._lock.release()
79-
self._is_locked = False
81+
self._local.is_locked = False
8082

8183
def _execute_locked(self, operation: Callable[[], R]) -> R:
8284
"""A private helper to execute an operation within a lock."""
83-
if not self._is_locked:
85+
if not getattr(self._local, "is_locked", False):
8486
self._lock_context()
8587
try:
8688
return operation()

tests/test_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ def test_parallel_transformer_with_context_modification(self):
134134
parallel_transformer = parallel_transformer.map(safe_increment_and_transform)
135135

136136
data = [1, 2, 3, 4, 5]
137-
result, _ = Pipeline(data).context(context).apply(parallel_transformer).to_list()
137+
result, processed_context = Pipeline(data).context(context).apply(parallel_transformer).to_list()
138138

139139
# Verify transformation results
140140
assert sorted(result) == [2, 4, 6, 8, 10]
141141
# Verify context was safely modified
142-
assert context["processed_count"] == len(data)
143-
assert context["sum_total"] == sum(data)
142+
assert processed_context["processed_count"] == len(data)
143+
assert processed_context["sum_total"] == sum(data)
144144

145145
def test_pipeline_accesses_modified_context(self):
146146
"""Test that pipeline can access context data modified by parallel transformer."""

tests/test_threaded_transformer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,10 @@ def test_context_modification_with_locking(self):
100100
context = ParallelContextManager({"items": 0})
101101

102102
def safe_increment(x: int, ctx: IContextManager) -> int:
103-
current_items = ctx["items"]
104-
time.sleep(0.001) # Increase chance of race condition
105-
ctx["items"] = current_items + 1
103+
with ctx:
104+
# Simulate a race condition
105+
time.sleep(0.001) # Increase chance of race condition
106+
ctx["items"] = ctx["items"] + 1
106107
return x * 2
107108

108109
transformer = ThreadedTransformer[int, int](max_workers=4, chunk_size=1)

0 commit comments

Comments
 (0)