Skip to content

Commit 7b9d7c5

Browse files
committed
fix: imports
1 parent 6b730d4 commit 7b9d7c5

2 files changed

Lines changed: 45 additions & 24 deletions

File tree

laygo/context/types.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class IContextManager(MutableMapping[str, Any], ABC):
4141
This class defines the contract for all context managers, ensuring they
4242
provide a dictionary-like interface for state manipulation by inheriting
4343
from `collections.abc.MutableMapping`. It also includes methods for
44-
distribution (get_handle) and resource management (shutdown).
44+
distribution (get_handle), resource management (shutdown), and context
45+
management (__enter__, __exit__).
4546
"""
4647

4748
@abstractmethod
@@ -66,3 +67,23 @@ def shutdown(self) -> None:
6667
background processes, or any other cleanup required by the manager.
6768
"""
6869
raise NotImplementedError
70+
71+
def __enter__(self) -> "IContextManager":
72+
"""
73+
Enters the runtime context related to this object.
74+
75+
Returns:
76+
The context manager instance itself.
77+
"""
78+
return self
79+
80+
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
81+
"""
82+
Exits the runtime context and performs cleanup.
83+
84+
Args:
85+
exc_type: The exception type, if an exception was raised.
86+
exc_val: The exception instance, if an exception was raised.
87+
exc_tb: The traceback object, if an exception was raised.
88+
"""
89+
self.shutdown()

tests/test_integration.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from laygo import ParallelTransformer
44
from laygo import Pipeline
5-
from laygo import PipelineContext
65
from laygo import Transformer
76
from laygo import createTransformer
7+
from laygo.context.types import IContextManager
88

99

1010
class TestPipelineTransformerBasics:
@@ -18,7 +18,7 @@ def test_basic_pipeline_transformer_integration(self):
1818

1919
def test_pipeline_context_sharing(self):
2020
"""Test that context is properly shared between pipeline and transformers."""
21-
context = PipelineContext({"multiplier": 3, "threshold": 5})
21+
context = {"multiplier": 3, "threshold": 5}
2222
transformer = Transformer().map(lambda x, ctx: x * ctx["multiplier"]).filter(lambda x, ctx: x > ctx["threshold"])
2323
result = Pipeline([1, 2, 3]).context(context).apply(transformer).to_list()
2424
assert result == [6, 9]
@@ -82,15 +82,15 @@ def validate_and_convert(x):
8282
assert valid_numbers == [1.0, 2.0, 3.0, 5.0, 7.0]
8383

8484

85-
def safe_increment_and_transform(x: int, ctx: PipelineContext) -> int:
86-
with ctx["lock"]:
85+
def safe_increment_and_transform(x: int, ctx: IContextManager) -> int:
86+
with ctx:
8787
ctx["processed_count"] += 1
8888
ctx["sum_total"] += x
8989
return x * 2
9090

9191

92-
def count_and_transform(x: int, ctx: PipelineContext) -> int:
93-
with ctx["lock"]:
92+
def count_and_transform(x: int, ctx: IContextManager) -> int:
93+
with ctx:
9494
ctx["items_processed"] += 1
9595
if x % 2 == 0:
9696
ctx["even_count"] += 1
@@ -99,17 +99,17 @@ def count_and_transform(x: int, ctx: PipelineContext) -> int:
9999
return x * 3
100100

101101

102-
def stage1_processor(x: int, ctx: PipelineContext) -> int:
102+
def stage1_processor(x: int, ctx: IContextManager) -> int:
103103
"""First stage processing with context update."""
104-
with ctx["lock"]:
104+
with ctx:
105105
ctx["stage1_processed"] += 1
106106
ctx["total_sum"] += x
107107
return x * 2
108108

109109

