Skip to content

Commit 8b61a14

Browse files
committed
chore: implemented use of the new context manager
1 parent 17bbe44 commit 8b61a14

8 files changed

Lines changed: 201 additions & 254 deletions

File tree

laygo/context/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Laygo Context Management Package.
3+
4+
This package provides different strategies for managing state (context)
5+
within a data pipeline, from simple in-memory dictionaries to
6+
process-safe managers for parallel execution.
7+
"""
8+
9+
from .parallel import ParallelContextManager
10+
from .simple import SimpleContextManager
11+
from .types import IContextHandle
12+
from .types import IContextManager
13+
14+
__all__ = [
15+
"IContextManager",
16+
"IContextHandle",
17+
"SimpleContextManager",
18+
"ParallelContextManager",
19+
]

laygo/errors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections.abc import Callable
22

3-
from laygo.helpers import PipelineContext
3+
from laygo.context.types import IContextManager
44

5-
ChunkErrorHandler = Callable[[list, Exception, PipelineContext], None]
5+
ChunkErrorHandler = Callable[[list, Exception, IContextManager], None]
66

77

8-
def raise_error(chunk: list, error: Exception, context: PipelineContext) -> None:
8+
def raise_error(chunk: list, error: Exception, context: IContextManager) -> None:
99
"""Handler that always re-raises the error, stopping execution.
1010
1111
This is a default error handler that provides fail-fast behavior by
@@ -47,7 +47,7 @@ def on_error(self, handler: ChunkErrorHandler) -> "ErrorHandler":
4747
self._handlers.insert(0, handler)
4848
return self
4949

50-
def handle(self, chunk: list, error: Exception, context: PipelineContext) -> None:
50+
def handle(self, chunk: list, error: Exception, context: IContextManager) -> None:
5151
"""Execute all handlers in the chain.
5252
5353
Handlers are executed in reverse order of addition. Execution stops

laygo/helpers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from typing import Any
44
from typing import TypeGuard
55

6+
from laygo.context.types import IContextManager
7+
68

79
class PipelineContext(dict[str, Any]):
810
"""Generic, untyped context available to all pipeline operations.
911
12+
DEPRECATED: This class is deprecated and will be removed in a future version.
13+
Use IContextManager implementations (SimpleContextManager, etc.) instead.
14+
1015
A dictionary-based context that can store arbitrary data shared across
1116
pipeline operations. This allows passing state and configuration between
1217
different stages of data processing.
@@ -16,14 +21,14 @@ class PipelineContext(dict[str, Any]):
1621

1722

1823
# Define the specific callables for clarity
19-
ContextAwareCallable = Callable[[Any, PipelineContext], Any]
20-
ContextAwareReduceCallable = Callable[[Any, Any, PipelineContext], Any]
24+
ContextAwareCallable = Callable[[Any, IContextManager], Any]
25+
ContextAwareReduceCallable = Callable[[Any, Any, IContextManager], Any]
2126

2227

2328
def is_context_aware(func: Callable[..., Any]) -> TypeGuard[ContextAwareCallable]:
2429
"""Check if a function is context-aware by inspecting its signature.
2530
26-
A context-aware function accepts a PipelineContext as its second parameter,
31+
A context-aware function accepts an IContextManager as its second parameter,
2732
allowing it to access shared state during pipeline execution.
2833
2934
Args:
@@ -40,7 +45,7 @@ def is_context_aware_reduce(func: Callable[..., Any]) -> TypeGuard[ContextAwareR
4045
"""Check if a reduce function is context-aware by inspecting its signature.
4146
4247
A context-aware reduce function accepts an accumulator, current value,
43-
and PipelineContext as its three parameters.
48+
and IContextManager as its three parameters.
4449
4550
Args:
4651
func: The reduce function to inspect for context awareness.

laygo/pipeline.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from concurrent.futures import ThreadPoolExecutor
66
from concurrent.futures import as_completed
77
import itertools
8-
import multiprocessing as mp
98
from queue import Queue
109
from typing import Any
1110
from typing import TypeVar
1211
from typing import overload
1312

