Skip to content

Commit f866937

Browse files
committed
Remove seed tracking, have snapshot evaluator own tracker instance
1 parent cc80ba5 commit f866937

6 files changed

Lines changed: 27 additions & 32 deletions

File tree

sqlmesh/core/console.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4308,7 +4308,7 @@ def _calculate_annotation_str_len(
43084308
def _format_bytes(num_bytes: t.Optional[int]) -> str:
43094309
if num_bytes and num_bytes > 0:
43104310
if num_bytes < 1024:
4311-
return f"{num_bytes} Bytes"
4311+
return f"{num_bytes} bytes"
43124312

43134313
num_bytes_float = float(num_bytes) / 1024.0
43144314
for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]:

sqlmesh/core/execution_tracker.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import typing as t
55
from contextlib import contextmanager
6-
from threading import local
6+
from threading import local, Lock
77
from dataclasses import dataclass, field
88

99

@@ -66,34 +66,32 @@ class QueryExecutionTracker:
6666

6767
_thread_local = local()
6868
_contexts: t.Dict[str, QueryExecutionContext] = {}
69+
_contexts_lock = Lock()
6970

70-
@classmethod
71-
def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
72-
return cls._contexts.get(snapshot_id_batch)
71+
def get_execution_context(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
72+
with self._contexts_lock:
73+
return self._contexts.get(snapshot_id_batch)
7374

7475
@classmethod
7576
def is_tracking(cls) -> bool:
7677
return getattr(cls._thread_local, "context", None) is not None
7778

78-
@classmethod
7979
@contextmanager
8080
def track_execution(
81-
cls, snapshot_id_batch: str, condition: bool = True
81+
self, snapshot_id_batch: str
8282
) -> t.Iterator[t.Optional[QueryExecutionContext]]:
8383
"""
8484
Context manager for tracking snapshot execution statistics.
8585
"""
86-
if not condition:
87-
yield None
88-
return
89-
9086
context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch)
91-
cls._thread_local.context = context
92-
cls._contexts[snapshot_id_batch] = context
87+
self._thread_local.context = context
88+
with self._contexts_lock:
89+
self._contexts[snapshot_id_batch] = context
90+
9391
try:
9492
yield context
9593
finally:
96-
cls._thread_local.context = None
94+
self._thread_local.context = None
9795

9896
@classmethod
9997
def record_execution(
@@ -103,8 +101,8 @@ def record_execution(
103101
if context is not None:
104102
context.add_execution(sql, row_count, bytes_processed)
105103

106-
@classmethod
107-
def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]:
108-
context = cls.get_execution_context(snapshot_id_batch)
109-
cls._contexts.pop(snapshot_id_batch, None)
104+
def get_execution_stats(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]:
105+
with self._contexts_lock:
106+
context = self._contexts.get(snapshot_id_batch)
107+
self._contexts.pop(snapshot_id_batch, None)
110108
return context.get_execution_stats() if context else None

sqlmesh/core/scheduler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sqlmesh.core import constants as c
1010
from sqlmesh.core.console import Console, get_console
1111
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
12-
from sqlmesh.core.execution_tracker import QueryExecutionTracker
1312
from sqlmesh.core.macros import RuntimeStage
1413
from sqlmesh.core.model.definition import AuditResult
1514
from sqlmesh.core.node import IntervalUnit
@@ -536,7 +535,7 @@ def run_node(node: SchedulingUnit) -> None:
536535
num_audits = len(audit_results)
537536
num_audits_failed = sum(1 for result in audit_results if result.count)
538537

539-
execution_stats = QueryExecutionTracker.get_execution_stats(
538+
execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
540539
f"{snapshot.snapshot_id}_{batch_idx}"
541540
)
542541

sqlmesh/core/snapshot/evaluator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
)
137137
self.selected_gateway = selected_gateway
138138
self.ddl_concurrent_tasks = ddl_concurrent_tasks
139+
self.execution_tracker = QueryExecutionTracker()
139140

140141
def evaluate(
141142
self,
@@ -170,9 +171,7 @@ def evaluate(
170171
Returns:
171172
The WAP ID of this evaluation if supported, None otherwise.
172173
"""
173-
with QueryExecutionTracker.track_execution(
174-
f"{snapshot.snapshot_id}_{batch_index}", condition=not snapshot.is_seed
175-
):
174+
with self.execution_tracker.track_execution(f"{snapshot.snapshot_id}_{batch_index}"):
176175
result = self._evaluate_snapshot(
177176
start=start,
178177
end=end,

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,13 +2454,10 @@ def capture_row_counts(
24542454
assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3
24552455

24562456
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2457-
assert len(actual_execution_stats) == 3
2458-
assert actual_execution_stats["seed_model"].total_rows_processed == 7
24592457
assert actual_execution_stats["incremental_model"].total_rows_processed == 7
24602458
assert actual_execution_stats["full_model"].total_rows_processed == 3
24612459

24622460
if ctx.mark.startswith("bigquery"):
2463-
assert actual_execution_stats["seed_model"].total_bytes_processed
24642461
assert actual_execution_stats["incremental_model"].total_bytes_processed
24652462
assert actual_execution_stats["full_model"].total_bytes_processed
24662463

tests/core/test_execution_tracker.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
def test_execution_tracker_thread_isolation() -> None:
99
def worker(id: str, row_counts: list[int]) -> QueryExecutionStats:
10-
with QueryExecutionTracker.track_execution(id) as ctx:
11-
assert QueryExecutionTracker.is_tracking()
10+
with execution_tracker.track_execution(id) as ctx:
11+
assert execution_tracker.is_tracking()
1212

1313
for count in row_counts:
14-
QueryExecutionTracker.record_execution("SELECT 1", count, None)
14+
execution_tracker.record_execution("SELECT 1", count, None)
1515

1616
assert ctx is not None
1717
return ctx.get_execution_stats()
1818

19+
execution_tracker = QueryExecutionTracker()
20+
1921
with ThreadPoolExecutor() as executor:
2022
futures = [
2123
executor.submit(worker, "batch_A", [10, 5]),
@@ -24,9 +26,9 @@ def worker(id: str, row_counts: list[int]) -> QueryExecutionStats:
2426
results = [f.result() for f in futures]
2527

2628
# Main thread has no active tracking context
27-
assert not QueryExecutionTracker.is_tracking()
28-
QueryExecutionTracker.record_execution("q", 10, None)
29-
assert QueryExecutionTracker.get_execution_stats("q") is None
29+
assert not execution_tracker.is_tracking()
30+
execution_tracker.record_execution("q", 10, None)
31+
assert execution_tracker.get_execution_stats("q") is None
3032

3133
# Order of results is not deterministic, so look up by id
3234
by_batch = {s.snapshot_batch_id: s for s in results}

0 commit comments

Comments
 (0)