Skip to content

Commit 434a3ae

Browse files
committed
fix: apply the same context management to threaded transformer
1 parent ec2fe3d commit 434a3ae

3 files changed

Lines changed: 37 additions & 34 deletions

File tree

laygo/transformers/parallel.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,14 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -
119119

120120
def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]:
121121
"""Helper to run the execution logic with a given context."""
122-
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
123-
executor = get_reusable_executor(max_workers=self.max_workers)
122+
executor = get_reusable_executor(max_workers=self.max_workers)
124123

125-
chunks_to_process = self._chunk_generator(data)
126-
gen_func = self._ordered_generator if self.ordered else self._unordered_generator
127-
processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context)
124+
chunks_to_process = self._chunk_generator(data)
125+
gen_func = self._ordered_generator if self.ordered else self._unordered_generator
126+
processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context)
128127

129-
for result_chunk in processed_chunks_iterator:
130-
yield from result_chunk
128+
for result_chunk in processed_chunks_iterator:
129+
yield from result_chunk
131130

132131
# ... The rest of the file remains the same ...
133132
def _ordered_generator(

laygo/transformers/threaded.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from collections.abc import Callable
55
from collections.abc import Iterable
66
from collections.abc import Iterator
7+
from collections.abc import MutableMapping
78
from concurrent.futures import FIRST_COMPLETED
89
from concurrent.futures import Future
910
from concurrent.futures import ThreadPoolExecutor
1011
from concurrent.futures import wait
1112
import copy
1213
from functools import partial
1314
import itertools
15+
from multiprocessing.managers import DictProxy
1416
import threading
1517
from typing import Any
1618
from typing import Union
@@ -101,25 +103,41 @@ def from_transformer[T, U](
101103

102104
def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]:
103105
"""
104-
Executes the transformer on data concurrently.
105-
106-
A new `threading.Lock` is created and added to the context for each call
107-
to ensure execution runs are isolated and thread-safe.
106+
Executes the transformer on data concurrently. It uses the shared
107+
context provided by the Pipeline, if available.
108108
"""
109-
# Determine the context for this run, passing it by reference as requested.
110-
run_context = context or self.context
111-
# Add a per-call lock for thread safety.
112-
run_context["lock"] = threading.Lock()
113-
114-
def process_chunk(chunk: list[In], shared_context: PipelineContext) -> list[Out]:
109+
run_context = context if context is not None else self.context
110+
111+
# Detect if the context is already managed by the Pipeline.
112+
is_managed_context = isinstance(run_context, DictProxy)
113+
114+
if is_managed_context:
115+
# Use the existing shared context and lock from the Pipeline.
116+
shared_context = run_context
117+
yield from self._execute_with_context(data, shared_context)
118+
# The context is live, so no need to update it here.
119+
# The Pipeline's __del__ will handle final state.
120+
else:
121+
# Fallback for standalone use: create a thread-safe context.
122+
# Since threads share memory, we can use the context directly with a lock.
123+
if "lock" not in run_context:
124+
run_context["lock"] = threading.Lock()
125+
126+
yield from self._execute_with_context(data, run_context)
127+
# Context is already updated in-place for threads (shared memory)
128+
129+
def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]:
130+
"""Helper to run the execution logic with a given context."""
131+
132+
def process_chunk(chunk: list[In], shared_context: MutableMapping[str, Any]) -> list[Out]:
115133
"""
116134
Process a single chunk by passing the chunk and context explicitly
117135
to the transformer chain. This is safer and avoids mutating self.
118136
"""
119-
return self.transformer(chunk, shared_context)
137+
return self.transformer(chunk, shared_context) # type: ignore
120138

121-
# Create a partial function with the run_context "baked in".
122-
process_chunk_with_context = partial(process_chunk, shared_context=run_context)
139+
# Create a partial function with the shared_context "baked in".
140+
process_chunk_with_context = partial(process_chunk, shared_context=shared_context)
123141

124142
def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]:
125143
"""Generate results in their original order."""

tests/test_parallel_transformer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import multiprocessing as mp
44
import time
5-
from unittest.mock import patch
65

76
from laygo import ErrorHandler
87
from laygo import ParallelTransformer
@@ -159,19 +158,6 @@ def test_unordered_vs_ordered_same_elements(self):
159158
assert sorted(ordered_result) == sorted(unordered_result)
160159
assert ordered_result == [x * 2 for x in data]
161160

162-
def test_process_pool_management(self):
163-
"""Test that process pool is properly created and cleaned up."""
164-
with patch("laygo.transformers.parallel.ProcessPoolExecutor") as mock_executor:
165-
mock_executor.return_value.__enter__.return_value = mock_executor.return_value
166-
mock_executor.return_value.__exit__.return_value = None
167-
mock_executor.return_value.submit.return_value.result.return_value = [2, 4]
168-
transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=2)
169-
list(transformer([1, 2]))
170-
171-
mock_executor.assert_called_with(max_workers=2)
172-
mock_executor.return_value.__enter__.assert_called_once()
173-
mock_executor.return_value.__exit__.assert_called_once()
174-
175161

176162
class TestParallelTransformerChunkingAndEdgeCases:
177163
"""Test chunking behavior and edge cases."""

0 commit comments

Comments
 (0)