Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions dataframely/collection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from dataframely._filter import Filter
from dataframely._polars import FrameType
from dataframely._typing import DataFrame as TypedDataFrame
from dataframely._typing import LazyFrame as TypedLazyFrame
from dataframely.exc import AnnotationImplementationError, ImplementationError
from dataframely.schema import Schema
Expand Down Expand Up @@ -92,6 +93,8 @@ class MemberInfo(CollectionMember):
schema: type[Schema]
#: Whether the member is optional.
is_optional: bool
#: Whether the member is a lazy frame.
is_lazy: bool = True


@dataclass
Expand Down Expand Up @@ -241,39 +244,46 @@ def _derive_member_info(
attr, annotation_args[0], annotation_args[1]
)
elif origin == typing.Union:
# Happy path: optional member
# Happy path: optional member (e.g. dy.LazyFrame[Schema] | None)
union_args = get_args(type_annotation)
if len(union_args) != 2:
raise AnnotationImplementationError(attr, type_annotation)
if not any(get_origin(arg) is None for arg in union_args):
# Check that exactly one arg is None (type(None) is NoneType)
if not any(arg is type(None) for arg in union_args):
raise AnnotationImplementationError(attr, type_annotation)

not_none_args = [arg for arg in union_args if get_origin(arg) is not None]
if len(not_none_args) == 0 or not issubclass(
get_origin(not_none_args[0]), TypedLazyFrame
):
# Get the non-None type (exactly one exists given prior checks)
not_none_arg = next(arg for arg in union_args if arg is not type(None))

frame_origin = get_origin(not_none_arg)
if frame_origin is None:
raise AnnotationImplementationError(attr, type_annotation)
Comment thread
borchero marked this conversation as resolved.

return MemberInfo(
schema=get_args(not_none_args[0])[0],
is_optional=True,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
propagate_row_failures=collection_member.propagate_row_failures,
)
elif issubclass(origin, TypedLazyFrame):
# Happy path: required member
return MemberInfo(
schema=get_args(type_annotation)[0],
is_optional=False,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
propagate_row_failures=collection_member.propagate_row_failures,
)
schema = get_args(not_none_arg)[0]
is_optional = True
elif issubclass(origin, (TypedLazyFrame, TypedDataFrame)):
frame_origin = origin
schema = get_args(type_annotation)[0]
is_optional = False
else:
raise AnnotationImplementationError(attr, type_annotation)

if issubclass(frame_origin, TypedLazyFrame):
is_lazy = True
elif issubclass(frame_origin, TypedDataFrame):
is_lazy = False
else:
# Some other unknown annotation
raise AnnotationImplementationError(attr, type_annotation)

return MemberInfo(
schema=schema,
is_optional=is_optional,
is_lazy=is_lazy,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
propagate_row_failures=collection_member.propagate_row_failures,
)

def __repr__(cls) -> str:
parts = [f'[Collection "{cls.__class__.__name__}"]']
parts.append(textwrap.indent("Members:", prefix=" " * 2))
Expand Down Expand Up @@ -344,6 +354,16 @@ def non_ignored_members(cls) -> set[str]:
if not member.ignored_in_filters
}

@classmethod
def lazy_members(cls) -> set[str]:
"""The names of all members annotated as lazy frames."""
return {name for name, member in cls.members().items() if member.is_lazy}

@classmethod
def eager_members(cls) -> set[str]:
"""The names of all members annotated as data frames (eager)."""
return {name for name, member in cls.members().items() if not member.is_lazy}

