-
Notifications
You must be signed in to change notification settings - Fork 376
Expand file tree
/
Copy pathenvironment.py
More file actions
347 lines (296 loc) · 12.9 KB
/
environment.py
File metadata and controls
347 lines (296 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
from __future__ import annotations
import json
import re
import typing as t
from pydantic import Field
from sqlmesh.core import constants as c
from sqlmesh.core.config import EnvironmentSuffixTarget
from sqlmesh.core.engine_adapter.base import EngineAdapter
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.renderer import render_statements
from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo, Snapshot
from sqlmesh.utils import word_characters_only
from sqlmesh.utils.date import TimeLike, now_timestamp
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.utils.metaprogramming import Executable
from sqlmesh.utils.pydantic import PydanticModel, field_validator, ValidationInfo
T = t.TypeVar("T", bound="EnvironmentNamingInfo")
PydanticType = t.TypeVar("PydanticType", bound="PydanticModel")
class EnvironmentNamingInfo(PydanticModel):
"""
Information required for creating an object within an environment
Args:
name: The name of the environment.
suffix_target: Indicates whether to append the environment name to the schema or table name.
catalog_name_override: The name of the catalog to use for this environment if an override was provided
normalize_name: Indicates whether the environment's name will be normalized. For example, if it's
`dev`, then it will become `DEV` when targeting Snowflake.
gateway_managed: Determines whether the virtual layer's views are created by the model-specific
gateways, otherwise the default gateway is used. Default: False.
"""
name: str = c.PROD
suffix_target: EnvironmentSuffixTarget = Field(default=EnvironmentSuffixTarget.SCHEMA)
catalog_name_override: t.Optional[str] = None
normalize_name: bool = True
gateway_managed: bool = False
@property
def is_dev(self) -> bool:
return self.name.lower() != c.PROD
@field_validator("name", mode="before")
@classmethod
def _sanitize_name(cls, v: str) -> str:
return word_characters_only(v).lower()
@field_validator("normalize_name", "gateway_managed", mode="before")
@classmethod
def _validate_boolean_field(cls, v: t.Any, info: ValidationInfo) -> bool:
if v is None:
# Pydantic 2.13+ sets field_name to None during model_validate_json()
return (info.field_name or "") == "normalize_name"
return bool(v)
@t.overload
@classmethod
def sanitize_name(cls, v: str) -> str: ...
@t.overload
@classmethod
def sanitize_name(cls, v: Environment) -> Environment: ...
@classmethod
def sanitize_name(cls, v: str | Environment) -> str | Environment:
"""
Sanitizes the environment name so we create names that are valid names for database objects.
This means alphanumeric and underscores only. Invalid characters are replaced with underscores.
"""
if isinstance(v, Environment):
return v
if not isinstance(v, str):
raise TypeError(f"Expected str or Environment, got {type(v).__name__}")
return cls._sanitize_name(v)
@classmethod
def sanitize_names(cls, values: t.Iterable[str]) -> t.Set[str]:
return {cls.sanitize_name(value) for value in values}
@classmethod
def from_environment_catalog_mapping(
cls: t.Type[T],
environment_catalog_mapping: t.Dict[re.Pattern, str],
name: str = c.PROD,
**kwargs: t.Any,
) -> T:
construction_kwargs = dict(name=name, **kwargs)
for re_pattern, catalog_name in environment_catalog_mapping.items():
if re.match(re_pattern, name):
return cls(
catalog_name_override=catalog_name,
**construction_kwargs,
)
return cls(**construction_kwargs)
class EnvironmentSummary(PydanticModel):
"""Represents summary information of an isolated environment.
Args:
name: The name of the environment.
start_at: The start time of the environment.
end_at: The end time of the environment.
plan_id: The ID of the plan that last updated this environment.
previous_plan_id: The ID of the previous plan that updated this environment.
expiration_ts: The timestamp when this environment will expire.
finalized_ts: The timestamp when this environment was finalized.
"""
name: str
start_at: TimeLike
end_at: t.Optional[TimeLike] = None
plan_id: str
previous_plan_id: t.Optional[str] = None
expiration_ts: t.Optional[int] = None
finalized_ts: t.Optional[int] = None
@property
def expired(self) -> bool:
return self.expiration_ts is not None and self.expiration_ts <= now_timestamp()
class Environment(EnvironmentNamingInfo, EnvironmentSummary):
"""Represents an isolated environment.
Environments are isolated workspaces that hold pointers to physical tables.
Args:
snapshots: The snapshots that are part of this environment.
promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment
(i.e. for which the views are created). If not specified, all snapshots are promoted.
previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized.
requirements: A mapping of library versions for all the snapshots in this environment.
"""
snapshots_: t.List[t.Any] = Field(alias="snapshots")
promoted_snapshot_ids_: t.Optional[t.List[t.Any]] = Field(
default=None, alias="promoted_snapshot_ids"
)
previous_finalized_snapshots_: t.Optional[t.List[t.Any]] = Field(
default=None, alias="previous_finalized_snapshots"
)
requirements: t.Dict[str, str] = {}
@field_validator("snapshots_", "previous_finalized_snapshots_", mode="before")
@classmethod
def _load_snapshots(cls, v: str | t.List[t.Any] | None) -> t.List[t.Any] | None:
if isinstance(v, str):
return json.loads(v)
if v and not isinstance(next(iter(v)), (dict, SnapshotTableInfo)):
raise ValueError("Must be a list of SnapshotTableInfo dicts or objects")
return v
@field_validator("promoted_snapshot_ids_", mode="before")
@classmethod
def _load_snapshot_ids(cls, v: str | t.List[t.Any] | None) -> t.List[t.Any] | None:
if isinstance(v, str):
return json.loads(v)
if v and not isinstance(next(iter(v)), (dict, SnapshotId)):
raise ValueError("Must be a list of SnapshotId dicts or objects")
return v
@field_validator("requirements", mode="before")
def _load_requirements(cls, v: t.Any) -> t.Any:
if isinstance(v, str):
v = json.loads(v)
return v or {}
@property
def snapshots(self) -> t.List[SnapshotTableInfo]:
return self._convert_list_to_models_and_store("snapshots_", SnapshotTableInfo) or []
def snapshot_dicts(self) -> t.List[dict]:
return self._convert_list_to_dicts(self.snapshots_)
@property
def promoted_snapshot_ids(self) -> t.Optional[t.List[SnapshotId]]:
return self._convert_list_to_models_and_store("promoted_snapshot_ids_", SnapshotId)
def promoted_snapshot_id_dicts(self) -> t.List[dict]:
return self._convert_list_to_dicts(self.promoted_snapshot_ids_)
@property
def promoted_snapshots(self) -> t.List[SnapshotTableInfo]:
if self.promoted_snapshot_ids is None:
return self.snapshots
promoted_snapshot_ids = set(self.promoted_snapshot_ids)
return [s for s in self.snapshots if s.snapshot_id in promoted_snapshot_ids]
@property
def previous_finalized_snapshots(self) -> t.Optional[t.List[SnapshotTableInfo]]:
return self._convert_list_to_models_and_store(
"previous_finalized_snapshots_", SnapshotTableInfo
)
def previous_finalized_snapshot_dicts(self) -> t.List[dict]:
return self._convert_list_to_dicts(self.previous_finalized_snapshots_)
@property
def finalized_or_current_snapshots(self) -> t.List[SnapshotTableInfo]:
return (
self.snapshots
if self.finalized_ts
else self.previous_finalized_snapshots or self.snapshots
)
@property
def naming_info(self) -> EnvironmentNamingInfo:
return EnvironmentNamingInfo(
name=self.name,
suffix_target=self.suffix_target,
catalog_name_override=self.catalog_name_override,
normalize_name=self.normalize_name,
gateway_managed=self.gateway_managed,
)
@property
def summary(self) -> EnvironmentSummary:
return EnvironmentSummary(
name=self.name,
start_at=self.start_at,
end_at=self.end_at,
plan_id=self.plan_id,
previous_plan_id=self.previous_plan_id,
expiration_ts=self.expiration_ts,
finalized_ts=self.finalized_ts,
)
def can_partially_promote(self, existing_environment: Environment) -> bool:
"""Returns True if the existing environment can be partially promoted to the current environment.
Partial promotion means that we don't need to re-create views for snapshots that are already promoted in the
target environment.
"""
return (
bool(existing_environment.finalized_ts)
and not existing_environment.expired
and existing_environment.gateway_managed == self.gateway_managed
and existing_environment.name == c.PROD
)
def _convert_list_to_models_and_store(
self, field: str, type_: t.Type[PydanticType]
) -> t.Optional[t.List[PydanticType]]:
value = getattr(self, field)
if value and not isinstance(value[0], type_):
value = [type_.parse_obj(obj) for obj in value]
setattr(self, field, value)
return value
def _convert_list_to_dicts(self, value: t.Optional[t.List[t.Any]]) -> t.List[dict]:
if not value:
return []
return value if isinstance(value[0], dict) else [v.dict() for v in value]
class EnvironmentStatements(PydanticModel):
before_all: t.List[str]
after_all: t.List[str]
python_env: t.Dict[str, Executable]
jinja_macros: t.Optional[JinjaMacroRegistry] = None
project: t.Optional[str] = None
def render_before_all(
self,
dialect: str,
default_catalog: t.Optional[str] = None,
**render_kwargs: t.Any,
) -> t.List[str]:
return self.render(RuntimeStage.BEFORE_ALL, dialect, default_catalog, **render_kwargs)
def render_after_all(
self,
dialect: str,
default_catalog: t.Optional[str] = None,
**render_kwargs: t.Any,
) -> t.List[str]:
return self.render(RuntimeStage.AFTER_ALL, dialect, default_catalog, **render_kwargs)
def render(
self,
runtime_stage: RuntimeStage,
dialect: str,
default_catalog: t.Optional[str] = None,
**render_kwargs: t.Any,
) -> t.List[str]:
return render_statements(
statements=getattr(self, runtime_stage.value),
dialect=dialect,
default_catalog=default_catalog,
python_env=self.python_env,
jinja_macros=self.jinja_macros,
runtime_stage=runtime_stage,
**render_kwargs,
)
def execute_environment_statements(
adapter: EngineAdapter,
environment_statements: t.List[EnvironmentStatements],
runtime_stage: RuntimeStage,
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str] = None,
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
execution_time: t.Optional[TimeLike] = None,
selected_models: t.Optional[t.Set[str]] = None,
) -> None:
try:
rendered_expressions = [
expr
for statements in environment_statements
for expr in statements.render(
runtime_stage=runtime_stage,
dialect=adapter.dialect,
default_catalog=default_catalog,
snapshots=snapshots,
start=start,
end=end,
execution_time=execution_time,
environment_naming_info=environment_naming_info,
engine_adapter=adapter,
selected_models=selected_models,
)
]
except Exception as e:
raise SQLMeshError(
f"An error occurred during rendering of the '{runtime_stage.value}' statements:\n\n{e}"
)
if rendered_expressions:
with adapter.transaction():
for expr in rendered_expressions:
try:
adapter.execute(expr)
except Exception as e:
raise SQLMeshError(
f"An error occurred during execution of the following '{runtime_stage.value}' statement:\n\n{expr}\n\n{e}"
)