Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ValidationInfo,
field_validator,
model_validator,
validation_data,
validation_error_message,
get_concrete_types_from_typehint,
)
Expand Down Expand Up @@ -1081,7 +1082,7 @@ def validate_execution_project(
v: t.Optional[str],
info: ValidationInfo,
) -> t.Optional[str]:
if v and not info.data.get("project"):
if v and not validation_data(info).get("project"):
raise ConfigError(
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
)
Expand All @@ -1093,7 +1094,7 @@ def validate_quota_project(
v: t.Optional[str],
info: ValidationInfo,
) -> t.Optional[str]:
if v and not info.data.get("project"):
if v and not validation_data(info).get("project"):
raise ConfigError(
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
)
Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _sanitize_name(cls, v: str) -> str:
@classmethod
def _validate_boolean_field(cls, v: t.Any, info: ValidationInfo) -> bool:
if v is None:
return info.field_name == "normalize_name"
# Pydantic 2.13+ sets field_name to None during model_validate_json()
return (info.field_name or "") == "normalize_name"
return bool(v)

@t.overload
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/metric/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlmesh.core.node import str_or_exp_to_str
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, validation_data

MeasureAndDimTables = t.Tuple[str, t.Tuple[str, ...]]

Expand Down Expand Up @@ -89,7 +89,7 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]:
@field_validator("expression", mode="before")
def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expr:
if isinstance(v, str):
dialect = info.data.get("dialect")
dialect = validation_data(info).get("dialect")
return d.parse_one(v, dialect=dialect)
if isinstance(v, exp.Expr):
return v
Expand Down
17 changes: 12 additions & 5 deletions sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
prepare_env,
serialize_env,
)
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect
from sqlmesh.utils.pydantic import (
PydanticModel,
ValidationInfo,
field_validator,
get_dialect,
validation_data,
)

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
Expand Down Expand Up @@ -479,7 +485,7 @@ def parse_expression(
if callable(v):
return v

dialect = info.data.get("dialect") if info else ""
dialect = validation_data(info).get("dialect") if info else ""

if isinstance(v, list):
return [
Expand Down Expand Up @@ -519,7 +525,7 @@ def parse_properties(
if v is None:
return v

dialect = info.data.get("dialect") if info else ""
dialect = validation_data(info).get("dialect") if info else ""

if isinstance(v, str):
v = d.parse_one(v, dialect=dialect)
Expand Down Expand Up @@ -557,8 +563,9 @@ def default_catalog(cls: t.Type, v: t.Any) -> t.Optional[str]:


def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[str]]:
dialect = info.data.get("dialect")
default_catalog = info.data.get("default_catalog")
data = validation_data(info)
dialect = data.get("dialect")
default_catalog = data.get("default_catalog")

if isinstance(v, exp.Paren):
v = v.unnest()
Expand Down
20 changes: 11 additions & 9 deletions sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
list_of_fields_validator,
model_validator,
get_dialect,
validation_data,
)

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -135,7 +136,7 @@ def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any:

@field_validator("tags", mode="before")
def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
return ensure_list(cls._validate_value_or_tuple(v, info.data))
return ensure_list(cls._validate_value_or_tuple(v, validation_data(info)))

@classmethod
def _validate_value_or_tuple(
Expand Down Expand Up @@ -164,7 +165,7 @@ def _normalize(value: t.Any) -> t.Any:
@field_validator("table_format", "storage_format", mode="before")
def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]:
if isinstance(v, exp.Expr) and not (isinstance(v, (exp.Literal, exp.Identifier))):
return v.sql(info.data.get("dialect"))
return v.sql(validation_data(info).get("dialect"))
return str_or_exp_to_str(v)

@field_validator("dialect", mode="before")
Expand Down Expand Up @@ -192,7 +193,7 @@ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.L
if (
isinstance(v, list)
and all(isinstance(i, str) for i in v)
and info.field_name == "partitioned_by_"
and (info.field_name or "") == "partitioned_by_"
):
# this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str]
# however, we should only invoke this if the list contains strings because this validator is also
Expand All @@ -205,7 +206,7 @@ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.L
)
v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v

expressions = list_of_fields_validator(v, info.data)
expressions = list_of_fields_validator(v, validation_data(info))

for expression in expressions:
num_cols = len(list(expression.find_all(exp.Column)))
Expand All @@ -228,7 +229,7 @@ def _columns_validator(
cls, v: t.Any, info: ValidationInfo
) -> t.Optional[t.Dict[str, exp.DataType]]:
columns_to_types = {}
dialect = info.data.get("dialect")
dialect = validation_data(info).get("dialect")