@classmethod
def _failure_propagating_members(cls) -> set[str]:
"""The names of all members of the collection that propagate individual row
Expand Down Expand Up @@ -372,9 +392,9 @@ def _filters(cls) -> dict[str, Filter[Self]]:
return getattr(cls, _FILTER_ATTR)

def to_dict(self) -> dict[str, pl.LazyFrame]:
Comment thread
borchero marked this conversation as resolved.
"""Return a dictionary representation of this collection."""
"""Return a dictionary with all members as lazy frames."""
return {
member: getattr(self, member)
member: getattr(self, member).lazy()
for member in self.member_schemas()
if getattr(self, member) is not None
}
Expand All @@ -385,6 +405,9 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self:
for member_name, member in cls.members().items():
if member.is_optional and member_name not in data:
setattr(out, member_name, None)
else:
elif member.is_lazy:
setattr(out, member_name, data[member_name].lazy())
else:
setattr(out, member_name, data[member_name].lazy().collect())

return out
33 changes: 17 additions & 16 deletions dataframely/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dataframely._storage.constants import COLLECTION_METADATA_KEY
from dataframely._storage.delta import DeltaStorageBackend
from dataframely._storage.parquet import ParquetStorageBackend
from dataframely._typing import LazyFrame, Validation
from dataframely._typing import DataFrame, LazyFrame, Validation
from dataframely.exc import (
DeserializationError,
ValidationError,
Expand Down Expand Up @@ -68,13 +68,13 @@ class Collection(BaseCollection, ABC):
to 1-N relationships that are managed in separate data frames.

A collection must only have type annotations for :class:`~dataframely.LazyFrame`
with known schema:
or :class:`~dataframely.DataFrame` with known schema:

.. code:: python

class MyCollection(dy.Collection):
first_member: dy.LazyFrame[MyFirstSchema]
second_member: dy.LazyFrame[MySecondSchema]
second_member: dy.DataFrame[MySecondSchema]

Besides, it may define *filters* (c.f. :meth:`~dataframely.filter`) and arbitrary
methods.
Expand Down Expand Up @@ -788,17 +788,14 @@ def collect_all(self) -> Self:
particularly useful when :meth:`filter` is called with lazy frame inputs.

Returns:
The same collection with all members collected once.

Note:
As all collection members are required to be lazy frames, the returned
collection's members are still "lazy". However, they are "shallow-lazy",
meaning they are obtained by calling `.collect().lazy()`.
The same collection with all members collected once. Members annotated
with :class:`~dataframely.DataFrame` are returned as DataFrames, while
members annotated with :class:`~dataframely.LazyFrame` are returned as
"shallow-lazy" frames (obtained by calling ``.collect().lazy()``).
"""
dfs = pl.collect_all(self.to_dict().values())
return self._init(
{key: dfs[i].lazy() for i, key in enumerate(self.to_dict().keys())}
)
lazy_dict = self.to_dict()
dfs = pl.collect_all(lazy_dict.values())
return self._init(dict(zip(lazy_dict, dfs)))

# --------------------------------- SERIALIZATION -------------------------------- #

Expand Down Expand Up @@ -842,6 +839,7 @@ def serialize(cls) -> str:
name: {
"schema": info.schema._as_dict(),
"is_optional": info.is_optional,
"is_lazy": info.is_lazy,
"ignored_in_filters": info.ignored_in_filters,
"inline_for_sampling": info.inline_for_sampling,
}
Expand Down Expand Up @@ -1330,11 +1328,14 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] |

