Skip to content

Commit 1967dad

Browse files
authored
Merge pull request #12 from ringoldsdev/feat/20250723/chunk-level-reducer
feat: implemented reduce over chunks
2 parents 455849b + 3bc778c commit 1967dad

File tree

2 files changed

+152
-37
lines changed

2 files changed

+152
-37
lines changed

laygo/transformers/transformer.py

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import reduce
88
import itertools
99
from typing import Any
10+
from typing import Literal
1011
from typing import Self
1112
from typing import Union
1213
from typing import overload
@@ -343,42 +344,81 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -
343344
# The context is now passed explicitly through the transformer chain.
344345
yield from self.transformer(chunk, run_context)
345346

346-
def reduce[U](self, function: PipelineReduceFunction[U, Out], initial: U):
347-
"""Reduce elements to a single value (terminal operation).
348-
349-
Args:
350-
function: The reduction function. Can be context-aware.
351-
initial: The initial value for the reduction.
352-
353-
Returns:
354-
A function that executes the reduction when called with data.
355-
"""
356-
357-
if is_context_aware_reduce(function):
358-
359-
def _reduce_with_context(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
360-
# The context for the run is determined here.
361-
run_context = context or self.context
362-
363-
data_iterator = self(data, run_context)
364-
365-
def function_wrapper(acc: U, value: Out) -> U:
366-
return function(acc, value, run_context)
367-
368-
yield reduce(function_wrapper, data_iterator, initial)
369-
370-
return _reduce_with_context
371-
372-
# Not context-aware, so we adapt the function to ignore the context.
373-
def _reduce(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
374-
# The context for the run is determined here.
375-
run_context = context or self.context
376-
377-
data_iterator = self(data, run_context)
378-
379-
yield reduce(function, data_iterator, initial) # type: ignore
347+
@overload
348+
def reduce[U](
349+
self,
350+
function: PipelineReduceFunction[U, Out],
351+
initial: U,
352+
*,
353+
per_chunk: Literal[True],
354+
) -> "Transformer[In, U]":
355+
"""Reduces each chunk to a single value (chainable operation)."""
356+
...
380357

381-
return _reduce
358+
@overload
359+
def reduce[U](
360+
self,
361+
function: PipelineReduceFunction[U, Out],
362+
initial: U,
363+
*,
364+
per_chunk: Literal[False] = False,
365+
) -> Callable[[Iterable[In], PipelineContext | None], Iterator[U]]:
366+
"""Reduces the entire dataset to a single value (terminal operation)."""
367+
...
368+
369+
def reduce[U](
370+
self,
371+
function: PipelineReduceFunction[U, Out],
372+
initial: U,
373+
*,
374+
per_chunk: bool = False,
375+
) -> Union["Transformer[In, U]", Callable[[Iterable[In], PipelineContext | None], Iterator[U]]]: # type: ignore
376+
"""Reduces elements to a single value, either per-chunk or for the entire dataset."""
377+
if per_chunk:
378+
# --- Efficient "per-chunk" logic (chainable) ---
379+
380+
# The context-awareness check is now hoisted and executed only ONCE.
381+
if is_context_aware_reduce(function):
382+
# We define a specialized operation for the context-aware case.
383+
def reduce_chunk_operation(chunk: list[Out], ctx: PipelineContext) -> list[U]:
384+
if not chunk:
385+
return []
386+
# No check happens here; we know the function needs the context.
387+
wrapper = lambda acc, val: function(acc, val, ctx) # noqa: E731, W291
388+
return [reduce(wrapper, chunk, initial)]
389+
else:
390+
# We define a specialized, simpler operation for the non-aware case.
391+
def reduce_chunk_operation(chunk: list[Out], ctx: PipelineContext) -> list[U]:
392+
if not chunk:
393+
return []
394+
# No check happens here; the function is called directly.
395+
return [reduce(function, chunk, initial)] # type: ignore
396+
397+
return self._pipe(reduce_chunk_operation)
398+
399+
# --- "Entire dataset" logic with `match` (terminal) ---
400+
match is_context_aware_reduce(function):
401+
case True:
402+
403+
def _reduce_with_context(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
404+
run_context = context or self.context
405+
data_iterator = self(data, run_context)
406+
407+
def function_wrapper(acc, val):
408+
return function(acc, val, run_context) # type: ignore
409+
410+
yield reduce(function_wrapper, data_iterator, initial)
411+
412+
return _reduce_with_context
413+
414+
case False:
415+
416+
def _reduce(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
417+
run_context = context or self.context
418+
data_iterator = self(data, run_context)
419+
yield reduce(function, data_iterator, initial) # type: ignore
420+
421+
return _reduce
382422

383423
def catch[U](
384424
self,

tests/test_transformer.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_basic_reduce(self):
265265
"""Test reduce with sum operation."""
266266
transformer = createTransformer(int)
267267
reducer = transformer.reduce(lambda acc, x: acc + x, initial=0)
268-
result = list(reducer([1, 2, 3, 4]))
268+
result = list(reducer([1, 2, 3, 4], None))
269269
assert result == [10]
270270

271271
def test_reduce_with_context(self):
@@ -280,9 +280,84 @@ def test_reduce_after_transformation(self):
280280
"""Test reduce after map transformation."""
281281
transformer = createTransformer(int).map(lambda x: x * 2)
282282
reducer = transformer.reduce(lambda acc, x: acc + x, initial=0)
283-
result = list(reducer([1, 2, 3]))
283+
result = list(reducer([1, 2, 3], None))
284284
assert result == [12] # [2, 4, 6] summed = 12
285285

286+
def test_reduce_per_chunk_basic(self):
287+
"""Test reduce with per_chunk=True for basic operation."""
288+
transformer = createTransformer(int, chunk_size=2).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
289+
result = list(transformer([1, 2, 3, 4, 5]))
290+
# With chunk_size=2: [1, 2] -> 3, [3, 4] -> 7, [5] -> 5
291+
assert result == [3, 7, 5]
292+
293+
def test_reduce_per_chunk_with_context(self):
294+
"""Test reduce with per_chunk=True and context-aware function."""
295+
context = PipelineContext({"multiplier": 2})
296+
transformer = createTransformer(int, chunk_size=2).reduce(
297+
lambda acc, x, ctx: acc + (x * ctx["multiplier"]), initial=0, per_chunk=True
298+
)
299+
result = list(transformer([1, 2, 3], context))
300+
# With chunk_size=2: [1, 2] -> (1*2) + (2*2) = 6, [3] -> (3*2) = 6
301+
assert result == [6, 6]
302+
303+
def test_reduce_per_chunk_empty_chunks(self):
304+
"""Test reduce with per_chunk=True handles empty chunks correctly."""
305+
transformer = createTransformer(int, chunk_size=5).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
306+
result = list(transformer([]))
307+
assert result == []
308+
309+
def test_reduce_per_chunk_single_element_chunks(self):
310+
"""Test reduce with per_chunk=True with single element chunks."""
311+
transformer = createTransformer(int, chunk_size=1).reduce(lambda acc, x: acc + x, initial=10, per_chunk=True)
312+
result = list(transformer([1, 2, 3]))
313+
# Each chunk has one element: [1] -> 10+1=11, [2] -> 10+2=12, [3] -> 10+3=13
314+
assert result == [11, 12, 13]
315+
316+
def test_reduce_per_chunk_chaining(self):
317+
"""Test reduce with per_chunk=True can be chained with other operations."""
318+
transformer = (
319+
createTransformer(int, chunk_size=2)
320+
.map(lambda x: x * 2)
321+
.reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
322+
.map(lambda x: x * 10)
323+
)
324+
result = list(transformer([1, 2, 3]))
325+
# After map: [2, 4, 6]
326+
# With chunk_size=2: [2, 4] -> 6, [6] -> 6
327+
# After second map: [60, 60]
328+
assert result == [60, 60]
329+
330+
def test_reduce_per_chunk_different_chunk_sizes(self):
331+
"""Test reduce with per_chunk=True works with different chunk sizes."""
332+
data = [1, 2, 3, 4, 5, 6]
333+
334+
# Test with chunk_size=2
335+
transformer_2 = createTransformer(int, chunk_size=2).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
336+
result_2 = list(transformer_2(data))
337+
assert result_2 == [3, 7, 11] # [1,2]->3, [3,4]->7, [5,6]->11
338+
339+
# Test with chunk_size=3
340+
transformer_3 = createTransformer(int, chunk_size=3).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
341+
result_3 = list(transformer_3(data))
342+
assert result_3 == [6, 15] # [1,2,3]->6, [4,5,6]->15
343+
344+
def test_reduce_per_chunk_versus_terminal(self):
345+
"""Test that per_chunk=True and per_chunk=False produce different behaviors."""
346+
data = [1, 2, 3, 4]
347+
348+
# Terminal reduce (per_chunk=False) - returns a callable
349+
transformer_terminal = createTransformer(int, chunk_size=2)
350+
reducer_terminal = transformer_terminal.reduce(lambda acc, x: acc + x, initial=0, per_chunk=False)
351+
result_terminal = list(reducer_terminal(data, None))
352+
assert result_terminal == [10] # Sum of all elements
353+
354+
# Per-chunk reduce (per_chunk=True) - returns a transformer
355+
transformer_per_chunk = createTransformer(int, chunk_size=2).reduce(
356+
lambda acc, x: acc + x, initial=0, per_chunk=True
357+
)
358+
result_per_chunk = list(transformer_per_chunk(data))
359+
assert result_per_chunk == [3, 7] # Sum per chunk [1,2]->3, [3,4]->7
360+
286361

287362
class TestTransformerEdgeCases:
288363
"""Test edge cases and boundary conditions."""

0 commit comments

Comments
 (0)