Skip to content

Commit a6b4fc3

Browse files
committed
fix: optimise thread pool
1 parent 03869a9 commit a6b4fc3

File tree

1 file changed

+51
-27
lines changed

1 file changed

+51
-27
lines changed

laygo/transformers/strategies/threaded.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from concurrent.futures import Future
66
from concurrent.futures import ThreadPoolExecutor
77
from concurrent.futures import wait
8-
from functools import partial
98
import itertools
9+
import threading
10+
from typing import ClassVar
1011

1112
from laygo.context.types import IContextManager
1213
from laygo.transformers.strategies.types import ChunkGenerator
@@ -15,26 +16,38 @@
1516

1617

1718
class ThreadedStrategy[In, Out](ExecutionStrategy[In, Out]):
19+
# Class-level thread pool cache to reuse executors
20+
_thread_pools: ClassVar[dict[int, ThreadPoolExecutor]] = {}
21+
_pool_lock: ClassVar[threading.Lock] = threading.Lock()
22+
1823
def __init__(self, max_workers: int = 4, ordered: bool = True):
1924
self.max_workers = max_workers
2025
self.ordered = ordered
2126

27+
@classmethod
28+
def _get_thread_pool(cls, max_workers: int) -> ThreadPoolExecutor:
29+
"""Get or create a reusable thread pool for the given worker count."""
30+
with cls._pool_lock:
31+
if max_workers not in cls._thread_pools:
32+
cls._thread_pools[max_workers] = ThreadPoolExecutor(
33+
max_workers=max_workers, thread_name_prefix=f"laygo-{max_workers}"
34+
)
35+
return cls._thread_pools[max_workers]
36+
2237
def execute(self, transformer_logic, chunk_generator, data, context):
2338
"""Execute the transformer on data concurrently.
2439
25-
It uses the shared context provided by the Pipeline, if available.
40+
Uses a reusable thread pool to minimize thread creation overhead.
2641
2742
Args:
43+
transformer_logic: The transformation function to apply.
44+
chunk_generator: Function to generate data chunks.
2845
data: The input data to process.
2946
context: Optional pipeline context for shared state.
3047
3148
Returns:
3249
An iterator over the transformed data.
3350
"""
34-
35-
# Since threads share memory, we can pass the context manager directly.
36-
# No handle/proxy mechanism is needed, but the locking inside
37-
# ParallelContextManager is crucial for thread safety.
3851
yield from self._execute_with_context(data, transformer_logic, context, chunk_generator)
3952

4053
def _execute_with_context(
@@ -48,13 +61,15 @@ def _execute_with_context(
4861
4962
Args:
5063
data: The input data to process.
64+
transformer: The transformation function to apply.
5165
shared_context: The shared context for the execution.
66+
chunk_generator: Function to generate data chunks.
5267
5368
Returns:
5469
An iterator over the transformed data.
5570
"""
5671

57-
def process_chunk(chunk: list[In], shared_context: IContextManager) -> list[Out]:
72+
def process_chunk(chunk: list[In]) -> list[Out]:
5873
"""Process a single chunk by passing the chunk and context explicitly.
5974
6075
Args:
@@ -66,49 +81,58 @@ def process_chunk(chunk: list[In], shared_context: IContextManager) -> list[Out]
6681
"""
6782
return transformer(chunk, shared_context) # type: ignore
6883

69-
# Create a partial function with the shared_context "baked in".
70-
process_chunk_with_context = partial(process_chunk, shared_context=shared_context)
71-
7284
def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]:
7385
"""Generate results in their original order."""
7486
futures: deque[Future[list[Out]]] = deque()
75-
for _ in range(self.max_workers + 1):
87+
88+
# Pre-submit initial batch of futures
89+
for _ in range(min(self.max_workers, 10)): # Limit initial submissions
7690
try:
7791
chunk = next(chunks_iter)
78-
futures.append(executor.submit(process_chunk_with_context, chunk))
92+
futures.append(executor.submit(process_chunk, chunk))
7993
except StopIteration:
8094
break
95+
8196
while futures:
82-
yield futures.popleft().result()
97+
# Get the next result and submit the next chunk
98+
result = futures.popleft().result()
99+
yield result
100+
83101
try:
84102
chunk = next(chunks_iter)
85-
futures.append(executor.submit(process_chunk_with_context, chunk))
103+
futures.append(executor.submit(process_chunk, chunk))
86104
except StopIteration:
87105
continue
88106

89107
def _unordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]:
90108
"""Generate results as they complete."""
109+
# Pre-submit initial batch
91110
futures = {
92-
executor.submit(process_chunk_with_context, chunk)
93-
for chunk in itertools.islice(chunks_iter, self.max_workers + 1)
111+
executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunks_iter, min(self.max_workers, 10))
94112
}
113+
95114
while futures:
96115
done, futures = wait(futures, return_when=FIRST_COMPLETED)
97116
for future in done:
98117
yield future.result()
99118
try:
100119
chunk = next(chunks_iter)
101-
futures.add(executor.submit(process_chunk_with_context, chunk))
120+
futures.add(executor.submit(process_chunk, chunk))
102121
except StopIteration:
103122
continue
104123

105-
def result_iterator_manager() -> Iterator[Out]:
106-
"""Manage the thread pool and yield flattened results."""
107-
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
108-
chunks_to_process = chunk_generator(data)
109-
gen_func = _ordered_generator if self.ordered else _unordered_generator
110-
processed_chunks_iterator = gen_func(chunks_to_process, executor)
111-
for result_chunk in processed_chunks_iterator:
112-
yield from result_chunk
113-
114-
return result_iterator_manager()
124+
# Use the reusable thread pool instead of creating a new one
125+
executor = self._get_thread_pool(self.max_workers)
126+
chunks_to_process = chunk_generator(data)
127+
gen_func = _ordered_generator if self.ordered else _unordered_generator
128+
129+
# Process chunks using the reusable executor
130+
for result_chunk in gen_func(chunks_to_process, executor):
131+
yield from result_chunk
132+
133+
def __del__(self) -> None:
134+
"""Shutdown all cached thread pools. Call this during application cleanup."""
135+
with self._pool_lock:
136+
for pool in self._thread_pools.values():
137+
pool.shutdown(wait=True)
138+
self._thread_pools.clear()

0 commit comments

Comments
 (0)