Skip to content

Commit 335feb5

Browse files
committed
Move all tracking into snapshot evaluator, remove seed tracker class
1 parent 8872f6c commit 335feb5

6 files changed

Lines changed: 69 additions & 147 deletions

File tree

sqlmesh/core/engine_adapter/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from sqlmesh.core.model.kind import TimeColumn
4242
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
43-
from sqlmesh.core.execution_tracker import record_execution as track_execution_record
43+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
4444
from sqlmesh.utils import (
4545
CorrelationId,
4646
columns_to_types_all_known,
@@ -2443,7 +2443,11 @@ def _log_sql(
24432443
def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None:
24442444
self.cursor.execute(sql, **kwargs)
24452445

2446-
if track_row_count and self.SUPPORTS_QUERY_EXECUTION_TRACKING:
2446+
if (
2447+
self.SUPPORTS_QUERY_EXECUTION_TRACKING
2448+
and track_row_count
2449+
and QueryExecutionTracker.is_tracking()
2450+
):
24472451
rowcount_raw = getattr(self.cursor, "rowcount", None)
24482452
rowcount = None
24492453
if rowcount_raw is not None:
@@ -2452,7 +2456,7 @@ def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) ->
24522456
except (TypeError, ValueError):
24532457
pass
24542458

2455-
track_execution_record(sql, rowcount)
2459+
QueryExecutionTracker.record_execution(sql, rowcount)
24562460

24572461
@contextlib.contextmanager
24582462
def temp_table(

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
SourceQuery,
2222
set_catalog,
2323
)
24-
from sqlmesh.core.execution_tracker import record_execution as track_execution_record
24+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
2525
from sqlmesh.core.node import IntervalUnit
2626
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
2727
from sqlmesh.utils import optional_import, get_source_columns_to_types
@@ -1104,7 +1104,7 @@ def _execute(
11041104
elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]:
11051105
num_rows = query_job.num_dml_affected_rows
11061106

1107-
track_execution_record(sql, num_rows)
1107+
QueryExecutionTracker.record_execution(sql, num_rows)
11081108

11091109
def _get_data_objects(
11101110
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None

sqlmesh/core/execution_tracker.py

Lines changed: 20 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class QueryExecutionContext:
2727
queries_executed: t.List[t.Tuple[str, t.Optional[int], float]] = field(default_factory=list)
2828

2929
def add_execution(self, sql: str, row_count: t.Optional[int]) -> None:
30-
"""Record a single query execution."""
3130
if row_count is not None and row_count >= 0:
3231
self.total_rows_processed += row_count
3332
self.query_count += 1
@@ -46,97 +45,49 @@ def get_execution_stats(self) -> t.Dict[str, t.Any]:
4645

4746
class QueryExecutionTracker:
4847
"""
49-
Thread-local context manager for snapshot evaluation execution statistics, such as
48+
Thread-local context manager for snapshot execution statistics, such as
5049
rows processed.
5150
"""
5251

5352
_thread_local = local()
53+
_contexts: t.Dict[str, QueryExecutionContext] = {}
5454

5555
@classmethod
56-
def get_execution_context(cls) -> t.Optional[QueryExecutionContext]:
57-
return getattr(cls._thread_local, "context", None)
56+
def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
57+
return cls._contexts.get(snapshot_id_batch)
5858

5959
@classmethod
6060
def is_tracking(cls) -> bool:
61-
return cls.get_execution_context() is not None
61+
return getattr(cls._thread_local, "context", None) is not None
6262

6363
@classmethod
6464
@contextmanager
65-
def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionContext]:
65+
def track_execution(
66+
cls, snapshot_id_batch: str, condition: bool = True
67+
) -> t.Iterator[t.Optional[QueryExecutionContext]]:
6668
"""
67-
Context manager for tracking snapshot evaluation execution statistics.
69+
Context manager for tracking snapshot execution statistics.
6870
"""
69-
context = QueryExecutionContext(id=snapshot_name_batch)
71+
if not condition:
72+
yield None
73+
return
74+
75+
context = QueryExecutionContext(id=snapshot_id_batch)
7076
cls._thread_local.context = context
77+
cls._contexts[snapshot_id_batch] = context
7178
try:
7279
yield context
7380
finally:
74-
if hasattr(cls._thread_local, "context"):
75-
delattr(cls._thread_local, "context")
81+
cls._thread_local.context = None
7682

7783
@classmethod
7884
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
79-
context = cls.get_execution_context()
85+
context = getattr(cls._thread_local, "context", None)
8086
if context is not None:
8187
context.add_execution(sql, row_count)
8288

8389
@classmethod
84-
def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]:
85-
context = cls.get_execution_context()
90+
def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[t.Dict[str, t.Any]]:
91+
context = cls.get_execution_context(snapshot_id_batch)
92+
cls._contexts.pop(snapshot_id_batch, None)
8693
return context.get_execution_stats() if context else None
87-
88-
89-
class SeedExecutionTracker:
90-
_seed_contexts: t.Dict[str, QueryExecutionContext] = {}
91-
_thread_local = local()
92-
93-
@classmethod
94-
@contextmanager
95-
def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]:
96-
"""
97-
Context manager for tracking seed creation execution statistics.
98-
"""
99-
context = QueryExecutionContext(id=model_name)
100-
cls._seed_contexts[model_name] = context
101-
cls._thread_local.seed_id = model_name
102-
103-
try:
104-
yield context
105-
finally:
106-
if hasattr(cls._thread_local, "seed_id"):
107-
delattr(cls._thread_local, "seed_id")
108-
109-
@classmethod
110-
def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]:
111-
context = cls._seed_contexts.pop(model_name, None)
112-
return context.get_execution_stats() if context else None
113-
114-
@classmethod
115-
def clear_all_seed_stats(cls) -> None:
116-
"""Clear all remaining seed stats. Used for cleanup after evaluation completes."""
117-
cls._seed_contexts.clear()
118-
119-
@classmethod
120-
def is_tracking(cls) -> bool:
121-
return hasattr(cls._thread_local, "seed_id")
122-
123-
@classmethod
124-
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
125-
seed_id = getattr(cls._thread_local, "seed_id", None)
126-
if seed_id:
127-
context = cls._seed_contexts.get(seed_id)
128-
if context is not None:
129-
context.add_execution(sql, row_count)
130-
131-
132-
def record_execution(sql: str, row_count: t.Optional[int]) -> None:
133-
"""
134-
Record execution statistics for a single SQL statement.
135-
136-
Automatically infers which tracker is active based on the current thread.
137-
"""
138-
if SeedExecutionTracker.is_tracking():
139-
SeedExecutionTracker.record_execution(sql, row_count)
140-
return
141-
if QueryExecutionTracker.is_tracking():
142-
QueryExecutionTracker.record_execution(sql, row_count)