14-
from laygo.helpers import PipelineContext
13+
from laygo.context import IContextManager
14+
from laygo.context import SimpleContextManager
1515
from laygo.helpers import is_context_aware
1616
from laygo.transformers.transformer import Transformer
1717
from laygo.transformers.transformer import passthrough_chunks
@@ -44,12 +44,14 @@ class Pipeline[T]:
4444
pipeline effectively single-use unless the data source is re-initialized.
4545
"""
4646

47-
def __init__(self, *data: Iterable[T]) -> None:
47+
def __init__(self, *data: Iterable[T], context_manager: IContextManager | None = None) -> None:
4848
"""Initialize a pipeline with one or more data sources.
4949
5050
Args:
5151
*data: One or more iterable data sources. If multiple sources are
5252
provided, they will be chained together.
53+
context_manager: An instance of a class that implements IContextManager.
54+
If None, a SimpleContextManager is used by default.
5355
5456
Raises:
5557
ValueError: If no data sources are provided.
@@ -59,25 +61,16 @@ def __init__(self, *data: Iterable[T]) -> None:
5961
self.data_source: Iterable[T] = itertools.chain.from_iterable(data) if len(data) > 1 else data[0]
6062
self.processed_data: Iterator = iter(self.data_source)
6163

62-
# Always create a shared context with multiprocessing manager
63-
self._manager = mp.Manager()
64-
self.ctx = self._manager.dict()
65-
# Add a shared lock to the context for safe concurrent updates
66-
self.ctx["lock"] = self._manager.Lock()
67-
68-
# Store reference to original context for final synchronization
69-
self._original_context_ref: PipelineContext | None = None
64+
# Rule 1: Pipeline creates a simple context manager by default.
65+
self.context_manager = context_manager or SimpleContextManager()
7066

7167
def __del__(self) -> None:
72-
"""Clean up the multiprocessing manager when the pipeline is destroyed."""
73-
try:
74-
self._sync_context_back()
75-
self._manager.shutdown()
76-
except Exception:
77-
pass
68+
"""Clean up the context manager when the pipeline is destroyed."""
69+
if hasattr(self, "context_manager"):
70+
self.context_manager.shutdown()
7871

79-
def context(self, ctx: PipelineContext) -> "Pipeline[T]":
80-
"""Update the pipeline context and store a reference to the original context.
72+
def context(self, ctx: dict[str, Any]) -> "Pipeline[T]":
73+
"""Update the pipeline's context manager with values from a dictionary.
8174
8275
The provided context will be used during pipeline execution and any
8376
modifications made by transformers will be synchronized back to the
@@ -96,10 +89,7 @@ def context(self, ctx: PipelineContext) -> "Pipeline[T]":
9689
automatically synchronized back to the original context object
9790
when the pipeline is destroyed or processing completes.
9891
"""
99-
# Store reference to the original context
100-
self._original_context_ref = ctx
101-
# Copy the context data to the pipeline's shared context
102-
self.ctx.update(ctx)
92+
self.context_manager.update(ctx)
10393
return self
10494

10595
def _sync_context_back(self) -> None:
@@ -108,12 +98,9 @@ def _sync_context_back(self) -> None:
10898
This is called after processing is complete to update the original
10999
context with any changes made during pipeline execution.
110100
"""
111-
if self._original_context_ref is not None:
112-
# Copy the final context state back to the original context reference
113-
final_context_state = dict(self.ctx)
114-
final_context_state.pop("lock", None) # Remove non-serializable lock
115-
self._original_context_ref.clear()
116-
self._original_context_ref.update(final_context_state)
101+
# This method is kept for backward compatibility but is no longer needed
102+
# since we use the context manager directly
103+
pass
117104

