|
4 | 4 | from collections.abc import Callable |
5 | 5 | from collections.abc import Iterable |
6 | 6 | from collections.abc import Iterator |
| 7 | +from collections.abc import MutableMapping |
7 | 8 | from concurrent.futures import FIRST_COMPLETED |
8 | 9 | from concurrent.futures import Future |
9 | 10 | from concurrent.futures import ThreadPoolExecutor |
10 | 11 | from concurrent.futures import wait |
11 | 12 | import copy |
12 | 13 | from functools import partial |
13 | 14 | import itertools |
| 15 | +from multiprocessing.managers import DictProxy |
14 | 16 | import threading |
15 | 17 | from typing import Any |
16 | 18 | from typing import Union |
@@ -101,25 +103,41 @@ def from_transformer[T, U]( |
101 | 103 |
|
102 | 104 | def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: |
103 | 105 | """ |
104 | | - Executes the transformer on data concurrently. |
105 | | -
|
106 | | - A new `threading.Lock` is created and added to the context for each call |
107 | | - to ensure execution runs are isolated and thread-safe. |
| 106 | + Executes the transformer on data concurrently. It uses the shared |
| 107 | + context provided by the Pipeline, if available. |
108 | 108 | """ |
109 | | - # Determine the context for this run, passing it by reference as requested. |
110 | | - run_context = context or self.context |
111 | | - # Add a per-call lock for thread safety. |
112 | | - run_context["lock"] = threading.Lock() |
113 | | - |
114 | | - def process_chunk(chunk: list[In], shared_context: PipelineContext) -> list[Out]: |
| 109 | + run_context = context if context is not None else self.context |
| 110 | + |
| 111 | + # Detect if the context is already managed by the Pipeline. |
| 112 | + is_managed_context = isinstance(run_context, DictProxy) |
| 113 | + |
| 114 | + if is_managed_context: |
| 115 | + # Use the existing shared context and lock from the Pipeline. |
| 116 | + shared_context = run_context |
| 117 | + yield from self._execute_with_context(data, shared_context) |
| 118 | + # The context is live, so no need to update it here. |
| 119 | + # The Pipeline's __del__ will handle final state. |
| 120 | + else: |
| 121 | + # Fallback for standalone use: create a thread-safe context. |
| 122 | + # Since threads share memory, we can use the context directly with a lock. |
| 123 | + if "lock" not in run_context: |
| 124 | + run_context["lock"] = threading.Lock() |
| 125 | + |
| 126 | + yield from self._execute_with_context(data, run_context) |
| 127 | + # Context is already updated in-place for threads (shared memory) |
| 128 | + |
| 129 | + def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]: |
| 130 | + """Helper to run the execution logic with a given context.""" |
| 131 | + |
| 132 | + def process_chunk(chunk: list[In], shared_context: MutableMapping[str, Any]) -> list[Out]: |
115 | 133 | """ |
116 | 134 | Process a single chunk by passing the chunk and context explicitly |
117 | 135 | to the transformer chain. This is safer and avoids mutating self. |
118 | 136 | """ |
119 | | - return self.transformer(chunk, shared_context) |
| 137 | + return self.transformer(chunk, shared_context) # type: ignore |
120 | 138 |
|
121 | | - # Create a partial function with the run_context "baked in". |
122 | | - process_chunk_with_context = partial(process_chunk, shared_context=run_context) |
| 139 | + # Create a partial function with the shared_context "baked in". |
| 140 | + process_chunk_with_context = partial(process_chunk, shared_context=shared_context) |
123 | 141 |
|
124 | 142 | def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]: |
125 | 143 | """Generate results in their original order.""" |
|
0 commit comments