|
1 | 1 | """Tests for the Pipeline class.""" |
2 | 2 |
|
| 3 | +import pytest |
| 4 | + |
3 | 5 | from laygo import Pipeline |
| 6 | +from laygo import PipelineContext |
4 | 7 | from laygo.transformers.transformer import createTransformer |
5 | 8 |
|
6 | 9 |
|
@@ -211,3 +214,244 @@ def second_map(x): |
211 | 214 |
|
212 | 215 | assert sorted(first_map_values) == list(range(10)) |
213 | 216 | assert sorted(second_map_values) == [x * 2 for x in range(10)] |
| 217 | + |
| 218 | + |
| 219 | +class TestPipelineBranch: |
| 220 | + """Test pipeline branch method functionality.""" |
| 221 | + |
| 222 | + def test_branch_basic_functionality(self): |
| 223 | + """Test basic branch operation with simple transformers.""" |
| 224 | + # Create a pipeline with basic data |
| 225 | + pipeline = Pipeline([1, 2, 3, 4, 5]) |
| 226 | + |
| 227 | + # Create two different branch transformers |
| 228 | + double_branch = createTransformer(int).map(lambda x: x * 2) |
| 229 | + square_branch = createTransformer(int).map(lambda x: x ** 2) |
| 230 | + |
| 231 | + # Execute branching |
| 232 | + result = pipeline.branch({ |
| 233 | + "doubled": double_branch, |
| 234 | + "squared": square_branch |
| 235 | + }) |
| 236 | + |
| 237 | + # Verify results contain the last processed chunk for each branch |
| 238 | + assert "doubled" in result |
| 239 | + assert "squared" in result |
| 240 | + assert len(result) == 2 |
| 241 | + |
| 242 | + # Since the default chunk size is 1000 and we have 5 elements, |
| 243 | + # there should be only one chunk, so the result should contain all elements |
| 244 | + assert sorted(result["doubled"]) == [2, 4, 6, 8, 10] |
| 245 | + assert sorted(result["squared"]) == [1, 4, 9, 16, 25] |
| 246 | + |
| 247 | + def test_branch_with_empty_input(self): |
| 248 | + """Test branch with empty input data.""" |
| 249 | + pipeline = Pipeline([]) |
| 250 | + |
| 251 | + double_branch = createTransformer(int).map(lambda x: x * 2) |
| 252 | + square_branch = createTransformer(int).map(lambda x: x ** 2) |
| 253 | + |
| 254 | + result = pipeline.branch({ |
| 255 | + "doubled": double_branch, |
| 256 | + "squared": square_branch |
| 257 | + }) |
| 258 | + |
| 259 | + # Should return empty lists for all branches |
| 260 | + assert result == {"doubled": [], "squared": []} |
| 261 | + |
| 262 | + def test_branch_with_empty_branches_dict(self): |
| 263 | + """Test branch with empty branches dictionary.""" |
| 264 | + pipeline = Pipeline([1, 2, 3]) |
| 265 | + |
| 266 | + result = pipeline.branch({}) |
| 267 | + |
| 268 | + # Should return empty dictionary |
| 269 | + assert result == {} |
| 270 | + |
| 271 | + def test_branch_with_single_branch(self): |
| 272 | + """Test branch with only one branch.""" |
| 273 | + pipeline = Pipeline([1, 2, 3, 4]) |
| 274 | + |
| 275 | + triple_branch = createTransformer(int).map(lambda x: x * 3) |
| 276 | + |
| 277 | + result = pipeline.branch({"tripled": triple_branch}) |
| 278 | + |
| 279 | + assert len(result) == 1 |
| 280 | + assert "tripled" in result |
| 281 | + assert sorted(result["tripled"]) == [3, 6, 9, 12] |
| 282 | + |
| 283 | + def test_branch_with_filtering_transformers(self): |
| 284 | + """Test branch with transformers that filter data.""" |
| 285 | + pipeline = Pipeline([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| 286 | + |
| 287 | + # Create transformers that filter data |
| 288 | + even_branch = createTransformer(int).filter(lambda x: x % 2 == 0) |
| 289 | + odd_branch = createTransformer(int).filter(lambda x: x % 2 == 1) |
| 290 | + |
| 291 | + result = pipeline.branch({ |
| 292 | + "evens": even_branch, |
| 293 | + "odds": odd_branch |
| 294 | + }) |
| 295 | + |
| 296 | + assert sorted(result["evens"]) == [2, 4, 6, 8, 10] |
| 297 | + assert sorted(result["odds"]) == [1, 3, 5, 7, 9] |
| 298 | + |
| 299 | + def test_branch_with_multiple_transformations(self): |
| 300 | + """Test branch with complex multi-step transformers.""" |
| 301 | + pipeline = Pipeline([1, 2, 3, 4, 5, 6]) |
| 302 | + |
| 303 | + # Complex transformer: filter evens, then double, then add 1 |
| 304 | + complex_branch = (createTransformer(int) |
| 305 | + .filter(lambda x: x % 2 == 0) |
| 306 | + .map(lambda x: x * 2) |
| 307 | + .map(lambda x: x + 1)) |
| 308 | + |
| 309 | + # Simple transformer: just multiply by 10 |
| 310 | + simple_branch = createTransformer(int).map(lambda x: x * 10) |
| 311 | + |
| 312 | + result = pipeline.branch({ |
| 313 | + "complex": complex_branch, |
| 314 | + "simple": simple_branch |
| 315 | + }) |
| 316 | + |
| 317 | + # Complex: [2, 4, 6] -> [4, 8, 12] -> [5, 9, 13] |
| 318 | + assert sorted(result["complex"]) == [5, 9, 13] |
| 319 | + # Simple: [1, 2, 3, 4, 5, 6] -> [10, 20, 30, 40, 50, 60] |
| 320 | + assert sorted(result["simple"]) == [10, 20, 30, 40, 50, 60] |
| 321 | + |
| 322 | + def test_branch_with_chunked_data(self): |
| 323 | + """Test branch behavior with data that gets processed in multiple chunks.""" |
| 324 | + # Create a dataset large enough to be processed in multiple chunks |
| 325 | + # with a small chunk size |
| 326 | + data = list(range(1, 21)) # [1, 2, 3, ..., 20] |
| 327 | + pipeline = Pipeline(data) |
| 328 | + |
| 329 | + # Use small chunk size to ensure multiple chunks |
| 330 | + small_chunk_transformer = createTransformer(int, chunk_size=5).map(lambda x: x * 2) |
| 331 | + identity_transformer = createTransformer(int, chunk_size=5) |
| 332 | + |
| 333 | + result = pipeline.branch({ |
| 334 | + "doubled": small_chunk_transformer, |
| 335 | + "identity": identity_transformer |
| 336 | + }) |
| 337 | + |
| 338 | + # Since branch returns the LAST chunk processed, and we have 20 items with chunk_size=5, |
| 339 | + # we'll have 4 chunks: [1-5], [6-10], [11-15], [16-20] |
| 340 | + # The last chunk is [16, 17, 18, 19, 20] |
| 341 | + assert sorted(result["doubled"]) == [32, 34, 36, 38, 40] # [16, 17, 18, 19, 20] * 2 |
| 342 | + assert sorted(result["identity"]) == [16, 17, 18, 19, 20] |
| 343 | + |
| 344 | + def test_branch_with_flatten_operation(self): |
| 345 | + """Test branch with flatten operations.""" |
| 346 | + pipeline = Pipeline([[1, 2], [3, 4], [5, 6]]) |
| 347 | + |
| 348 | + flatten_branch = createTransformer(list).flatten() |
| 349 | + count_branch = createTransformer(list).map(lambda x: len(x)) |
| 350 | + |
| 351 | + result = pipeline.branch({ |
| 352 | + "flattened": flatten_branch, |
| 353 | + "lengths": count_branch |
| 354 | + }) |
| 355 | + |
| 356 | + assert sorted(result["flattened"]) == [1, 2, 3, 4, 5, 6] |
| 357 | + assert sorted(result["lengths"]) == [2, 2, 2] |
| 358 | + |
| 359 | + def test_branch_is_terminal_operation(self): |
| 360 | + """Test that branch is a terminal operation that consumes the pipeline.""" |
| 361 | + pipeline = Pipeline([1, 2, 3, 4, 5]) |
| 362 | + |
| 363 | + # Create a simple transformer |
| 364 | + double_branch = createTransformer(int).map(lambda x: x * 2) |
| 365 | + |
| 366 | + # Execute branch |
| 367 | + result = pipeline.branch({"doubled": double_branch}) |
| 368 | + |
| 369 | + # Verify the result |
| 370 | + assert sorted(result["doubled"]) == [2, 4, 6, 8, 10] |
| 371 | + |
| 372 | + # Attempt to use the pipeline again should yield empty results |
| 373 | + # since the iterator has been consumed |
| 374 | + empty_result = pipeline.to_list() |
| 375 | + assert empty_result == [] |
| 376 | + |
| 377 | + def test_branch_with_different_chunk_sizes(self): |
| 378 | + """Test branch with transformers that have different chunk sizes.""" |
| 379 | + data = list(range(1, 16)) # [1, 2, 3, ..., 15] |
| 380 | + pipeline = Pipeline(data) |
| 381 | + |
| 382 | + # Different chunk sizes for different branches |
| 383 | + large_chunk_branch = createTransformer(int, chunk_size=10).map(lambda x: x + 100) |
| 384 | + small_chunk_branch = createTransformer(int, chunk_size=3).map(lambda x: x + 200) |
| 385 | + |
| 386 | + result = pipeline.branch({ |
| 387 | + "large_chunk": large_chunk_branch, |
| 388 | + "small_chunk": small_chunk_branch |
| 389 | + }) |
| 390 | + |
| 391 | + # With 15 items: |
| 392 | + # large_chunk (chunk_size=10): chunks [1-10], [11-15] -> last chunk [11-15] |
| 393 | + # small_chunk (chunk_size=3): chunks [1-3], [4-6], [7-9], [10-12], [13-15] -> last chunk [13-15] |
| 394 | + |
| 395 | + assert sorted(result["large_chunk"]) == [111, 112, 113, 114, 115] # [11, 12, 13, 14, 15] + 100 |
| 396 | + assert sorted(result["small_chunk"]) == [213, 214, 215] # [13, 14, 15] + 200 |
| 397 | + |
| 398 | + def test_branch_preserves_data_order_within_chunks(self): |
| 399 | + """Test that branch preserves data order within the final chunk.""" |
| 400 | + pipeline = Pipeline([5, 3, 8, 1, 9, 2]) |
| 401 | + |
| 402 | + # Identity transformer should preserve order |
| 403 | + identity_branch = createTransformer(int) |
| 404 | + reverse_branch = createTransformer(int).map(lambda x: -x) |
| 405 | + |
| 406 | + result = pipeline.branch({ |
| 407 | + "identity": identity_branch, |
| 408 | + "negated": reverse_branch |
| 409 | + }) |
| 410 | + |
| 411 | + # Should preserve the original order within the chunk |
| 412 | + assert result["identity"] == [5, 3, 8, 1, 9, 2] |
| 413 | + assert result["negated"] == [-5, -3, -8, -1, -9, -2] |
| 414 | + |
| 415 | + def test_branch_with_error_handling(self): |
| 416 | + """Test branch behavior when transformers encounter errors.""" |
| 417 | + pipeline = Pipeline([1, 2, 0, 4, 5]) |
| 418 | + |
| 419 | + # Create a transformer that will fail on zero division |
| 420 | + division_branch = createTransformer(int).map(lambda x: 10 // x) |
| 421 | + safe_branch = createTransformer(int).map(lambda x: x * 2) |
| 422 | + |
| 423 | + # The division_branch should fail when processing 0 |
| 424 | + # We expect this to raise an exception |
| 425 | + with pytest.raises(ZeroDivisionError): |
| 426 | + pipeline.branch({ |
| 427 | + "division": division_branch, |
| 428 | + "safe": safe_branch |
| 429 | + }) |
| 430 | + |
| 431 | + def test_branch_context_isolation(self): |
| 432 | + """Test that different branches don't interfere with each other's context.""" |
| 433 | + pipeline = Pipeline([1, 2, 3]) |
| 434 | + |
| 435 | + # Create context-aware transformers that modify context |
| 436 | + def context_modifier_a(chunk: list[int], ctx: PipelineContext) -> list[int]: |
| 437 | + ctx["branch_a_processed"] = len(chunk) |
| 438 | + return [x * 2 for x in chunk] |
| 439 | + |
| 440 | + def context_modifier_b(chunk: list[int], ctx: PipelineContext) -> list[int]: |
| 441 | + ctx["branch_b_processed"] = len(chunk) |
| 442 | + return [x * 3 for x in chunk] |
| 443 | + |
| 444 | + branch_a = createTransformer(int)._pipe(context_modifier_a) |
| 445 | + branch_b = createTransformer(int)._pipe(context_modifier_b) |
| 446 | + |
| 447 | + result = pipeline.branch({ |
| 448 | + "branch_a": branch_a, |
| 449 | + "branch_b": branch_b |
| 450 | + }) |
| 451 | + |
| 452 | + assert sorted(result["branch_a"]) == [2, 4, 6] |
| 453 | + assert sorted(result["branch_b"]) == [3, 6, 9] |
| 454 | + |
| 455 | + # Both context values should be set |
| 456 | + assert pipeline.ctx.get("branch_a_processed") == 3 |
| 457 | + assert pipeline.ctx.get("branch_b_processed") == 3 |
0 commit comments