118105
def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]":
119106
"""Apply a transformation using a lambda function.
@@ -146,13 +133,13 @@ def apply[U](self, transformer: Transformer[T, U]) -> "Pipeline[U]": ...
146133
def apply[U](self, transformer: Callable[[Iterable[T]], Iterator[U]]) -> "Pipeline[U]": ...
147134

148135
@overload
149-
def apply[U](self, transformer: Callable[[Iterable[T], PipelineContext], Iterator[U]]) -> "Pipeline[U]": ...
136+
def apply[U](self, transformer: Callable[[Iterable[T], IContextManager], Iterator[U]]) -> "Pipeline[U]": ...
150137

151138
def apply[U](
152139
self,
153140
transformer: Transformer[T, U]
154141
| Callable[[Iterable[T]], Iterator[U]]
155-
| Callable[[Iterable[T], PipelineContext], Iterator[U]],
142+
| Callable[[Iterable[T], IContextManager], Iterator[U]],
156143
) -> "Pipeline[U]":
157144
"""Apply a transformer to the current data source.
158145
@@ -181,10 +168,11 @@ def apply[U](
181168
"""
182169
match transformer:
183170
case Transformer():
184-
self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore
171+
# Pass the pipeline's context manager to the transformer
172+
self.processed_data = transformer(self.processed_data, self.context_manager) # type: ignore
185173
case _ if callable(transformer):
186174
if is_context_aware(transformer):
187-
self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore
175+
self.processed_data = transformer(self.processed_data, self.context_manager) # type: ignore
188176
else:
189177
self.processed_data = transformer(self.processed_data) # type: ignore
190178
case _:
@@ -256,12 +244,12 @@ def consumer(transformer: Transformer, queue: Queue) -> list[Any]:
256244

257245
def stream_from_queue() -> Iterator[T]:
258246
while (batch := queue.get()) is not None:
259-
yield batch
247+
yield from batch
260248

261249
if use_queue_chunks:
262250
transformer = transformer.set_chunker(passthrough_chunks)
263251

264-
result_iterator = transformer(stream_from_queue(), self.ctx) # type: ignore
252+
result_iterator = transformer(stream_from_queue(), self.context_manager) # type: ignore
265253
return list(result_iterator)
266254

267255
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:

laygo/transformers/http.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
import requests
1818

19+
from laygo.context import IContextManager
20+
from laygo.context import SimpleContextManager
1921
from laygo.errors import ErrorHandler
20-
from laygo.helpers import PipelineContext
2122
from laygo.transformers.transformer import ChunkErrorHandler
2223
from laygo.transformers.transformer import PipelineFunction
2324
from laygo.transformers.transformer import Transformer
@@ -85,6 +86,8 @@ def __init__(
8586
self.max_workers = max_workers
8687
self.session = requests.Session()
8788
self._worker_url: str | None = None
89+
# HTTP transformers always use a simple context manager to avoid serialization issues
90+
self._default_context = SimpleContextManager()
8891

8992
def _finalize_config(self) -> None:
9093
"""Determine the final worker URL, generating one if needed.
@@ -107,19 +110,22 @@ def _finalize_config(self) -> None:
107110
self.endpoint = path.lstrip("/")
108111
self._worker_url = f"{self.base_url}/{self.endpoint}"
109112

110-
def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]:
113+
def __call__(self, data: Iterable[In], context: IContextManager | None = None) -> Iterator[Out]:
111114
"""Execute distributed processing on the data (CLIENT-SIDE).
112115
113116
This method is called by the Pipeline to start distributed processing.
114117
It sends data chunks to worker endpoints via HTTP.
115118
116119
Args:
117120
data: The input data to process.
118-
context: Optional pipeline context (currently not used in HTTP mode).
121+
context: Optional pipeline context. HTTP transformers always use their
122+
internal SimpleContextManager regardless of the provided context.
119123
120124
Returns:
121125
An iterator over the processed data.
122126
"""
127+
run_context = context or self._default_context
128+
123129
self._finalize_config()
124130

125131
def process_chunk(chunk: list) -> list:
@@ -143,18 +149,24 @@ def process_chunk(chunk: list) -> list:
143149
print(f"Error calling worker {self._worker_url}: {e}")
144150
return []
145151

146-
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
147-
chunk_iterator = self._chunk_generator(data)
148-
futures = {executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunk_iterator, self.max_workers)}
149-
while futures:
150-
done, futures = wait(futures, return_when=FIRST_COMPLETED)
151-
for future in done:
152-
yield from future.result()
153-
try:
154-
new_chunk = next(chunk_iterator)
155-
futures.add(executor.submit(process_chunk, new_chunk))
156-
except StopIteration:
157-
continue
152+
try:
153+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
154+
chunk_iterator = self._chunk_generator(data)
155+
futures = {
156+
executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunk_iterator, self.max_workers)
157+
}
158+
while futures:
159+
done, futures = wait(futures, return_when=FIRST_COMPLETED)
160+
for future in done:
161+
yield from future.result()
162+
try:
163+
new_chunk = next(chunk_iterator)
164+
futures.add(executor.submit(process_chunk, new_chunk))
165+
except StopIteration:
166+
continue
167+
finally:
168+
# Always clean up our context since we always use the default one
169+
run_context.shutdown()
158170

159171
def get_route(self):
160172
"""Get the route configuration for registering this transformer as a worker.
@@ -167,7 +179,7 @@ def get_route(self):
167179
"""
168180
self._finalize_config()
169181

170-
def worker_view_func(chunk: list, context: PipelineContext):
182+
def worker_view_func(chunk: list, context: IContextManager):
171183
"""The actual worker logic for this transformer.
172184
173185
Args:
@@ -226,6 +238,6 @@ def catch[U](
226238
super().catch(sub_pipeline_builder, on_error)
227239
return self # type: ignore
228240

229-
def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "HTTPTransformer[In, Out]":
241+
def short_circuit(self, function: Callable[[IContextManager], bool | None]) -> "HTTPTransformer[In, Out]":
230242
super().short_circuit(function)
231243
return self

0 commit comments

Comments
 (0)