Skip to content

Commit bf30d00

Browse files
committed
chore: simplified parallel context
1 parent 8b61a14 commit bf30d00

3 files changed

Lines changed: 87 additions & 85 deletions

File tree

laygo/context/parallel.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,131 +3,131 @@
33
multiprocessing.Manager to share state across processes.
44
"""
55

6+
from collections.abc import Callable
67
from collections.abc import Iterator
78
import multiprocessing as mp
8-
from multiprocessing.managers import BaseManager
99
from multiprocessing.managers import DictProxy
10-
from multiprocessing.synchronize import Lock
10+
from threading import Lock
1111
from typing import Any
12+
from typing import TypeVar
1213

1314
from laygo.context.types import IContextHandle
1415
from laygo.context.types import IContextManager
1516

16-
17-
class _ParallelStateManager(BaseManager):
18-
"""A custom manager to expose a shared dictionary and lock."""
19-
20-
pass
17+
R = TypeVar("R")
2118

2219

2320
class ParallelContextHandle(IContextHandle):
2421
"""
25-
A lightweight, picklable "blueprint" for recreating a connection to the
26-
shared context in a different process.
22+
A lightweight, picklable handle that carries the actual shared objects
23+
(the DictProxy and Lock) to worker processes.
2724
"""
2825

29-
def __init__(self, address: tuple[str, int], manager_class: type["ParallelContextManager"]):
30-
self.address = address
31-
self.manager_class = manager_class
26+
def __init__(self, shared_dict: DictProxy, lock: Lock):
27+
self._shared_dict = shared_dict
28+
self._lock = lock
3229

3330
def create_proxy(self) -> "IContextManager":
3431
"""
35-
Creates a new instance of the ParallelContextManager in "proxy" mode
36-
by initializing it with this handle.
32+
Creates a new ParallelContextManager instance that wraps the shared
33+
objects received by the worker process.
3734
"""
38-
return self.manager_class(handle=self)
35+
return ParallelContextManager(handle=self)
3936

4037

4138
class ParallelContextManager(IContextManager):
4239
"""
43-
A context manager that uses a background multiprocessing.Manager to enable
44-
state sharing across different processes.
45-
46-
This single class operates in two modes:
47-
1. Server Mode (when created normally): It starts and manages the background
48-
server process that holds the shared state.
49-
2. Proxy Mode (when created with a handle): It acts as a client, connecting
50-
to an existing server process to manipulate the shared state.
40+
A context manager that enables state sharing across processes.
41+
42+
It operates in two modes:
43+
1. Main Mode: When created normally, it starts a multiprocessing.Manager
44+
and creates a shared dictionary and lock.
45+
2. Proxy Mode: When created from a handle, it wraps a DictProxy and Lock
46+
that were received from another process. It does not own the manager.
5147
"""
5248

5349
def __init__(self, initial_context: dict[str, Any] | None = None, handle: ParallelContextHandle | None = None):
5450
"""
5551
Initializes the manager. If a handle is provided, it initializes in
56-
proxy mode; otherwise, it starts a new server.
52+
proxy mode; otherwise, it starts a new manager.
5753
"""
5854
if handle:
5955
# --- PROXY MODE INITIALIZATION ---
60-
# This instance is a client connecting to an existing server.
61-
self._is_proxy = True
62-
self._manager_server = None # Proxies do not own the server process.
63-
64-
manager = _ParallelStateManager(address=handle.address)
65-
manager.connect()
66-
self._manager = manager
67-
56+
# This instance is a client wrapping objects from an existing server.
57+
self._manager = None # Proxies do not own the manager process.
58+
self._shared_dict = handle._shared_dict
59+
self._lock = handle._lock
6860
else:
69-
# --- SERVER MODE INITIALIZATION ---
70-
# This is the main instance that owns the server process.
71-
self._is_proxy = False
72-
manager = mp.Manager() # type: ignore
73-
_ParallelStateManager.register("get_dict", callable=lambda: manager.dict(initial_context or {}))
74-
_ParallelStateManager.register("get_lock", callable=lambda: manager.Lock())
75-
76-
self._manager_server = _ParallelStateManager(address=("", 0))
77-
self._manager_server.start()
78-
self._manager = self._manager_server
79-
80-
# Common setup for both modes
81-
self._shared_dict: DictProxy = self._manager.get_dict() # type: ignore
82-
self._lock: Lock = self._manager.get_lock() # type: ignore
61+
# --- MAIN MODE INITIALIZATION ---
62+
# This instance owns the manager and its shared objects.
63+
self._manager = mp.Manager()
64+
self._shared_dict = self._manager.dict(initial_context or {})
65+
self._lock = self._manager.Lock()
66+
67+
self._is_locked = False
68+
69+
def _lock_context(self) -> None:
70+
"""Acquire the lock for this context manager."""
71+
if not self._is_locked:
72+
self._lock.acquire()
73+
self._is_locked = True
74+
75+
def _unlock_context(self) -> None:
76+
"""Release the lock for this context manager."""
77+
if self._is_locked:
78+
self._lock.release()
79+
self._is_locked = False
80+
81+
def _execute_locked(self, operation: Callable[[], R]) -> R:
82+
"""A private helper to execute an operation within a lock."""
83+
if not self._is_locked:
84+
self._lock_context()
85+
try:
86+
return operation()
87+
finally:
88+
self._unlock_context()
89+
else:
90+
return operation()
8391

8492
def get_handle(self) -> ParallelContextHandle:
8593
"""
86-
Returns a picklable handle for reconstruction in a worker.
87-
Only the main server instance can generate handles.
94+
Returns a picklable handle containing the shared dict and lock.
95+
Only the main instance can generate handles.
8896
"""
89-
if self._is_proxy or not self._manager_server:
97+
if not self._manager:
9098
raise TypeError("Cannot get a handle from a proxy context instance.")
9199

92-
return ParallelContextHandle(
93-
address=self._manager_server.address, # type: ignore
94-
manager_class=self.__class__, # Pass its own class for reconstruction
95-
)
100+
return ParallelContextHandle(self._shared_dict, self._lock)
96101

97102
def shutdown(self) -> None:
98103
"""
99104
Shuts down the background manager process.
100-
This is a no-op for proxy instances, as only the main instance
101-
should control the server's lifecycle.
105+
This is a no-op for proxy instances.
102106
"""
103-
if not self._is_proxy and self._manager_server:
104-
self._manager_server.shutdown()
107+
if self._manager:
108+
self._manager.shutdown()
105109

106110
def __enter__(self) -> "ParallelContextManager":
107111
"""Acquires the lock for use in a 'with' statement."""
108-
self._lock.acquire()
112+
self._lock_context()
109113
return self
110114

111115
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
112116
"""Releases the lock."""
113-
self._lock.release()
117+
self._unlock_context()
114118

115119
def __getitem__(self, key: str) -> Any:
116-
with self._lock:
117-
return self._shared_dict[key]
120+
return self._execute_locked(lambda: self._shared_dict[key])
118121

119122
def __setitem__(self, key: str, value: Any) -> None:
120-
with self._lock:
121-
self._shared_dict[key] = value
123+
self._execute_locked(lambda: self._shared_dict.__setitem__(key, value))
122124

123125
def __delitem__(self, key: str) -> None:
124-
with self._lock:
125-
del self._shared_dict[key]
126+
self._execute_locked(lambda: self._shared_dict.__delitem__(key))
126127

127128
def __iter__(self) -> Iterator[str]:
128-
with self._lock:
129-
return iter(list(self._shared_dict.keys()))
129+
# Iteration needs to copy the keys to be safe across processes
130+
return self._execute_locked(lambda: iter(list(self._shared_dict.keys())))
130131

131132
def __len__(self) -> int:
132-
with self._lock:
133-
return len(self._shared_dict)
133+
return self._execute_locked(lambda: len(self._shared_dict))

laygo/transformers/threaded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __call__(self, data: Iterable[In], context: IContextManager | None = None) -
119119
Returns:
120120
An iterator over the transformed data.
121121
"""
122-
run_context = context if context is not None else self._default_context
122+
run_context = context or self._default_context
123123

