diff --git a/dataframely/collection/_base.py b/dataframely/collection/_base.py index ff4140c..2d0cb34 100644 --- a/dataframely/collection/_base.py +++ b/dataframely/collection/_base.py @@ -15,6 +15,7 @@ from dataframely._filter import Filter from dataframely._polars import FrameType +from dataframely._typing import DataFrame as TypedDataFrame from dataframely._typing import LazyFrame as TypedLazyFrame from dataframely.exc import AnnotationImplementationError, ImplementationError from dataframely.schema import Schema @@ -92,6 +93,8 @@ class MemberInfo(CollectionMember): schema: type[Schema] #: Whether the member is optional. is_optional: bool + #: Whether the member is a lazy frame. + is_lazy: bool = True @dataclass @@ -241,39 +244,46 @@ def _derive_member_info( attr, annotation_args[0], annotation_args[1] ) elif origin == typing.Union: - # Happy path: optional member + # Happy path: optional member (e.g. dy.LazyFrame[Schema] | None) union_args = get_args(type_annotation) if len(union_args) != 2: raise AnnotationImplementationError(attr, type_annotation) - if not any(get_origin(arg) is None for arg in union_args): + # Check that exactly one arg is None (type(None) is NoneType) + if not any(arg is type(None) for arg in union_args): raise AnnotationImplementationError(attr, type_annotation) - not_none_args = [arg for arg in union_args if get_origin(arg) is not None] - if len(not_none_args) == 0 or not issubclass( - get_origin(not_none_args[0]), TypedLazyFrame - ): + # Get the non-None type (exactly one exists given prior checks) + not_none_arg = next(arg for arg in union_args if arg is not type(None)) + + frame_origin = get_origin(not_none_arg) + if frame_origin is None: raise AnnotationImplementationError(attr, type_annotation) - return MemberInfo( - schema=get_args(not_none_args[0])[0], - is_optional=True, - ignored_in_filters=collection_member.ignored_in_filters, - inline_for_sampling=collection_member.inline_for_sampling, - propagate_row_failures=collection_member.propagate_row_failures, - ) - elif issubclass(origin, TypedLazyFrame): - # Happy path: required member - return MemberInfo( - schema=get_args(type_annotation)[0], - is_optional=False, - ignored_in_filters=collection_member.ignored_in_filters, - inline_for_sampling=collection_member.inline_for_sampling, - propagate_row_failures=collection_member.propagate_row_failures, - ) + schema = get_args(not_none_arg)[0] + is_optional = True + elif issubclass(origin, (TypedLazyFrame, TypedDataFrame)): + frame_origin = origin + schema = get_args(type_annotation)[0] + is_optional = False + else: + raise AnnotationImplementationError(attr, type_annotation) + + if issubclass(frame_origin, TypedLazyFrame): + is_lazy = True + elif issubclass(frame_origin, TypedDataFrame): + is_lazy = False else: - # Some other unknown annotation raise AnnotationImplementationError(attr, type_annotation) + return MemberInfo( + schema=schema, + is_optional=is_optional, + is_lazy=is_lazy, + ignored_in_filters=collection_member.ignored_in_filters, + inline_for_sampling=collection_member.inline_for_sampling, + propagate_row_failures=collection_member.propagate_row_failures, + ) + def __repr__(cls) -> str: parts = [f'[Collection "{cls.__class__.__name__}"]'] parts.append(textwrap.indent("Members:", prefix=" " * 2)) @@ -344,6 +354,16 @@ def non_ignored_members(cls) -> set[str]: if not member.ignored_in_filters } + @classmethod + def lazy_members(cls) -> set[str]: + """The names of all members annotated as lazy frames.""" + return {name for name, member in cls.members().items() if member.is_lazy} + + @classmethod + def eager_members(cls) -> set[str]: + """The names of all members annotated as data frames (eager).""" + return {name for name, member in cls.members().items() if not member.is_lazy} + @classmethod def _failure_propagating_members(cls) -> set[str]: """The names of all members of the collection that propagate individual row @@ -372,9 +392,9 @@ def _filters(cls) -> dict[str, Filter[Self]]: return getattr(cls, _FILTER_ATTR) def to_dict(self) -> dict[str, pl.LazyFrame]: - """Return a dictionary representation of this collection.""" + """Return a dictionary with all members as lazy frames.""" return { - member: getattr(self, member) + member: getattr(self, member).lazy() for member in self.member_schemas() if getattr(self, member) is not None } @@ -385,6 +405,9 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self: for member_name, member in cls.members().items(): if member.is_optional and member_name not in data: setattr(out, member_name, None) - else: + elif member.is_lazy: setattr(out, member_name, data[member_name].lazy()) + else: + setattr(out, member_name, data[member_name].lazy().collect()) + return out diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 665b85a..61f8697 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -32,7 +32,7 @@ from dataframely._storage.constants import COLLECTION_METADATA_KEY from dataframely._storage.delta import DeltaStorageBackend from dataframely._storage.parquet import ParquetStorageBackend -from dataframely._typing import LazyFrame, Validation +from dataframely._typing import DataFrame, LazyFrame, Validation from dataframely.exc import ( DeserializationError, ValidationError, @@ -68,13 +68,13 @@ class Collection(BaseCollection, ABC): to 1-N relationships that are managed in separate data frames. A collection must only have type annotations for :class:`~dataframely.LazyFrame` - with known schema: + or :class:`~dataframely.DataFrame` with known schema: .. code:: python class MyCollection(dy.Collection): first_member: dy.LazyFrame[MyFirstSchema] - second_member: dy.LazyFrame[MySecondSchema] + second_member: dy.DataFrame[MySecondSchema] Besides, it may define *filters* (c.f. :meth:`~dataframely.filter`) and arbitrary methods. @@ -788,17 +788,14 @@ def collect_all(self) -> Self: particularly useful when :meth:`filter` is called with lazy frame inputs. Returns: - The same collection with all members collected once. - - Note: - As all collection members are required to be lazy frames, the returned - collection's members are still "lazy". However, they are "shallow-lazy", - meaning they are obtained by calling `.collect().lazy()`. + The same collection with all members collected once. Members annotated + with :class:`~dataframely.DataFrame` are returned as DataFrames, while + members annotated with :class:`~dataframely.LazyFrame` are returned as + "shallow-lazy" frames (obtained by calling ``.collect().lazy()``). """ - dfs = pl.collect_all(self.to_dict().values()) - return self._init( - {key: dfs[i].lazy() for i, key in enumerate(self.to_dict().keys())} - ) + lazy_dict = self.to_dict() + dfs = pl.collect_all(lazy_dict.values()) + return self._init(dict(zip(lazy_dict, dfs))) # --------------------------------- SERIALIZATION -------------------------------- # @@ -842,6 +839,7 @@ def serialize(cls) -> str: name: { "schema": info.schema._as_dict(), "is_optional": info.is_optional, + "is_lazy": info.is_lazy, "ignored_in_filters": info.ignored_in_filters, "inline_for_sampling": info.inline_for_sampling, } @@ -1330,11 +1328,14 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | annotations: dict[str, Any] = {} for name, info in decoded["members"].items(): - lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore + schema = _schema_from_dict(info["schema"]) + # Default to lazy for backwards compatibility with old serialized data + is_lazy = info.get("is_lazy", True) + frame_type = LazyFrame[schema] if is_lazy else DataFrame[schema] # type: ignore if info["is_optional"]: - lf_type = lf_type | None # type: ignore + frame_type = frame_type | None # type: ignore annotations[name] = Annotated[ - lf_type, + frame_type, CollectionMember( ignored_in_filters=info["ignored_in_filters"], inline_for_sampling=info["inline_for_sampling"], diff --git a/dataframely/collection/filter_result.py b/dataframely/collection/filter_result.py index 432d14c..d95e850 100644 --- a/dataframely/collection/filter_result.py +++ b/dataframely/collection/filter_result.py @@ -40,7 +40,9 @@ def collect_all(self, **kwargs: Any) -> CollectionFilterResult[C]: kwargs: Keyword arguments passed directly to :meth:`polars.collect_all`. Returns: - The same filter result object with all lazy frames collected and exposed as + The same filter result object with all frames collected. Members annotated + with :class:`~dataframely.DataFrame` are returned as DataFrames, while + members annotated with :class:`~dataframely.LazyFrame` are returned as "shallow" lazy frames. Attention: diff --git a/tests/collection/test_dataframe_members.py b/tests/collection/test_dataframe_members.py new file mode 100644 index 0000000..7ccc539 --- /dev/null +++ b/tests/collection/test_dataframe_members.py @@ -0,0 +1,127 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause +"""Tests for dy.DataFrame members in collections. + +Members annotated with dy.DataFrame are collected once during _init and stored as +DataFrames, while dy.LazyFrame members remain lazy. +""" + +import polars as pl +import pytest + +import dataframely as dy + +# ------------------------------------------------------------------------------------ # +# SCHEMA # +# ------------------------------------------------------------------------------------ # + + +class UserSchema(dy.Schema): + id = dy.Integer(primary_key=True) + name = dy.String() + + +class OrderSchema(dy.Schema): + id = dy.Integer(primary_key=True) + user_id = dy.Integer() + amount = dy.Float(min=0) + + +class EagerCollection(dy.Collection): + """Collection with only DataFrame (eager) members.""" + + users: dy.DataFrame[UserSchema] + orders: dy.DataFrame[OrderSchema] + + +class MixedCollection(dy.Collection): + """Collection with mixed DataFrame and LazyFrame members.""" + + users: dy.DataFrame[UserSchema] + orders: dy.LazyFrame[OrderSchema] + + +class LazyCollection(dy.Collection): + """Collection with only LazyFrame members (traditional).""" + + users: dy.LazyFrame[UserSchema] + orders: dy.LazyFrame[OrderSchema] + + +class OptionalEagerCollection(dy.Collection): + """Collection with optional DataFrame member.""" + + users: dy.DataFrame[UserSchema] + orders: dy.DataFrame[OrderSchema] | None + + +# ------------------------------------------------------------------------------------ # +# FIXTURES # +# ------------------------------------------------------------------------------------ # + + +@pytest.fixture() +def valid_data() -> dict[str, pl.DataFrame]: + return { + "users": pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}), + "orders": pl.DataFrame( + {"id": [1, 2], "user_id": [1, 2], "amount": [10.0, 20.0]} + ), + } + + +# ------------------------------------------------------------------------------------ # +# MEMBER INFO TESTS # +# ------------------------------------------------------------------------------------ # + + +@pytest.mark.parametrize( + ("collection_cls", "expected_lazy", "expected_eager"), + [ + (EagerCollection, set(), {"users", "orders"}), + (LazyCollection, {"users", "orders"}, set()), + (MixedCollection, {"orders"}, {"users"}), + (OptionalEagerCollection, set(), {"users", "orders"}), + ], +) +def test_member_detection( + collection_cls: type[dy.Collection], + expected_lazy: set[str], + expected_eager: set[str], +) -> None: + members = collection_cls.members() + for name in expected_lazy: + assert members[name].is_lazy + for name in expected_eager: + assert not members[name].is_lazy + assert collection_cls.lazy_members() == expected_lazy + assert collection_cls.eager_members() == expected_eager + + +def test_optional_eager_member_detection() -> None: + members = OptionalEagerCollection.members() + assert not members["users"].is_optional + assert members["orders"].is_optional + + +# ------------------------------------------------------------------------------------ # +# ACCESS PATTERN TESTS # +# ------------------------------------------------------------------------------------ # + + +@pytest.mark.parametrize( + ("collection_cls", "expected_types"), + [ + (EagerCollection, {"users": pl.DataFrame, "orders": pl.DataFrame}), + (LazyCollection, {"users": pl.LazyFrame, "orders": pl.LazyFrame}), + (MixedCollection, {"users": pl.DataFrame, "orders": pl.LazyFrame}), + ], +) +def test_member_access_returns_correct_type( + collection_cls: type[dy.Collection], + expected_types: dict[str, type], + valid_data: dict[str, pl.DataFrame], +) -> None: + collection = collection_cls.validate(valid_data) + for name, expected_type in expected_types.items(): + assert isinstance(getattr(collection, name), expected_type) diff --git a/tests/collection/test_implementation.py b/tests/collection/test_implementation.py index 13a30a9..39dbeab 100644 --- a/tests/collection/test_implementation.py +++ b/tests/collection/test_implementation.py @@ -16,17 +16,17 @@ class MyTestSchema(dy.Schema): a = dy.Integer(primary_key=True) -def test_annotation_type_failure() -> None: - with pytest.raises( - AnnotationImplementationError, - ): - create_collection( - "test", - { - "first": create_schema("first", {"a": dy.Integer()}), - }, - annotation_base_class=dy.DataFrame, - ) +def test_annotation_dataframe_success() -> None: + """DataFrame annotations are now supported.""" + collection = create_collection( + "test", + { + "first": create_schema("first", {"a": dy.Integer()}), + }, + annotation_base_class=dy.DataFrame, + ) + members = collection.members() + assert not members["first"].is_lazy def test_annotation_union_success() -> None: @@ -40,14 +40,16 @@ def test_annotation_union_success() -> None: def test_annotation_union_with_data_frame() -> None: - """When we use a union annotation, it must contain one typed LazyFrame and None.""" - with pytest.raises(AnnotationImplementationError): - create_collection_raw( - "test", - { - "first": dy.DataFrame[MyTestSchema] | None, - }, - ) + """DataFrame union with None is now supported for optional eager members.""" + collection = create_collection_raw( + "test", + { + "first": dy.DataFrame[MyTestSchema] | None, + }, + ) + members = collection.members() + assert not members["first"].is_lazy + assert members["first"].is_optional def test_annotation_union_too_many_arg_failure() -> None: @@ -109,7 +111,7 @@ def test_annotation_only_none_failure() -> None: def test_annotation_invalid_type_failure() -> None: - """First argument of union must be a LazyFrame.""" + """First argument of union must be a LazyFrame or DataFrame.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( "test", @@ -119,6 +121,39 @@ def test_annotation_invalid_type_failure() -> None: ) +def test_annotation_invalid_generic_type_in_union() -> None: + """Union with generic type that's not LazyFrame/DataFrame should fail.""" + with pytest.raises(AnnotationImplementationError): + create_collection_raw( + "test", + { + "first": list[int] | None, + }, + ) + + +def test_annotation_union_frame_with_non_none_type() -> None: + """Union of DataFrame with non-None type should fail.""" + with pytest.raises(AnnotationImplementationError): + create_collection_raw( + "test", + { + "first": dy.DataFrame[MyTestSchema] | int, + }, + ) + + +def test_annotation_invalid_standalone_generic() -> None: + """Standalone generic type that's not LazyFrame/DataFrame should fail.""" + with pytest.raises(AnnotationImplementationError): + create_collection_raw( + "test", + { + "first": list[int], + }, + ) + + def test_explicit_annotation_type_failure_no_frame_type() -> None: """First argument of the annotated union must be a LazyFrame.""" with pytest.raises(AnnotationImplementationError): diff --git a/tests/collection/test_serialization.py b/tests/collection/test_serialization.py index 6b9b219..75c9ec5 100644 --- a/tests/collection/test_serialization.py +++ b/tests/collection/test_serialization.py @@ -84,6 +84,55 @@ def test_roundtrip_matches(collection: type[dy.Collection]) -> None: assert collection.matches(decoded) +# --------------------------------- DATAFRAME MEMBERS -------------------------------- # + + +class EagerSchema(dy.Schema): + id = dy.Int64(primary_key=True) + + +class MixedEagerCollection(dy.Collection): + """Collection with mixed DataFrame and LazyFrame members.""" + + eager: dy.DataFrame[EagerSchema] + lazy: dy.LazyFrame[EagerSchema] + + +def test_serialize_includes_is_lazy() -> None: + """Serialization includes the is_lazy field for each member.""" + serialized = MixedEagerCollection.serialize() + decoded = json.loads(serialized) + + assert decoded["members"]["eager"]["is_lazy"] is False + assert decoded["members"]["lazy"]["is_lazy"] is True + + +def test_roundtrip_dataframe_members() -> None: + """DataFrame members round-trip correctly through serialization.""" + serialized = MixedEagerCollection.serialize() + decoded = dy.deserialize_collection(serialized) + + assert MixedEagerCollection.matches(decoded) + assert not decoded.members()["eager"].is_lazy + assert decoded.members()["lazy"].is_lazy + + +def test_deserialize_without_is_lazy_defaults_to_lazy() -> None: + """Old serialized data without is_lazy defaults to lazy for backwards compat.""" + collection = create_collection( + "test", {"s1": create_schema("schema1", {"a": dy.Int64()})} + ) + serialized = collection.serialize() + + # Remove is_lazy from serialized data to simulate old format + decoded_dict = json.loads(serialized) + del decoded_dict["members"]["s1"]["is_lazy"] + modified = json.dumps(decoded_dict) + + result = dy.deserialize_collection(modified) + assert result.members()["s1"].is_lazy is True + + # ----------------------------- DESERIALIZATION FAILURES ----------------------------- #