Skip to content

Commit 5a1ce05

Browse files
committed
chore: pipeline now returns results and the final context
1 parent 6e1ef29 commit 5a1ce05

File tree

9 files changed

+232
-182
lines changed

9 files changed

+232
-182
lines changed

laygo/context/parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@ def __iter__(self) -> Iterator[str]:
131131

132132
def __len__(self) -> int:
133133
return self._execute_locked(lambda: len(self._shared_dict))
134+
135+
def to_dict(self) -> dict[str, Any]:
136+
return self._execute_locked(lambda: dict(self._shared_dict))

laygo/context/simple.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,11 @@ def __len__(self) -> int:
8787
def shutdown(self) -> None:
8888
"""No-op for the simple context manager."""
8989
pass
90+
91+
def to_dict(self) -> dict[str, Any]:
92+
"""
93+
Returns a copy of the entire context as a standard Python dictionary.
94+
95+
This operation is performed atomically to ensure consistency.
96+
"""
97+
return self._context

laygo/context/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,18 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException
8787
exc_tb: The traceback object, if an exception was raised.
8888
"""
8989
self.shutdown()
90+
91+
def to_dict(self) -> dict[str, Any]:
92+
"""
93+
Returns a copy of the entire shared context as a standard
94+
Python dictionary.
95+
96+
This operation is performed atomically using a lock to ensure a
97+
consistent snapshot of the context is returned.
98+
99+
Returns:
100+
A standard dict containing a copy of the shared context.
101+
"""
102+
# The dict() constructor iterates over the proxy and copies its items.
103+
# The lock ensures this happens atomically without race conditions.
104+
raise NotImplementedError

laygo/pipeline.py

Lines changed: 100 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def context(self, ctx: dict[str, Any]) -> "Pipeline[T]":
8989
automatically synchronized back to the original context object
9090
when the pipeline is destroyed or processing completes.
9191
"""
92+
self._user_context = ctx
9293
self.context_manager.update(ctx)
9394
return self
9495

@@ -180,95 +181,6 @@ def apply[U](
180181

181182
return self # type: ignore
182183

183-
def branch(
184-
self,
185-
branches: dict[str, Transformer[T, Any]],
186-
batch_size: int = 1000,
187-
max_batch_buffer: int = 1,
188-
use_queue_chunks: bool = True,
189-
) -> dict[str, list[Any]]:
190-
"""Forks the pipeline into multiple branches for concurrent, parallel processing.
191-
192-
This is a **terminal operation** that implements a fan-out pattern where
193-
the entire dataset is copied to each branch for independent processing.
194-
Each branch processes the complete dataset concurrently using separate
195-
transformers, and results are collected and returned in a dictionary.
196-
197-
Args:
198-
branches: A dictionary where keys are branch names (str) and values
199-
are `Transformer` instances of any subtype.
200-
batch_size: The number of items to batch together when sending data
201-
to branches. Larger batches can improve throughput but
202-
use more memory. Defaults to 1000.
203-
max_batch_buffer: The maximum number of batches to buffer for each
204-
branch queue. Controls memory usage and creates
205-
backpressure. Defaults to 1.
206-
use_queue_chunks: Whether to use passthrough chunking for the
207-
transformers. When True, batches are processed
208-
as chunks. Defaults to True.
209-
210-
Returns:
211-
A dictionary where keys are the branch names and values are lists
212-
of all items processed by that branch's transformer.
213-
214-
Note:
215-
This operation consumes the pipeline's iterator, making subsequent
216-
operations on the same pipeline return empty results.
217-
"""
218-
if not branches:
219-
self.consume()
220-
return {}
221-
222-
source_iterator = self.processed_data
223-
branch_items = list(branches.items())
224-
num_branches = len(branch_items)
225-
final_results: dict[str, list[Any]] = {}
226-
227-
queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)]
228-
229-
def producer() -> None:
230-
"""Reads from the source and distributes batches to ALL branch queues."""
231-
# Use itertools.batched for clean and efficient batch creation.
232-
for batch_tuple in itertools.batched(source_iterator, batch_size):
233-
# The batch is a tuple; convert to a list for consumers.
234-
batch_list = list(batch_tuple)
235-
for q in queues:
236-
q.put(batch_list)
237-
238-
# Signal to all consumers that the stream is finished.
239-
for q in queues:
240-
q.put(None)
241-
242-
def consumer(transformer: Transformer, queue: Queue) -> list[Any]:
243-
"""Consumes batches from a queue and runs them through a transformer."""
244-
245-
def stream_from_queue() -> Iterator[T]:
246-
while (batch := queue.get()) is not None:
247-
yield batch
248-
249-
if use_queue_chunks:
250-
transformer = transformer.set_chunker(passthrough_chunks)
251-
252-
result_iterator = transformer(stream_from_queue(), self.context_manager) # type: ignore
253-
return list(result_iterator)
254-
255-
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:
256-
executor.submit(producer)
257-
258-
future_to_name = {
259-
executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items)
260-
}
261-
262-
for future in as_completed(future_to_name):
263-
name = future_to_name[future]
264-
try:
265-
final_results[name] = future.result()
266-
except Exception as e:
267-
print(f"Branch '{name}' raised an exception: {e}")
268-
final_results[name] = []
269-
270-
return final_results
271-
272184
def buffer(self, size: int, batch_size: int = 1000) -> "Pipeline[T]":
273185
"""Inserts a buffer in the pipeline to allow downstream processing to read ahead.
274186
@@ -328,7 +240,7 @@ def __iter__(self) -> Iterator[T]:
328240
"""
329241
yield from self.processed_data
330242