124124
# Since threads share memory, we can pass the context manager directly.
125125
# No handle/proxy mechanism is needed, but the locking inside

tests/test_parallel_transformer.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from laygo import ErrorHandler
77
from laygo import ParallelTransformer
8-
from laygo import PipelineContext
8+
from laygo.context import IContextManager
9+
from laygo.context import ParallelContextManager
910
from laygo.transformers.parallel import createParallelTransformer
1011
from laygo.transformers.transformer import createTransformer
1112

@@ -82,16 +83,18 @@ def test_tap_side_effects(self):
8283
assert sorted(side_effects) == [1, 2, 3, 4]
8384

8485

85-
def safe_increment(x: int, ctx: PipelineContext) -> int:
86-
with ctx["lock"]:
86+
def safe_increment(x: int, ctx: IContextManager) -> int:
87+
# Safe cast since we know ParallelContextManager implements context manager protocol
88+
with ctx: # type: ignore
8789
current_items = ctx["items"]
8890
time.sleep(0.001)
8991
ctx["items"] = current_items + 1
9092
return x * 2
9193

9294

93-
def update_stats(x: int, ctx: PipelineContext) -> int:
94-
with ctx["lock"]:
95+
def update_stats(x: int, ctx: IContextManager) -> int:
96+
# Safe cast since we know ParallelContextManager implements context manager protocol
97+
with ctx: # type: ignore
9598
ctx["total_sum"] += x
9699
ctx["item_count"] += 1
97100
ctx["max_value"] = max(ctx["max_value"], x)
@@ -103,25 +106,24 @@ class TestParallelTransformerContextSupport:
103106

