Skip to content

Commit 0edfea7

Browse files
authored
Chore: Move SnapshotEvaluator back at plan evaluator constructor (#4877)
1 parent 13024bc commit 0edfea7

7 files changed

Lines changed: 26 additions & 78 deletions

File tree

sqlmesh/core/config/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
130130
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
131131
return BuiltInPlanEvaluator(
132132
state_sync=context.state_sync,
133+
snapshot_evaluator=context.snapshot_evaluator,
133134
create_scheduler=context.create_scheduler,
134135
default_catalog=context.default_catalog,
135136
console=context.console,

sqlmesh/core/context.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
run_tests,
117117
)
118118
from sqlmesh.core.user import User
119-
from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId
119+
from sqlmesh.utils import UniqueKeyDict, Verbosity
120120
from sqlmesh.utils.concurrency import concurrent_apply_to_values
121121
from sqlmesh.utils.dag import DAG
122122
from sqlmesh.utils.date import (
@@ -417,7 +417,7 @@ def __init__(
417417
self.config.get_state_connection(self.gateway) or self.connection_config
418418
)
419419

420-
self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {}
420+
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
421421

422422
self.console = get_console()
423423
setattr(self.console, "dialect", self.config.dialect)
@@ -445,22 +445,18 @@ def engine_adapter(self) -> EngineAdapter:
445445
self._engine_adapter = self.connection_config.create_engine_adapter()
446446
return self._engine_adapter
447447

448-
def snapshot_evaluator(
449-
self, correlation_id: t.Optional[CorrelationId] = None
450-
) -> SnapshotEvaluator:
451-
# Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations
452-
if correlation_id not in self._snapshot_evaluators:
453-
self._snapshot_evaluators[correlation_id] = SnapshotEvaluator(
448+
@property
449+
def snapshot_evaluator(self) -> SnapshotEvaluator:
450+
if not self._snapshot_evaluator:
451+
self._snapshot_evaluator = SnapshotEvaluator(
454452
{
455-
gateway: adapter.with_settings(
456-
log_level=logging.INFO, correlation_id=correlation_id
457-
)
453+
gateway: adapter.with_settings(log_level=logging.INFO)
458454
for gateway, adapter in self.engine_adapters.items()
459455
},
460456
ddl_concurrent_tasks=self.concurrent_tasks,
461457
selected_gateway=self.selected_gateway,
462458
)
463-
return self._snapshot_evaluators[correlation_id]
459+
return self._snapshot_evaluator
464460

465461
def execution_context(
466462
self,
@@ -541,9 +537,7 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
541537

542538
return self.create_scheduler(snapshots)
543539

544-
def create_scheduler(
545-
self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None
546-
) -> Scheduler:
540+
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
547541
"""Creates the built-in scheduler.
548542
549543
Args:
@@ -554,7 +548,7 @@ def create_scheduler(
554548
"""
555549
return Scheduler(
556550
snapshots,
557-
self.snapshot_evaluator(correlation_id),
551+
self.snapshot_evaluator,
558552
self.state_sync,
559553
default_catalog=self.default_catalog,
560554
max_workers=self.concurrent_tasks,
@@ -719,7 +713,7 @@ def run(
719713
NotificationEvent.RUN_START, environment=environment
720714
)
721715
analytics_run_id = analytics.collector.on_run_start(
722-
engine_type=self.snapshot_evaluator().adapter.dialect,
716+
engine_type=self.snapshot_evaluator.adapter.dialect,
723717
state_sync_type=self.state_sync.state_type(),
724718
)
725719
self._load_materializations()
@@ -1081,7 +1075,7 @@ def evaluate(
10811075
and not parent_snapshot.categorized
10821076
]
10831077

1084-
df = self.snapshot_evaluator().evaluate_and_fetch(
1078+
df = self.snapshot_evaluator.evaluate_and_fetch(
10851079
snapshot,
10861080
start=start,
10871081
end=end,
@@ -1593,12 +1587,7 @@ def apply(
15931587
default_catalog=self.default_catalog,
15941588
console=self.console,
15951589
)
1596-
explainer.evaluate(
1597-
plan.to_evaluatable(),
1598-
snapshot_evaluator=self.snapshot_evaluator(
1599-
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
1600-
),
1601-
)
1590+
explainer.evaluate(plan.to_evaluatable())
16021591
return
16031592

16041593
self.notification_target_manager.notify(
@@ -2121,7 +2110,7 @@ def audit(
21212110
errors = []
21222111
skipped_count = 0
21232112
for snapshot in snapshots:
2124-
for audit_result in self.snapshot_evaluator().audit(
2113+
for audit_result in self.snapshot_evaluator.audit(
21252114
snapshot=snapshot,
21262115
start=start,
21272116
end=end,
@@ -2153,7 +2142,7 @@ def audit(
21532142
self.console.log_status_update(f"Got {error.count} results, expected 0.")
21542143
if error.query:
21552144
self.console.show_sql(
2156-
f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}"
2145+
f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}"
21572146
)
21582147

21592148
self.console.log_status_update("Done.")
@@ -2345,14 +2334,12 @@ def print_environment_names(self) -> None:
23452334

23462335
def close(self) -> None:
23472336
"""Releases all resources allocated by this context."""
2348-
for evaluator in self._snapshot_evaluators.values():
2349-
evaluator.close()
2337+
if self._snapshot_evaluator:
2338+
self._snapshot_evaluator.close()
23502339

23512340
if self._state_sync:
23522341
self._state_sync.close()
23532342

2354-
self._snapshot_evaluators.clear()
2355-
23562343
def _run(
23572344
self,
23582345
environment: str,
@@ -2403,11 +2390,7 @@ def _run(
24032390

24042391
def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
24052392
self._scheduler.create_plan_evaluator(self).evaluate(
2406-
plan.to_evaluatable(),
2407-
snapshot_evaluator=self.snapshot_evaluator(
2408-
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
2409-
),
2410-
circuit_breaker=circuit_breaker,
2393+
plan.to_evaluatable(), circuit_breaker=circuit_breaker
24112394
)
24122395

24132396
@python_api_analytics
@@ -2700,7 +2683,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
27002683
)
27012684

27022685
# Remove the expired snapshots tables
2703-
self.snapshot_evaluator().cleanup(
2686+
self.snapshot_evaluator.cleanup(
27042687
target_snapshots=cleanup_targets,
27052688
on_complete=self.console.update_cleanup_progress,
27062689
)

sqlmesh/core/plan/evaluator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ class PlanEvaluator(abc.ABC):
5151
def evaluate(
5252
self,
5353
plan: EvaluatablePlan,
54-
snapshot_evaluator: SnapshotEvaluator,
5554
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
5655
) -> None:
5756
"""Evaluates a plan by pushing snapshots and backfilling data.
@@ -63,7 +62,6 @@ def evaluate(
6362
6463
Args:
6564
plan: The plan to evaluate.
66-
snapshot_evaluator: The snapshot evaluator to use.
6765
circuit_breaker: The circuit breaker to use.
6866
"""
6967

@@ -72,11 +70,13 @@ class BuiltInPlanEvaluator(PlanEvaluator):
7270
def __init__(
7371
self,
7472
state_sync: StateSync,
73+
snapshot_evaluator: SnapshotEvaluator,
7574
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
7675
default_catalog: t.Optional[str],
7776
console: t.Optional[Console] = None,
7877
):
7978
self.state_sync = state_sync
79+
self.snapshot_evaluator = snapshot_evaluator
8080
self.create_scheduler = create_scheduler
8181
self.default_catalog = default_catalog
8282
self.console = console or get_console()
@@ -85,11 +85,9 @@ def __init__(
8585
def evaluate(
8686
self,
8787
plan: EvaluatablePlan,
88-
snapshot_evaluator: SnapshotEvaluator,
8988
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
9089
) -> None:
9190
self._circuit_breaker = circuit_breaker
92-
self.snapshot_evaluator = snapshot_evaluator
9391

9492
self.console.start_plan_evaluation(plan)
9593
analytics.collector.on_plan_apply_start(

sqlmesh/core/plan/explainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2121
from sqlmesh.utils.date import to_ts
2222
from sqlmesh.utils.errors import SQLMeshError
23-
from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator
2423

2524

2625
logger = logging.getLogger(__name__)
@@ -40,7 +39,6 @@ def __init__(
4039
def evaluate(
4140
self,
4241
plan: EvaluatablePlan,
43-
snapshot_evaluator: SnapshotEvaluator,
4442
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
4543
) -> None:
4644
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)

tests/conftest.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
SnapshotDataVersion,
4343
SnapshotFingerprint,
4444
)
45-
from sqlmesh.utils import random_id, CorrelationId
45+
from sqlmesh.utils import random_id
4646
from sqlmesh.utils.date import TimeLike, to_date
4747
from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
4848
from sqlmesh.core.engine_adapter.shared import CatalogSupport
@@ -266,12 +266,10 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
266266
def push_plan(context: Context, plan: Plan) -> None:
267267
plan_evaluator = BuiltInPlanEvaluator(
268268
context.state_sync,
269+
context.snapshot_evaluator,
269270
context.create_scheduler,
270271
context.default_catalog,
271272
)
272-
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(
273-
CorrelationId.from_plan_id(plan.plan_id)
274-
)
275273
deployability_index = DeployabilityIndex.create(context.snapshots.values())
276274
evaluatable_plan = plan.to_evaluatable()
277275
stages = plan_stages.build_plan_stages(

tests/core/test_integration.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
SnapshotInfoLike,
6868
SnapshotTableInfo,
6969
)
70-
from sqlmesh.utils import CorrelationId
7170
from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp
7271
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
7372
from sqlmesh.utils.pydantic import validate_string
@@ -1138,7 +1137,7 @@ def test_non_breaking_change_after_forward_only_in_dev(
11381137
init_and_plan_context: t.Callable, has_view_binding: bool
11391138
):
11401139
context, plan = init_and_plan_context("examples/sushi")
1141-
context.snapshot_evaluator().adapter.HAS_VIEW_BINDING = has_view_binding
1140+
context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding
11421141
context.apply(plan)
11431142

11441143
model = context.get_model("sushi.waiter_revenue_by_day")
@@ -6794,29 +6793,3 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
67946793
# valid_from should be the epoch, valid_to should be NaT
67956794
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
67966795
assert pd.isna(row["valid_to"])
6797-
6798-
6799-
def test_plan_evaluator_correlation_id(tmp_path: Path):
6800-
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
6801-
sqls = [call[0][0] for call in mock_logger.call_args_list]
6802-
return any(f"/* {correlation_id} */" in sql for sql in sqls)
6803-
6804-
create_temp_file(
6805-
tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
6806-
)
6807-
6808-
# Case 1: Ensure that the correlation id (plan_id) is included in the SQL
6809-
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6810-
ctx = Context(paths=[tmp_path], config=Config())
6811-
plan = ctx.plan(auto_apply=True, no_prompts=True)
6812-
6813-
correlation_id = CorrelationId.from_plan_id(plan.plan_id)
6814-
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"
6815-
6816-
assert _correlation_id_in_sqls(correlation_id, mock_logger)
6817-
6818-
# Case 2: Ensure that the previous correlation id is not included in the SQL for other operations
6819-
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6820-
ctx.snapshot_evaluator().adapter.execute("SELECT 1")
6821-
6822-
assert not _correlation_id_in_sqls(correlation_id, mock_logger)

tests/core/test_plan_evaluator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
stages as plan_stages,
1212
)
1313
from sqlmesh.core.snapshot import SnapshotChangeCategory
14-
from sqlmesh.utils import CorrelationId
1514

1615

1716
@pytest.fixture
@@ -60,13 +59,11 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):
6059

6160
evaluator = BuiltInPlanEvaluator(
6261
sushi_context.state_sync,
62+
sushi_context.snapshot_evaluator,
6363
sushi_context.create_scheduler,
6464
sushi_context.default_catalog,
6565
console=sushi_context.console,
6666
)
67-
evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator(
68-
CorrelationId.from_plan_id(plan.plan_id)
69-
)
7067

7168
evaluatable_plan = plan.to_evaluatable()
7269
stages = plan_stages.build_plan_stages(

0 commit comments

Comments
 (0)