diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index ef4cb2506e..4326f11f1e 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -21,10 +21,10 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import cached_property -from typing import Any, TypeAlias +from typing import Annotated, Any, TypeAlias from typing import Literal as TypingLiteral -from pydantic import ConfigDict, Field, SerializeAsAny, model_validator +from pydantic import BeforeValidator, ConfigDict, Field, SerializeAsAny, model_validator from pydantic_core.core_schema import ValidatorFunctionWrapHandler from pyiceberg.expressions.literals import AboveMax, BelowMin, Literal, literal @@ -508,7 +508,7 @@ def as_unbound(self) -> type[UnboundPredicate]: ... class UnboundPredicate(Unbound, BooleanExpression, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) - term: UnboundTerm + term: Annotated[str | UnboundTerm, BeforeValidator(_to_unbound_term)] def __init__(self, term: str | UnboundTerm, **kwargs: Any) -> None: super().__init__(term=_to_unbound_term(term), **kwargs) @@ -540,7 +540,7 @@ def __str__(self) -> str: return f"{str(self.__class__.__name__)}(term={str(self.term)})" def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate: - bound_term = self.term.bind(schema, case_sensitive) + bound_term = self.term.bind(schema, case_sensitive) # type: ignore[union-attr] bound_type = self.as_bound return bound_type(bound_term) # type: ignore[misc] @@ -696,7 +696,7 @@ def __init__( super().__init__(term=_to_unbound_term(term), values=literal_set) def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate: - bound_term = self.term.bind(schema, case_sensitive) + bound_term = self.term.bind(schema, case_sensitive) # type: ignore[union-attr] literal_set = self.literals return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in literal_set}) # type: ignore @@ -716,7 +716,7 @@ def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the SetPredicate class.""" return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False - def __getnewargs__(self) -> tuple[UnboundTerm, set[Any]]: + def __getnewargs__(self) -> tuple[str | UnboundTerm, set[Any]]: """Pickle the SetPredicate class.""" return (self.term, self.literals) @@ -870,7 +870,7 @@ def as_bound(self) -> type[BoundNotIn]: # type: ignore class LiteralPredicate(UnboundPredicate, ABC): type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") - term: UnboundTerm + term: Annotated[str | UnboundTerm, BeforeValidator(_to_unbound_term)] value: LiteralValue = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) @@ -885,7 +885,7 @@ def literal(self) -> LiteralValue: return self.value def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate: - bound_term = self.term.bind(schema, case_sensitive) + bound_term = self.term.bind(schema, case_sensitive) # type: ignore[union-attr] lit = self.literal.to(bound_term.ref().field.field_type) if isinstance(lit, AboveMax):