110-
def stage2_processor(x: int, ctx: PipelineContext) -> int:
110+
def stage2_processor(x: int, ctx: IContextManager) -> int:
111111
"""Second stage processing with context update."""
112-
with ctx["lock"]:
112+
with ctx:
113113
ctx["stage2_processed"] += 1
114114
ctx["total_sum"] += x # Add transformed value too
115115
return x + 10
@@ -128,7 +128,7 @@ def test_parallel_transformer_basic_integration(self):
128128

129129
def test_parallel_transformer_with_context_modification(self):
130130
"""Test parallel transformer safely modifying shared context."""
131-
context = PipelineContext({"processed_count": 0, "sum_total": 0})
131+
context = {"processed_count": 0, "sum_total": 0}
132132

133133
parallel_transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=2)
134134
parallel_transformer = parallel_transformer.map(safe_increment_and_transform)
@@ -144,7 +144,7 @@ def test_parallel_transformer_with_context_modification(self):
144144

145145
def test_pipeline_accesses_modified_context(self):
146146
"""Test that pipeline can access context data modified by parallel transformer."""
147-
context = PipelineContext({"items_processed": 0, "even_count": 0, "odd_count": 0})
147+
context = {"items_processed": 0, "even_count": 0, "odd_count": 0}
148148

149149
parallel_transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=3)
150150
parallel_transformer = parallel_transformer.map(count_and_transform)
@@ -155,14 +155,14 @@ def test_pipeline_accesses_modified_context(self):
155155

156156
# Verify results and context access
157157
assert sorted(result) == [3, 6, 9, 12, 15, 18]
158-
assert pipeline.ctx["items_processed"] == 6
159-
assert pipeline.ctx["even_count"] == 3 # 2, 4, 6
160-
assert pipeline.ctx["odd_count"] == 3 # 1, 3, 5
158+
assert pipeline.context_manager["items_processed"] == 6
159+
assert pipeline.context_manager["even_count"] == 3 # 2, 4, 6
160+
assert pipeline.context_manager["odd_count"] == 3 # 1, 3, 5
161161

162162
def test_multiple_parallel_transformers_chaining(self):
163163
"""Test chaining multiple parallel transformers with shared context."""
164164
# Shared context for statistics across transformations
165-
context = PipelineContext({"stage1_processed": 0, "stage2_processed": 0, "total_sum": 0})
165+
context = {"stage1_processed": 0, "stage2_processed": 0, "total_sum": 0}
166166

167167
# Create two parallel transformers
168168
stage1 = ParallelTransformer[int, int](max_workers=2, chunk_size=2).map(stage1_processor)
@@ -184,7 +184,7 @@ def test_multiple_parallel_transformers_chaining(self):
184184
assert result == expected_final
185185

186186
# Verify context reflects both stages
187-
final_context = pipeline.ctx
187+
final_context = pipeline.context_manager
188188
assert final_context["stage1_processed"] == 5
189189
assert final_context["stage2_processed"] == 5
190190

@@ -199,11 +199,11 @@ def test_pipeline_context_isolation_with_parallel_processing(self):
199199

200200
# Create base context structure
201201
def create_context():
202-
return PipelineContext({"count": 0})
202+
return {"count": 0}
203203

204-
def increment_counter(x: int, ctx: PipelineContext) -> int:
204+
def increment_counter(x: int, ctx: IContextManager) -> int:
205205
"""Increment counter in context."""
206-
with ctx["lock"]:
206+
with ctx:
207207
ctx["count"] += 1
208208
return x * 2
209209

@@ -225,8 +225,8 @@ def increment_counter(x: int, ctx: PipelineContext) -> int:
225225
assert result2 == [2, 4, 6]
226226

227227
# But contexts should be isolated
228-
assert pipeline1.ctx["count"] == 3
229-
assert pipeline2.ctx["count"] == 3
228+
assert pipeline1.context_manager["count"] == 3
229+
assert pipeline2.context_manager["count"] == 3
230230

231231
# Verify they are different context objects
232-
assert pipeline1.ctx is not pipeline2.ctx
232+
assert pipeline1.context_manager is not pipeline2.context_manager

0 commit comments

Comments
 (0)