Skip to content

Commit 496e8a1

Browse files
committed
feat: implemented branch method
1 parent f484d3e commit 496e8a1

2 files changed

Lines changed: 294 additions & 0 deletions

File tree

laygo/pipeline.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,56 @@ def apply[U](
147147

148148
return self # type: ignore
149149

150+
def branch(self, branches: dict[str, Transformer[T, Any]]) -> dict[str, list[Any]]:
151+
"""Forks the pipeline, sending all data to multiple branches and returning the last chunk.
152+
153+
This is a **terminal operation** that implements a fan-out pattern.
154+
It consumes the pipeline's data, sends the **entire dataset** to each
155+
branch transformer, and continuously **overwrites** a shared context value
156+
with the latest processed chunk. The final result is a dictionary
157+
containing only the **last processed chunk** for each branch.
158+
159+
Args:
160+
branches: A dictionary where keys are branch names (str) and values
161+
are `Transformer` instances.
162+
163+
Returns:
164+
A dictionary where keys are the branch names and values are lists
165+
of items from the last processed chunk for that branch.
166+
"""
167+
if not branches:
168+
self.consume()
169+
return {}
170+
171+
# 1. Build a single "fan-out" transformer by chaining taps.
172+
fan_out_transformer = Transformer[T, T]()
173+
174+
for name, branch_transformer in branches.items():
175+
# Create a "collector" that runs the user's logic and then
176+
# overwrites the context with its latest chunk.
177+
collector = Transformer.from_transformer(branch_transformer)
178+
179+
# This is the side-effect operation that overwrites the context.
180+
def overwrite_context_with_chunk(chunk: list[Any], ctx: PipelineContext, name=name) -> list[Any]:
181+
# This is an atomic assignment for manager dicts; no lock needed.
182+
ctx[name] = chunk
183+
# Return the chunk unmodified to satisfy the _pipe interface.
184+
return chunk
185+
186+
# Add this as the final step in the collector's pipeline.
187+
collector._pipe(overwrite_context_with_chunk)
188+
189+
# Tap the main transformer. The collector will run as a side-effect.
190+
fan_out_transformer.tap(collector)
191+
192+
# 2. Apply the fan-out transformer and consume the entire pipeline.
193+
self.apply(fan_out_transformer).consume()
194+
195+
# 3. Collect the final state from the context.
196+
final_results = {name: self.ctx.get(name, []) for name in branches}
197+
198+
return final_results
199+
150200
def buffer(self, size: int) -> "Pipeline[T]":
151201
"""Buffer the pipeline using threaded processing.
152202

tests/test_pipeline.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Tests for the Pipeline class."""
22

3+
import pytest
4+
35
from laygo import Pipeline
6+
from laygo import PipelineContext
47
from laygo.transformers.transformer import createTransformer
58

69

@@ -211,3 +214,244 @@ def second_map(x):
211214

212215
assert sorted(first_map_values) == list(range(10))
213216
assert sorted(second_map_values) == [x * 2 for x in range(10)]
217+
218+
219+
class TestPipelineBranch:
220+
"""Test pipeline branch method functionality."""
221+
222+
def test_branch_basic_functionality(self):
223+
"""Test basic branch operation with simple transformers."""
224+
# Create a pipeline with basic data
225+
pipeline = Pipeline([1, 2, 3, 4, 5])
226+
227+
# Create two different branch transformers
228+
double_branch = createTransformer(int).map(lambda x: x * 2)
229+
square_branch = createTransformer(int).map(lambda x: x ** 2)
230+
231+
# Execute branching
232+
result = pipeline.branch({
233+
"doubled": double_branch,
234+
"squared": square_branch
235+
})
236+
237+
# Verify results contain the last processed chunk for each branch
238+
assert "doubled" in result
239+
assert "squared" in result
240+
assert len(result) == 2
241+
242+
# Since the default chunk size is 1000 and we have 5 elements,
243+
# there should be only one chunk, so the result should contain all elements
244+
assert sorted(result["doubled"]) == [2, 4, 6, 8, 10]
245+
assert sorted(result["squared"]) == [1, 4, 9, 16, 25]
246+
247+
def test_branch_with_empty_input(self):
248+
"""Test branch with empty input data."""
249+
pipeline = Pipeline([])
250+
251+
double_branch = createTransformer(int).map(lambda x: x * 2)
252+
square_branch = createTransformer(int).map(lambda x: x ** 2)
253+
254+
result = pipeline.branch({
255+
"doubled": double_branch,
256+
"squared": square_branch
257+
})
258+
259+
# Should return empty lists for all branches
260+
assert result == {"doubled": [], "squared": []}
261+
262+
def test_branch_with_empty_branches_dict(self):
263+
"""Test branch with empty branches dictionary."""
264+
pipeline = Pipeline([1, 2, 3])
265+
266+
result = pipeline.branch({})
267+
268+
# Should return empty dictionary
269+
assert result == {}
270+
271+
def test_branch_with_single_branch(self):
272+
"""Test branch with only one branch."""
273+
pipeline = Pipeline([1, 2, 3, 4])
274+
275+
triple_branch = createTransformer(int).map(lambda x: x * 3)
276+
277+
result = pipeline.branch({"tripled": triple_branch})
278+
279+
assert len(result) == 1
280+
assert "tripled" in result
281+
assert sorted(result["tripled"]) == [3, 6, 9, 12]
282+
283+
def test_branch_with_filtering_transformers(self):
284+
"""Test branch with transformers that filter data."""
285+
pipeline = Pipeline([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
286+
287+
# Create transformers that filter data
288+
even_branch = createTransformer(int).filter(lambda x: x % 2 == 0)
289+
odd_branch = createTransformer(int).filter(lambda x: x % 2 == 1)
290+
291+
result = pipeline.branch({
292+
"evens": even_branch,
293+
"odds": odd_branch
294+
})
295+
296+
assert sorted(result["evens"]) == [2, 4, 6, 8, 10]
297+
assert sorted(result["odds"]) == [1, 3, 5, 7, 9]
298+
299+
def test_branch_with_multiple_transformations(self):
300+
"""Test branch with complex multi-step transformers."""
301+
pipeline = Pipeline([1, 2, 3, 4, 5, 6])
302+
303+
# Complex transformer: filter evens, then double, then add 1
304+
complex_branch = (createTransformer(int)
305+
.filter(lambda x: x % 2 == 0)
306+
.map(lambda x: x * 2)
307+
.map(lambda x: x + 1))
308+
309+
# Simple transformer: just multiply by 10
310+
simple_branch = createTransformer(int).map(lambda x: x * 10)
311+
312+
result = pipeline.branch({
313+
"complex": complex_branch,
314+
"simple": simple_branch
315+
})
316+
317+
# Complex: [2, 4, 6] -> [4, 8, 12] -> [5, 9, 13]
318+
assert sorted(result["complex"]) == [5, 9, 13]
319+
# Simple: [1, 2, 3, 4, 5, 6] -> [10, 20, 30, 40, 50, 60]
320+
assert sorted(result["simple"]) == [10, 20, 30, 40, 50, 60]
321+
322+
def test_branch_with_chunked_data(self):
323+
"""Test branch behavior with data that gets processed in multiple chunks."""
324+
# Create a dataset large enough to be processed in multiple chunks
325+
# with a small chunk size
326+
data = list(range(1, 21)) # [1, 2, 3, ..., 20]
327+
pipeline = Pipeline(data)
328+
329+
# Use small chunk size to ensure multiple chunks
330+
small_chunk_transformer = createTransformer(int, chunk_size=5).map(lambda x: x * 2)
331+
identity_transformer = createTransformer(int, chunk_size=5)
332+
333+
result = pipeline.branch({
334+
"doubled": small_chunk_transformer,
335+
"identity": identity_transformer
336+
})
337+
338+
# Since branch returns the LAST chunk processed, and we have 20 items with chunk_size=5,
339+
# we'll have 4 chunks: [1-5], [6-10], [11-15], [16-20]
340+
# The last chunk is [16, 17, 18, 19, 20]
341+
assert sorted(result["doubled"]) == [32, 34, 36, 38, 40] # [16, 17, 18, 19, 20] * 2
342+
assert sorted(result["identity"]) == [16, 17, 18, 19, 20]
343+
344+
def test_branch_with_flatten_operation(self):
345+
"""Test branch with flatten operations."""
346+
pipeline = Pipeline([[1, 2], [3, 4], [5, 6]])
347+
348+
flatten_branch = createTransformer(list).flatten()
349+
count_branch = createTransformer(list).map(lambda x: len(x))
350+
351+
result = pipeline.branch({
352+
"flattened": flatten_branch,
353+
"lengths": count_branch
354+
})
355+
356+
assert sorted(result["flattened"]) == [1, 2, 3, 4, 5, 6]
357+
assert sorted(result["lengths"]) == [2, 2, 2]
358+
359+
def test_branch_is_terminal_operation(self):
360+
"""Test that branch is a terminal operation that consumes the pipeline."""
361+
pipeline = Pipeline([1, 2, 3, 4, 5])
362+
363+
# Create a simple transformer
364+
double_branch = createTransformer(int).map(lambda x: x * 2)
365+
366+
# Execute branch
367+
result = pipeline.branch({"doubled": double_branch})
368+
369+
# Verify the result
370+
assert sorted(result["doubled"]) == [2, 4, 6, 8, 10]
371+
372+
# Attempt to use the pipeline again should yield empty results
373+
# since the iterator has been consumed
374+
empty_result = pipeline.to_list()
375+
assert empty_result == []
376+
377+
def test_branch_with_different_chunk_sizes(self):
378+
"""Test branch with transformers that have different chunk sizes."""
379+
data = list(range(1, 16)) # [1, 2, 3, ..., 15]
380+
pipeline = Pipeline(data)
381+
382+
# Different chunk sizes for different branches
383+
large_chunk_branch = createTransformer(int, chunk_size=10).map(lambda x: x + 100)
384+
small_chunk_branch = createTransformer(int, chunk_size=3).map(lambda x: x + 200)
385+
386+
result = pipeline.branch({
387+
"large_chunk": large_chunk_branch,
388+
"small_chunk": small_chunk_branch
389+
})
390+
391+
# With 15 items:
392+
# large_chunk (chunk_size=10): chunks [1-10], [11-15] -> last chunk [11-15]
393+
# small_chunk (chunk_size=3): chunks [1-3], [4-6], [7-9], [10-12], [13-15] -> last chunk [13-15]
394+
395+
assert sorted(result["large_chunk"]) == [111, 112, 113, 114, 115] # [11, 12, 13, 14, 15] + 100
396+
assert sorted(result["small_chunk"]) == [213, 214, 215] # [13, 14, 15] + 200
397+
398+
def test_branch_preserves_data_order_within_chunks(self):
399+
"""Test that branch preserves data order within the final chunk."""
400+
pipeline = Pipeline([5, 3, 8, 1, 9, 2])
401+
402+
# Identity transformer should preserve order
403+
identity_branch = createTransformer(int)
404+
reverse_branch = createTransformer(int).map(lambda x: -x)
405+
406+
result = pipeline.branch({
407+
"identity": identity_branch,
408+
"negated": reverse_branch
409+
})
410+
411+
# Should preserve the original order within the chunk
412+
assert result["identity"] == [5, 3, 8, 1, 9, 2]
413+
assert result["negated"] == [-5, -3, -8, -1, -9, -2]
414+
415+
def test_branch_with_error_handling(self):
416+
"""Test branch behavior when transformers encounter errors."""
417+
pipeline = Pipeline([1, 2, 0, 4, 5])
418+
419+
# Create a transformer that will fail on zero division
420+
division_branch = createTransformer(int).map(lambda x: 10 // x)
421+
safe_branch = createTransformer(int).map(lambda x: x * 2)
422+
423+
# The division_branch should fail when processing 0
424+
# We expect this to raise an exception
425+
with pytest.raises(ZeroDivisionError):
426+
pipeline.branch({
427+
"division": division_branch,
428+
"safe": safe_branch
429+
})
430+
431+
def test_branch_context_isolation(self):
432+
"""Test that different branches don't interfere with each other's context."""
433+
pipeline = Pipeline([1, 2, 3])
434+
435+
# Create context-aware transformers that modify context
436+
def context_modifier_a(chunk: list[int], ctx: PipelineContext) -> list[int]:
437+
ctx["branch_a_processed"] = len(chunk)
438+
return [x * 2 for x in chunk]
439+
440+
def context_modifier_b(chunk: list[int], ctx: PipelineContext) -> list[int]:
441+
ctx["branch_b_processed"] = len(chunk)
442+
return [x * 3 for x in chunk]
443+
444+
branch_a = createTransformer(int)._pipe(context_modifier_a)
445+
branch_b = createTransformer(int)._pipe(context_modifier_b)
446+
447+
result = pipeline.branch({
448+
"branch_a": branch_a,
449+
"branch_b": branch_b
450+
})
451+
452+
assert sorted(result["branch_a"]) == [2, 4, 6]
453+
assert sorted(result["branch_b"]) == [3, 6, 9]
454+
455+
# Both context values should be set
456+
assert pipeline.ctx.get("branch_a_processed") == 3
457+
assert pipeline.ctx.get("branch_b_processed") == 3

0 commit comments

Comments
 (0)