104107
def test_map_with_context(self):
105108
"""Test map with context-aware function in concurrent execution."""
106-
context = PipelineContext({"multiplier": 3})
109+
context = ParallelContextManager({"multiplier": 3})
107110
transformer = createParallelTransformer(int).map(lambda x, ctx: x * ctx["multiplier"])
108111
result = list(transformer([1, 2, 3], context))
109112
assert result == [3, 6, 9]
110113

111-
def test_context_modification_with_locking(self):
112-
"""Test safe context modification with locking in concurrent execution."""
113-
context = PipelineContext({"items": 0})
114+
def test_context_aware_complex_operation(self):
115+
"""Test complex context-aware operations with shared state."""
116+
context = ParallelContextManager({"multiplier": 3, "stats": {"total": 0, "count": 0}})
114117

115-
transformer = createParallelTransformer(int, max_workers=4, chunk_size=1).map(safe_increment)
116-
data = list(range(1, 11))
118+
transformer = createParallelTransformer(int, max_workers=2, chunk_size=2).map(update_stats)
119+
data = [1, 2, 3, 4, 5]
117120
result = list(transformer(data, context))
118121

119-
assert sorted(result) == sorted([x * 2 for x in data])
120-
assert context["items"] == len(data)
121-
122122
def test_multiple_context_values_modification(self):
123123
"""Test modifying multiple context values safely."""
124-
context = PipelineContext({"total_sum": 0, "item_count": 0, "max_value": 0})
124+
from laygo.context import ParallelContextManager
125+
126+
context = ParallelContextManager({"total_sum": 0, "item_count": 0, "max_value": 0})
125127

126128
transformer = createParallelTransformer(int, max_workers=3, chunk_size=2).map(update_stats)
127129
data = [1, 5, 3, 8, 2, 7, 4, 6]

0 commit comments

Comments
 (0)