annotations: dict[str, Any] = {}
for name, info in decoded["members"].items():
lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore
schema = _schema_from_dict(info["schema"])
# Default to lazy for backwards compatibility with old serialized data
is_lazy = info.get("is_lazy", True)
frame_type = LazyFrame[schema] if is_lazy else DataFrame[schema] # type: ignore
if info["is_optional"]:
lf_type = lf_type | None # type: ignore
frame_type = frame_type | None # type: ignore
annotations[name] = Annotated[
lf_type,
frame_type,
CollectionMember(
ignored_in_filters=info["ignored_in_filters"],
inline_for_sampling=info["inline_for_sampling"],
Expand Down
4 changes: 3 additions & 1 deletion dataframely/collection/filter_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def collect_all(self, **kwargs: Any) -> CollectionFilterResult[C]:
kwargs: Keyword arguments passed directly to :meth:`polars.collect_all`.

Returns:
The same filter result object with all lazy frames collected and exposed as
The same filter result object with all frames collected. Members annotated
with :class:`~dataframely.DataFrame` are returned as DataFrames, while
members annotated with :class:`~dataframely.LazyFrame` are returned as
"shallow" lazy frames.

Attention:
Expand Down
127 changes: 127 additions & 0 deletions tests/collection/test_dataframe_members.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause
"""Tests for dy.DataFrame members in collections.

Members annotated with dy.DataFrame are collected once during _init and stored as
DataFrames, while dy.LazyFrame members remain lazy.
"""

import polars as pl
import pytest

import dataframely as dy

# ------------------------------------------------------------------------------------ #
# SCHEMA #
# ------------------------------------------------------------------------------------ #


class UserSchema(dy.Schema):
id = dy.Integer(primary_key=True)
name = dy.String()


class OrderSchema(dy.Schema):
id = dy.Integer(primary_key=True)
user_id = dy.Integer()
amount = dy.Float(min=0)


class EagerCollection(dy.Collection):
"""Collection with only DataFrame (eager) members."""

users: dy.DataFrame[UserSchema]
orders: dy.DataFrame[OrderSchema]


class MixedCollection(dy.Collection):
"""Collection with mixed DataFrame and LazyFrame members."""

users: dy.DataFrame[UserSchema]
orders: dy.LazyFrame[OrderSchema]


class LazyCollection(dy.Collection):
"""Collection with only LazyFrame members (traditional)."""

users: dy.LazyFrame[UserSchema]
orders: dy.LazyFrame[OrderSchema]


class OptionalEagerCollection(dy.Collection):
"""Collection with optional DataFrame member."""

users: dy.DataFrame[UserSchema]
orders: dy.DataFrame[OrderSchema] | None


# ------------------------------------------------------------------------------------ #
# FIXTURES #
# ------------------------------------------------------------------------------------ #


@pytest.fixture()
def valid_data() -> dict[str, pl.DataFrame]:
return {
"users": pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}),
"orders": pl.DataFrame(
{"id": [1, 2], "user_id": [1, 2], "amount": [10.0, 20.0]}
),
}


# ------------------------------------------------------------------------------------ #
# MEMBER INFO TESTS #
# ------------------------------------------------------------------------------------ #


@pytest.mark.parametrize(
("collection_cls", "expected_lazy", "expected_eager"),
[
(EagerCollection, set(), {"users", "orders"}),
(LazyCollection, {"users", "orders"}, set()),
(MixedCollection, {"orders"}, {"users"}),
(OptionalEagerCollection, set(), {"users", "orders"}),
],
)
def test_member_detection(
collection_cls: type[dy.Collection],
expected_lazy: set[str],
expected_eager: set[str],
) -> None:
members = collection_cls.members()
for name in expected_lazy:
assert members[name].is_lazy
for name in expected_eager:
assert not members[name].is_lazy
assert collection_cls.lazy_members() == expected_lazy
assert collection_cls.eager_members() == expected_eager


def test_optional_eager_member_detection() -> None:
members = OptionalEagerCollection.members()
assert not members["users"].is_optional
assert members["orders"].is_optional


# ------------------------------------------------------------------------------------ #
# ACCESS PATTERN TESTS #
# ------------------------------------------------------------------------------------ #


@pytest.mark.parametrize(
("collection_cls", "expected_types"),
[
(EagerCollection, {"users": pl.DataFrame, "orders": pl.DataFrame}),
(LazyCollection, {"users": pl.LazyFrame, "orders": pl.LazyFrame}),
(MixedCollection, {"users": pl.DataFrame, "orders": pl.LazyFrame}),
],
)
def test_member_access_returns_correct_type(
collection_cls: type[dy.Collection],
expected_types: dict[str, type],
valid_data: dict[str, pl.DataFrame],
) -> None:
collection = collection_cls.validate(valid_data)
for name, expected_type in expected_types.items():
assert isinstance(getattr(collection, name), expected_type)
Loading
Loading