|
3 | 3 | multiprocessing.Manager to share state across processes. |
4 | 4 | """ |
5 | 5 |
|
| 6 | +from collections.abc import Callable |
6 | 7 | from collections.abc import Iterator |
7 | 8 | import multiprocessing as mp |
8 | | -from multiprocessing.managers import BaseManager |
9 | 9 | from multiprocessing.managers import DictProxy |
10 | | -from multiprocessing.synchronize import Lock |
| 10 | +from threading import Lock |
11 | 11 | from typing import Any |
| 12 | +from typing import TypeVar |
12 | 13 |
|
13 | 14 | from laygo.context.types import IContextHandle |
14 | 15 | from laygo.context.types import IContextManager |
15 | 16 |
|
16 | | - |
17 | | -class _ParallelStateManager(BaseManager): |
18 | | - """A custom manager to expose a shared dictionary and lock.""" |
19 | | - |
20 | | - pass |
| 17 | +R = TypeVar("R") |
21 | 18 |
|
22 | 19 |
|
23 | 20 | class ParallelContextHandle(IContextHandle): |
24 | 21 | """ |
25 | | - A lightweight, picklable "blueprint" for recreating a connection to the |
26 | | - shared context in a different process. |
| 22 | + A lightweight, picklable handle that carries the actual shared objects |
| 23 | + (the DictProxy and Lock) to worker processes. |
27 | 24 | """ |
28 | 25 |
|
29 | | - def __init__(self, address: tuple[str, int], manager_class: type["ParallelContextManager"]): |
30 | | - self.address = address |
31 | | - self.manager_class = manager_class |
| 26 | + def __init__(self, shared_dict: DictProxy, lock: Lock): |
| 27 | + self._shared_dict = shared_dict |
| 28 | + self._lock = lock |
32 | 29 |
|
33 | 30 | def create_proxy(self) -> "IContextManager": |
34 | 31 | """ |
35 | | - Creates a new instance of the ParallelContextManager in "proxy" mode |
36 | | - by initializing it with this handle. |
| 32 | + Creates a new ParallelContextManager instance that wraps the shared |
| 33 | + objects received by the worker process. |
37 | 34 | """ |
38 | | - return self.manager_class(handle=self) |
| 35 | + return ParallelContextManager(handle=self) |
39 | 36 |
|
40 | 37 |
|
41 | 38 | class ParallelContextManager(IContextManager): |
42 | 39 | """ |
43 | | - A context manager that uses a background multiprocessing.Manager to enable |
44 | | - state sharing across different processes. |
45 | | -
|
46 | | - This single class operates in two modes: |
47 | | - 1. Server Mode (when created normally): It starts and manages the background |
48 | | - server process that holds the shared state. |
49 | | - 2. Proxy Mode (when created with a handle): It acts as a client, connecting |
50 | | - to an existing server process to manipulate the shared state. |
| 40 | + A context manager that enables state sharing across processes. |
| 41 | +
|
| 42 | + It operates in two modes: |
| 43 | + 1. Main Mode: When created normally, it starts a multiprocessing.Manager |
| 44 | + and creates a shared dictionary and lock. |
| 45 | + 2. Proxy Mode: When created from a handle, it wraps a DictProxy and Lock |
| 46 | + that were received from another process. It does not own the manager. |
51 | 47 | """ |
52 | 48 |
|
53 | 49 | def __init__(self, initial_context: dict[str, Any] | None = None, handle: ParallelContextHandle | None = None): |
54 | 50 | """ |
55 | 51 | Initializes the manager. If a handle is provided, it initializes in |
56 | | - proxy mode; otherwise, it starts a new server. |
| 52 | + proxy mode; otherwise, it starts a new manager. |
57 | 53 | """ |
58 | 54 | if handle: |
59 | 55 | # --- PROXY MODE INITIALIZATION --- |
60 | | - # This instance is a client connecting to an existing server. |
61 | | - self._is_proxy = True |
62 | | - self._manager_server = None # Proxies do not own the server process. |
63 | | - |
64 | | - manager = _ParallelStateManager(address=handle.address) |
65 | | - manager.connect() |
66 | | - self._manager = manager |
67 | | - |
| 56 | + # This instance is a client wrapping objects from an existing server. |
| 57 | + self._manager = None # Proxies do not own the manager process. |
| 58 | + self._shared_dict = handle._shared_dict |
| 59 | + self._lock = handle._lock |
68 | 60 | else: |
69 | | - # --- SERVER MODE INITIALIZATION --- |
70 | | - # This is the main instance that owns the server process. |
71 | | - self._is_proxy = False |
72 | | - manager = mp.Manager() # type: ignore |
73 | | - _ParallelStateManager.register("get_dict", callable=lambda: manager.dict(initial_context or {})) |
74 | | - _ParallelStateManager.register("get_lock", callable=lambda: manager.Lock()) |
75 | | - |
76 | | - self._manager_server = _ParallelStateManager(address=("", 0)) |
77 | | - self._manager_server.start() |
78 | | - self._manager = self._manager_server |
79 | | - |
80 | | - # Common setup for both modes |
81 | | - self._shared_dict: DictProxy = self._manager.get_dict() # type: ignore |
82 | | - self._lock: Lock = self._manager.get_lock() # type: ignore |
| 61 | + # --- MAIN MODE INITIALIZATION --- |
| 62 | + # This instance owns the manager and its shared objects. |
| 63 | + self._manager = mp.Manager() |
| 64 | + self._shared_dict = self._manager.dict(initial_context or {}) |
| 65 | + self._lock = self._manager.Lock() |
| 66 | + |
| 67 | + self._is_locked = False |
| 68 | + |
| 69 | + def _lock_context(self) -> None: |
| 70 | + """Acquire the lock for this context manager.""" |
| 71 | + if not self._is_locked: |
| 72 | + self._lock.acquire() |
| 73 | + self._is_locked = True |
| 74 | + |
| 75 | + def _unlock_context(self) -> None: |
| 76 | + """Release the lock for this context manager.""" |
| 77 | + if self._is_locked: |
| 78 | + self._lock.release() |
| 79 | + self._is_locked = False |
| 80 | + |
| 81 | + def _execute_locked(self, operation: Callable[[], R]) -> R: |
| 82 | + """A private helper to execute an operation within a lock.""" |
| 83 | + if not self._is_locked: |
| 84 | + self._lock_context() |
| 85 | + try: |
| 86 | + return operation() |
| 87 | + finally: |
| 88 | + self._unlock_context() |
| 89 | + else: |
| 90 | + return operation() |
83 | 91 |
|
84 | 92 | def get_handle(self) -> ParallelContextHandle: |
85 | 93 | """ |
86 | | - Returns a picklable handle for reconstruction in a worker. |
87 | | - Only the main server instance can generate handles. |
| 94 | + Returns a picklable handle containing the shared dict and lock. |
| 95 | + Only the main instance can generate handles. |
88 | 96 | """ |
89 | | - if self._is_proxy or not self._manager_server: |
| 97 | + if not self._manager: |
90 | 98 | raise TypeError("Cannot get a handle from a proxy context instance.") |
91 | 99 |
|
92 | | - return ParallelContextHandle( |
93 | | - address=self._manager_server.address, # type: ignore |
94 | | - manager_class=self.__class__, # Pass its own class for reconstruction |
95 | | - ) |
| 100 | + return ParallelContextHandle(self._shared_dict, self._lock) |
96 | 101 |
|
97 | 102 | def shutdown(self) -> None: |
98 | 103 | """ |
99 | 104 | Shuts down the background manager process. |
100 | | - This is a no-op for proxy instances, as only the main instance |
101 | | - should control the server's lifecycle. |
| 105 | + This is a no-op for proxy instances. |
102 | 106 | """ |
103 | | - if not self._is_proxy and self._manager_server: |
104 | | - self._manager_server.shutdown() |
| 107 | + if self._manager: |
| 108 | + self._manager.shutdown() |
105 | 109 |
|
106 | 110 | def __enter__(self) -> "ParallelContextManager": |
107 | 111 | """Acquires the lock for use in a 'with' statement.""" |
108 | | - self._lock.acquire() |
| 112 | + self._lock_context() |
109 | 113 | return self |
110 | 114 |
|
111 | 115 | def __exit__(self, exc_type, exc_val, exc_tb) -> None: |
112 | 116 | """Releases the lock.""" |
113 | | - self._lock.release() |
| 117 | + self._unlock_context() |
114 | 118 |
|
115 | 119 | def __getitem__(self, key: str) -> Any: |
116 | | - with self._lock: |
117 | | - return self._shared_dict[key] |
| 120 | + return self._execute_locked(lambda: self._shared_dict[key]) |
118 | 121 |
|
119 | 122 | def __setitem__(self, key: str, value: Any) -> None: |
120 | | - with self._lock: |
121 | | - self._shared_dict[key] = value |
| 123 | + self._execute_locked(lambda: self._shared_dict.__setitem__(key, value)) |
122 | 124 |
|
123 | 125 | def __delitem__(self, key: str) -> None: |
124 | | - with self._lock: |
125 | | - del self._shared_dict[key] |
| 126 | + self._execute_locked(lambda: self._shared_dict.__delitem__(key)) |
126 | 127 |
|
127 | 128 | def __iter__(self) -> Iterator[str]: |
128 | | - with self._lock: |
129 | | - return iter(list(self._shared_dict.keys())) |
| 129 | + # Iteration needs to copy the keys to be safe across processes |
| 130 | + return self._execute_locked(lambda: iter(list(self._shared_dict.keys()))) |
130 | 131 |
|
131 | 132 | def __len__(self) -> int: |
132 | | - with self._lock: |
133 | | - return len(self._shared_dict) |
| 133 | + return self._execute_locked(lambda: len(self._shared_dict)) |
0 commit comments