sqlmesh/core/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlmesh.core import constants as c
1010
from sqlmesh.core.console import Console, get_console
1111
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
12-
from sqlmesh.core.execution_tracker import QueryExecutionTracker, SeedExecutionTracker
12+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
1313
from sqlmesh.core.macros import RuntimeStage
1414
from sqlmesh.core.model.definition import AuditResult
1515
from sqlmesh.core.node import IntervalUnit

sqlmesh/core/snapshot/evaluator.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from sqlmesh.core.dialect import schema_
4040
from sqlmesh.core.engine_adapter import EngineAdapter
4141
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType
42-
from sqlmesh.core.execution_tracker import SeedExecutionTracker
42+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
4343
from sqlmesh.core.macros import RuntimeStage
4444
from sqlmesh.core.model import (
4545
AuditResult,
@@ -170,19 +170,22 @@ def evaluate(
170170
Returns:
171171
The WAP ID of this evaluation if supported, None otherwise.
172172
"""
173-
result = self._evaluate_snapshot(
174-
start=start,
175-
end=end,
176-
execution_time=execution_time,
177-
snapshot=snapshot,
178-
snapshots=snapshots,
179-
allow_destructive_snapshots=allow_destructive_snapshots or set(),
180-
allow_additive_snapshots=allow_additive_snapshots or set(),
181-
deployability_index=deployability_index,
182-
batch_index=batch_index,
183-
target_table_exists=target_table_exists,
184-
**kwargs,
185-
)
173+
with QueryExecutionTracker.track_execution(
174+
f"{snapshot.snapshot_id}_{batch_index}", condition=not snapshot.is_seed
175+
):
176+
result = self._evaluate_snapshot(
177+
start=start,
178+
end=end,
179+
execution_time=execution_time,
180+
snapshot=snapshot,
181+
snapshots=snapshots,
182+
allow_destructive_snapshots=allow_destructive_snapshots or set(),
183+
allow_additive_snapshots=allow_additive_snapshots or set(),
184+
deployability_index=deployability_index,
185+
batch_index=batch_index,
186+
target_table_exists=target_table_exists,
187+
**kwargs,
188+
)
186189
if result is None or isinstance(result, str):
187190
return result
188191
raise SQLMeshError(
Lines changed: 22 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,38 @@
1-
# Tests the sqlmesh.core.execution_tracker module
2-
# - creates a scenario where executions will take place in multiple threads
3-
# - generates the scenario with known numbers of rows to be processed
4-
# - tests that the execution tracker correctly tracks the number of rows processed in both threads
5-
# - may use mocks, an existing test project, manually created snapshots, or a duckdb database to create the scenario
6-
71
from __future__ import annotations
82

9-
import threading
10-
from queue import Queue
11-
from typing import List, Optional
3+
import typing as t
4+
from concurrent.futures import ThreadPoolExecutor
125

136
from sqlmesh.core.execution_tracker import QueryExecutionTracker
147

158

16-
def test_execution_tracker_thread_isolation_and_aggregation() -> None:
17-
"""
18-
Two worker threads each track executions in their own context. Verify:
19-
- isolation across threads
20-
- correct aggregation of rows
21-
- query metadata is captured
22-
- main thread has no active tracking
23-
"""
24-
25-
assert not QueryExecutionTracker.is_tracking()
26-
assert QueryExecutionTracker.get_execution_stats() is None
27-
28-
counts_a: List[Optional[int]] = [10, 5, None]
29-
counts_b: List[Optional[int]] = [3, 7]
30-
31-
start_barrier = threading.Barrier(3) # 2 workers + main
32-
results: "Queue[dict]" = Queue()
33-
34-
def worker(batch_id: str, counts: List[Optional[int]]) -> None:
35-
with QueryExecutionTracker.track_execution(batch_id) as ctx:
36-
# tracking active in this thread
9+
def test_execution_tracker_thread_isolation() -> None:
10+
def worker(id: str, row_counts: list[int]) -> t.Dict[str, t.Any]:
11+
with QueryExecutionTracker.track_execution(id) as ctx:
3712
assert QueryExecutionTracker.is_tracking()
38-
# synchronize start to overlap execution
39-
start_barrier.wait()
40-
for c in counts:
41-
QueryExecutionTracker.record_execution("SELECT 1", c)
4213

43-
stats = ctx.get_execution_stats()
14+
for count in row_counts:
15+
QueryExecutionTracker.record_execution("SELECT 1", count)
4416

45-
assert stats["snapshot_batch"] == batch_id
46-
assert stats["query_count"] == len(counts)
47-
results.put(stats)
17+
assert ctx is not None
18+
return ctx.get_execution_stats()
4819

49-
t1 = threading.Thread(target=worker, args=("batch_A", counts_a))
50-
t2 = threading.Thread(target=worker, args=("batch_B", counts_b))
51-
52-
t1.start()
53-
t2.start()
54-
# Release workers at the same time
55-
start_barrier.wait()
56-
t1.join()
57-
t2.join()
20+
with ThreadPoolExecutor() as executor:
21+
futures = [
22+
executor.submit(worker, "batch_A", [10, 5]),
23+
executor.submit(worker, "batch_B", [3, 7]),
24+
]
25+
results = [f.result() for f in futures]
5826

5927
# Main thread has no active tracking context
6028
assert not QueryExecutionTracker.is_tracking()
6129
QueryExecutionTracker.record_execution("q", 10)
62-
assert QueryExecutionTracker.get_execution_stats() is None
63-
64-
collected = [results.get_nowait(), results.get_nowait()]
65-
# by name since order is non-deterministic
66-
by_batch = {s["snapshot_batch"]: s for s in collected}
30+
assert QueryExecutionTracker.get_execution_stats("q") is None
6731

68-
stats_a = by_batch["batch_A"]
69-
assert stats_a["total_rows_processed"] == 15 # 10 + 5 + 0 (None)
70-
assert stats_a["query_count"] == len(counts_a)
32+
# Order of results is not deterministic, so look up by id
33+
by_batch = {s["id"]: s for s in results}
7134

72-
stats_b = by_batch["batch_B"]
73-
assert stats_b["total_rows_processed"] == 10 # 3 + 7
74-
assert stats_b["query_count"] == len(counts_b)
35+
assert by_batch["batch_A"]["total_rows_processed"] == 15
36+
assert by_batch["batch_A"]["query_count"] == 2
37+
assert by_batch["batch_B"]["total_rows_processed"] == 10
38+
assert by_batch["batch_B"]["query_count"] == 2

0 commit comments

Comments
 (0)