diff --git a/src/cattrs/enums.py b/src/cattrs/enums.py index b1ab5040..94d73e29 100644 --- a/src/cattrs/enums.py +++ b/src/cattrs/enums.py @@ -1,20 +1,37 @@ -from collections.abc import Callable +from collections.abc import Callable, Mapping from enum import Enum from typing import TYPE_CHECKING, Any +from ._compat import has + if TYPE_CHECKING: from .converters import BaseConverter +def _needs_recursive_unstructure(value: Any) -> bool: + if isinstance(value, Enum) or has(value.__class__): + return True + if isinstance(value, tuple | list | set | frozenset): + return any(_needs_recursive_unstructure(v) for v in value) + if isinstance(value, Mapping): + return any( + _needs_recursive_unstructure(k) or _needs_recursive_unstructure(v) + for k, v in value.items() + ) + return False + + def enum_unstructure_factory( type: type[Enum], converter: "BaseConverter" ) -> Callable[[Enum], Any]: """A factory for generating enum unstructure hooks. If the enum is a typed enum (has `_value_`), we use the underlying value's hook. - Otherwise, we use the value directly. + Otherwise, we only use the converter when the values are known to need it. """ - if "_value_" in type.__annotations__: + if "_value_" in type.__annotations__ or any( + _needs_recursive_unstructure(member.value) for member in type + ): return lambda e: converter.unstructure(e.value) return lambda e: e.value diff --git a/tests/test_enums.py b/tests/test_enums.py index bdf591f1..5a58f2ae 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -2,6 +2,7 @@ from enum import Enum +from attrs import define from hypothesis import given from hypothesis.strategies import data, sampled_from from pytest import raises @@ -68,3 +69,57 @@ def test_structure_complex_enum() -> None: assert converter.structure(0, SimpleEnum) == SimpleEnum.A assert converter.structure("E", SimpleEnumWithTypeHint) == SimpleEnumWithTypeHint.E assert converter.structure((0, "D"), ComplexEnum) == ComplexEnum.AD + + +class EnumValuedEnum(Enum): + """Enum whose members have other Enum instances as values (no type annotation).""" + + X = SimpleEnum.A + Y = SimpleEnum.B + + +class TupleValuedEnum(Enum): + """Enum whose members have tuples with Enum instances (no type annotation).""" + + X = (SimpleEnum.A, 1) + + +@define +class AnAttrsClass: + a: int + + +class AttrsValuedEnum(Enum): + """Enum whose members have attrs instances as values (no type annotation).""" + + X = AnAttrsClass(1) + + +def test_unstructure_simple_enum_uses_value_directly() -> None: + """Simple enum values do not recurse through the converter.""" + converter = BaseConverter() + converter.register_unstructure_hook(int, lambda _: "overridden") + + assert converter.unstructure(SimpleEnum.A) == 0 + + +def test_unstructure_enum_with_enum_values() -> None: + """Enum members whose values are themselves Enums are unstructured recursively. + + Regression test for https://github.com/python-attrs/cattrs/issues/679. + """ + converter = BaseConverter() + assert converter.unstructure(EnumValuedEnum.X) == 0 + assert converter.unstructure(EnumValuedEnum.Y) == 1 + + +def test_unstructure_enum_with_tuple_values() -> None: + """Enum member tuples containing Enums are unstructured recursively.""" + converter = BaseConverter() + assert converter.unstructure(TupleValuedEnum.X) == (0, 1) + + +def test_unstructure_enum_with_attrs_values() -> None: + """Enum members whose values are attrs classes are unstructured recursively.""" + converter = BaseConverter() + assert converter.unstructure(AttrsValuedEnum.X) == {"a": 1}