if isinstance(v, exp.Schema):
for column in v.expressions:
Expand Down Expand Up @@ -280,7 +281,8 @@ def _columns_validator(
def _column_descriptions_validator(
cls, vs: t.Any, info: ValidationInfo
) -> t.Optional[t.Dict[str, str]]:
dialect = info.data.get("dialect")
data = validation_data(info)
dialect = data.get("dialect")

if vs is None:
return None
Expand All @@ -302,23 +304,23 @@ def _column_descriptions_validator(
for k, v in raw_col_descriptions.items()
}

columns_to_types = info.data.get("columns_to_types_")
columns_to_types = data.get("columns_to_types_")
if columns_to_types:
from sqlmesh.core.console import get_console

console = get_console()
for column_name in list(col_descriptions):
if column_name not in columns_to_types:
console.log_warning(
f"In model '{info.data['name']}', a description is provided for column '{column_name}' but it is not a column in the model."
f"In model '{data.get('name', '<unknown>')}', a description is provided for column '{column_name}' but it is not a column in the model."
)
del col_descriptions[column_name]

return col_descriptions

@field_validator("grains", "references", mode="before")
def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expr]:
dialect = info.data.get("dialect")
dialect = validation_data(info).get("dialect")

if isinstance(vs, exp.Paren):
vs = vs.unnest()
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/state_sync/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic_core.core_schema import ValidationInfo
from sqlglot import exp

from sqlmesh.utils.pydantic import PydanticModel, field_validator
from sqlmesh.utils.pydantic import PydanticModel, field_validator, validation_data
from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentNamingInfo
from sqlmesh.core.snapshot import (
Snapshot,
Expand Down Expand Up @@ -269,7 +269,7 @@ class PromotionResult(PydanticModel):
def _validate_removed_environment_naming_info(
cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo
) -> t.Optional[EnvironmentNamingInfo]:
if v and not info.data.get("removed"):
if v and not validation_data(info).get("removed"):
raise ValueError("removed_environment_naming_info must be None if removed is empty")
return v

Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum

from sqlmesh.core.notification_target import BasicSMTPNotificationTarget, NotificationTarget
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, validation_data


class UserRole(str, Enum):
Expand Down Expand Up @@ -42,7 +42,7 @@ def validate_notification_targets(
v: t.List[NotificationTarget],
info: ValidationInfo,
) -> t.List[NotificationTarget]:
email = info.data["email"]
email = validation_data(info).get("email")
for target in v:
if isinstance(target, BasicSMTPNotificationTarget) and target.recipients != {email}:
raise ValueError("Recipient emails do not match user email")
Expand Down
15 changes: 14 additions & 1 deletion sqlmesh/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def field_serializer(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any
return pydantic.field_serializer(*args, **kwargs)


def validation_data(info_or_data: t.Any) -> t.Dict[str, t.Any]:
"""Safely extract the validated-data dict from a ValidationInfo, dict, or None.

Pydantic 2.13+ sets ValidationInfo.data to None during model_validate_json().
This normalizes all inputs to a dict, returning an empty dict when data is unavailable.
"""
if isinstance(info_or_data, dict):
return info_or_data
if info_or_data is not None:
return info_or_data.data or {}
return {}


def get_dialect(values: t.Any) -> str:
"""Extracts dialect from a dict or pydantic obj, defaulting to the globally set dialect.

Expand All @@ -52,7 +65,7 @@ def get_dialect(values: t.Any) -> str:

from sqlmesh.core.model import model

dialect = (values if isinstance(values, dict) else values.data).get("dialect")
dialect = validation_data(values).get("dialect")
return model._dialect if dialect is None else dialect # type: ignore


Expand Down
7 changes: 4 additions & 3 deletions web/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
SnapshotId,
)
from sqlmesh.utils.date import TimeLike, now_timestamp
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, validation_data

SUPPORTED_EXTENSIONS = {".py", ".sql", ".yaml", ".yml", ".csv"}

Expand Down Expand Up @@ -117,8 +117,9 @@ class File(PydanticModel):

@field_validator("extension", mode="before")
def default_extension(cls, v: str, info: ValidationInfo) -> str:
if "name" in info.data:
return pathlib.Path(info.data["name"]).suffix
data = validation_data(info)
if "name" in data:
return pathlib.Path(data["name"]).suffix
return v


Expand Down
Loading