|
1 | 1 | from __future__ import annotations |
2 | 2 | import logging |
3 | 3 | import typing as t |
| 4 | +import time |
4 | 5 | from sqlglot import exp |
5 | 6 | from sqlmesh.core import constants as c |
6 | 7 | from sqlmesh.core.console import Console, get_console |
|
24 | 25 | snapshots_to_dag, |
25 | 26 | Intervals, |
26 | 27 | ) |
| 28 | +from sqlmesh.core.snapshot.definition import check_ready_intervals |
27 | 29 | from sqlmesh.core.snapshot.definition import ( |
28 | 30 | Interval, |
29 | 31 | expand_range, |
|
39 | 41 | to_timestamp, |
40 | 42 | validate_date_range, |
41 | 43 | ) |
42 | | -from sqlmesh.utils.errors import AuditError, NodeAuditsErrors, CircuitBreakerError, SQLMeshError |
| 44 | +from sqlmesh.utils.errors import ( |
| 45 | + AuditError, |
| 46 | + NodeAuditsErrors, |
| 47 | + CircuitBreakerError, |
| 48 | + SQLMeshError, |
| 49 | + SignalEvalError, |
| 50 | +) |
| 51 | + |
| 52 | +if t.TYPE_CHECKING: |
| 53 | + from sqlmesh.core.context import ExecutionContext |
43 | 54 |
|
44 | 55 | logger = logging.getLogger(__name__) |
45 | 56 | SnapshotToIntervals = t.Dict[Snapshot, Intervals] |
@@ -304,12 +315,11 @@ def batch_intervals( |
304 | 315 | default_catalog=self.default_catalog, |
305 | 316 | ) |
306 | 317 |
|
307 | | - intervals = snapshot.check_ready_intervals( |
| 318 | + intervals = self._check_ready_intervals( |
| 319 | + snapshot, |
308 | 320 | intervals, |
309 | 321 | context, |
310 | | - console=self.console, |
311 | | - default_catalog=self.default_catalog, |
312 | | - environment_naming_info=environment_naming_info, |
| 322 | + environment_naming_info, |
313 | 323 | ) |
314 | 324 | unready -= set(intervals) |
315 | 325 |
|
@@ -709,6 +719,76 @@ def _audit_snapshot( |
709 | 719 |
|
710 | 720 | return audit_results |
711 | 721 |
|
| 722 | + def _check_ready_intervals( |
| 723 | + self, |
| 724 | + snapshot: Snapshot, |
| 725 | + intervals: Intervals, |
| 726 | + context: ExecutionContext, |
| 727 | + environment_naming_info: EnvironmentNamingInfo, |
| 728 | + ) -> Intervals: |
| 729 | + """Checks if the intervals are ready for evaluation for the given snapshot. |
| 730 | +
|
| 731 | + This implementation also includes the signal progress tracking. |
| 732 | + Note that this will handle gaps in the provided intervals. The returned intervals |
| 733 | + may introduce new gaps. |
| 734 | +
|
| 735 | + Args: |
| 736 | + snapshot: The snapshot to check. |
| 737 | + intervals: The intervals to check. |
| 738 | + context: The context to use. |
| 739 | + environment_naming_info: The environment naming info to use. |
| 740 | +
|
| 741 | + Returns: |
| 742 | + The intervals that are ready for evaluation. |
| 743 | + """ |
| 744 | + signals = snapshot.is_model and snapshot.model.render_signal_calls() |
| 745 | + |
| 746 | + if not signals: |
| 747 | + return intervals |
| 748 | + |
| 749 | + self.console.start_signal_progress( |
| 750 | + snapshot, |
| 751 | + self.default_catalog, |
| 752 | + environment_naming_info or EnvironmentNamingInfo(), |
| 753 | + ) |
| 754 | + |
| 755 | + for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()): |
| 756 | + # Capture intervals before signal check for display |
| 757 | + intervals_to_check = merge_intervals(intervals) |
| 758 | + |
| 759 | + signal_start_ts = time.perf_counter() |
| 760 | + |
| 761 | + try: |
| 762 | + intervals = check_ready_intervals( |
| 763 | + signals.prepared_python_env[signal_name], |
| 764 | + intervals, |
| 765 | + context, |
| 766 | + python_env=signals.python_env, |
| 767 | + dialect=snapshot.model.dialect, |
| 768 | + path=snapshot.model._path, |
| 769 | + kwargs=kwargs, |
| 770 | + ) |
| 771 | + except SQLMeshError as e: |
| 772 | + raise SignalEvalError( |
| 773 | + f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}" |
| 774 | + ) |
| 775 | + |
| 776 | + duration = time.perf_counter() - signal_start_ts |
| 777 | + |
| 778 | + self.console.update_signal_progress( |
| 779 | + snapshot=snapshot, |
| 780 | + signal_name=signal_name, |
| 781 | + signal_idx=signal_idx, |
| 782 | + total_signals=len(signals.signals_to_kwargs), |
| 783 | + ready_intervals=merge_intervals(intervals), |
| 784 | + check_intervals=intervals_to_check, |
| 785 | + duration=duration, |
| 786 | + ) |
| 787 | + |
| 788 | + self.console.stop_signal_progress() |
| 789 | + |
| 790 | + return intervals |
| 791 | + |
712 | 792 |
|
713 | 793 | def merged_missing_intervals( |
714 | 794 | snapshots: t.Collection[Snapshot], |
|
0 commit comments