331-
def to_list(self) -> list[T]:
243+
def to_list(self) -> tuple[list[T], dict[str, Any]]:
332244
"""Execute the pipeline and return the results as a list.
333245
334246
This is a terminal operation that consumes the pipeline's iterator
@@ -341,9 +253,9 @@ def to_list(self) -> list[T]:
341253
This operation consumes the pipeline's iterator, making subsequent
342254
operations on the same pipeline return empty results.
343255
"""
344-
return list(self.processed_data)
256+
return list(self.processed_data), self.context_manager.to_dict()
345257

346-
def each(self, function: PipelineFunction[T]) -> None:
258+
def each(self, function: PipelineFunction[T]) -> tuple[None, dict[str, Any]]:
347259
"""Apply a function to each element (terminal operation).
348260
349261
This is a terminal operation that processes each element for side effects
@@ -360,7 +272,9 @@ def each(self, function: PipelineFunction[T]) -> None:
360272
for item in self.processed_data:
361273
function(item)
362274

363-
def first(self, n: int = 1) -> list[T]:
275+
return None, self.context_manager.to_dict()
276+
277+
def first(self, n: int = 1) -> tuple[list[T], dict[str, Any]]:
364278
"""Get the first n elements of the pipeline (terminal operation).
365279
366280
This is a terminal operation that consumes up to n elements from the
@@ -381,9 +295,9 @@ def first(self, n: int = 1) -> list[T]:
381295
operations will continue from where this operation left off.
382296
"""
383297
assert n >= 1, "n must be at least 1"
384-
return list(itertools.islice(self.processed_data, n))
298+
return list(itertools.islice(self.processed_data, n)), self.context_manager.to_dict()
385299

386-
def consume(self) -> None:
300+
def consume(self) -> tuple[None, dict[str, Any]]:
387301
"""Consume the pipeline without returning results (terminal operation).
388302
389303
This is a terminal operation that processes all elements in the pipeline
@@ -396,3 +310,94 @@ def consume(self) -> None:
396310
"""
397311
for _ in self.processed_data:
398312
pass
313+
314+
return None, self.context_manager.to_dict()
315+
316+
def branch(
317+
self,
318+
branches: dict[str, Transformer[T, Any]],
319+
batch_size: int = 1000,
320+
max_batch_buffer: int = 1,
321+
use_queue_chunks: bool = True,
322+
) -> tuple[dict[str, list[Any]], dict[str, Any]]:
323+
"""Forks the pipeline into multiple branches for concurrent, parallel processing.
324+
325+
This is a **terminal operation** that implements a fan-out pattern where
326+
the entire dataset is copied to each branch for independent processing.
327+
Each branch processes the complete dataset concurrently using separate
328+
transformers, and results are collected and returned in a dictionary.
329+
330+
Args:
331+
branches: A dictionary where keys are branch names (str) and values
332+
are `Transformer` instances of any subtype.
333+
batch_size: The number of items to batch together when sending data
334+
to branches. Larger batches can improve throughput but
335+
use more memory. Defaults to 1000.
336+
max_batch_buffer: The maximum number of batches to buffer for each
337+
branch queue. Controls memory usage and creates
338+
backpressure. Defaults to 1.
339+
use_queue_chunks: Whether to use passthrough chunking for the
340+
transformers. When True, batches are processed
341+
as chunks. Defaults to True.
342+
343+
Returns:
344+
A dictionary where keys are the branch names and values are lists
345+
of all items processed by that branch's transformer.
346+
347+
Note:
348+
This operation consumes the pipeline's iterator, making subsequent
349+
operations on the same pipeline return empty results.
350+
"""
351+
if not branches:
352+
self.consume()
353+
return {}, {}
354+
355+
source_iterator = self.processed_data
356+
branch_items = list(branches.items())
357+
num_branches = len(branch_items)
358+
final_results: dict[str, list[Any]] = {}
359+
360+
queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)]
361+
362+
def producer() -> None:
363+
"""Reads from the source and distributes batches to ALL branch queues."""
364+
# Use itertools.batched for clean and efficient batch creation.
365+
for batch_tuple in itertools.batched(source_iterator, batch_size):
366+
# The batch is a tuple; convert to a list for consumers.
367+
batch_list = list(batch_tuple)
368+
for q in queues:
369+
q.put(batch_list)
370+
371+
# Signal to all consumers that the stream is finished.
372+
for q in queues:
373+
q.put(None)
374+
375+
def consumer(transformer: Transformer, queue: Queue) -> list[Any]:
376+
"""Consumes batches from a queue and runs them through a transformer."""
377+
378+
def stream_from_queue() -> Iterator[T]:
379+
while (batch := queue.get()) is not None:
380+
yield batch
381+
382+
if use_queue_chunks:
383+
transformer = transformer.set_chunker(passthrough_chunks)
384+
385+
result_iterator = transformer(stream_from_queue(), self.context_manager) # type: ignore
386+
return list(result_iterator)
387+
388+
with ThreadPoolExecutor(max_workers=num_branches + 1) as executor:
389+
executor.submit(producer)
390+
391+
future_to_name = {
392+
executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items)
393+
}
394+
395+
for future in as_completed(future_to_name):
396+
name = future_to_name[future]
397+
try:
398+
final_results[name] = future.result()
399+
except Exception as e:
400+
print(f"Branch '{name}' raised an exception: {e}")
401+
final_results[name] = []
402+
403+
return final_results, self.context_manager.to_dict()

tests/test_http_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from laygo import HTTPTransformer
66
from laygo import Pipeline
7-
from laygo import PipelineContext
7+
from laygo.context.simple import SimpleContextManager
88

99

1010
class TestHTTPTransformer:
@@ -42,7 +42,7 @@ def mock_response(request, context):
4242
input_chunk = request.json()
4343
# Call the actual view function logic obtained from get_route()
4444
# We pass None for the context as it's not used in this simple case.
45-
output_chunk = worker_view_func(chunk=input_chunk, context=PipelineContext())
45+
output_chunk = worker_view_func(chunk=input_chunk, context=SimpleContextManager())
4646
return output_chunk
4747

4848
# Use requests_mock context manager
@@ -52,7 +52,7 @@ def mock_response(request, context):
5252
# 5. Run the standard Pipeline with the configured transformer
5353
initial_data = list(range(10)) # [0, 1, 2, ..., 9]
5454
pipeline = Pipeline(initial_data).apply(http_transformer)
55-
result = pipeline.to_list()
55+
result, _ = pipeline.to_list()
5656

5757
# 6. Assert the final result
5858
expected_result = [12, 14, 16, 18]

0 commit comments

Comments
 (0)