Skip to content

Commit 61268ab

Browse files
committed
refactor: extract generic metaclass logic from ListResponse
1 parent aef537b commit 61268ab

2 files changed

Lines changed: 90 additions & 64 deletions

File tree

scim2_models/rfc7644/list_response.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,84 +2,22 @@
22
from typing import Any
33
from typing import Generic
44
from typing import Optional
5-
from typing import Union
6-
from typing import get_args
7-
from typing import get_origin
85

9-
from pydantic import Discriminator
106
from pydantic import Field
11-
from pydantic import Tag
127
from pydantic import ValidationInfo
138
from pydantic import ValidatorFunctionWrapHandler
149
from pydantic import model_validator
1510
from pydantic_core import PydanticCustomError
1611
from typing_extensions import Self
1712

18-
from ..base import BaseModel
19-
from ..base import BaseModelType
2013
from ..base import Context
2114
from ..base import Required
2215
from ..rfc7643.resource import AnyResource
23-
from ..utils import UNION_TYPES
16+
from .message import GenericMessageMetaclass
2417
from .message import Message
2518

2619

27-
class ListResponseMetaclass(BaseModelType):
28-
def tagged_resource_union(resource_union):
29-
"""Build Discriminated Unions, so pydantic can guess which class are needed to instantiate by inspecting a payload.
30-
31-
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
32-
"""
33-
if get_origin(resource_union) not in UNION_TYPES:
34-
return resource_union
35-
36-
resource_types = get_args(resource_union)
37-
38-
def get_schema_from_payload(payload: Any) -> Optional[str]:
39-
if not payload:
40-
return None
41-
42-
payload_schemas = (
43-
payload.get("schemas", [])
44-
if isinstance(payload, dict)
45-
else payload.schemas
46-
)
47-
48-
resource_types_schemas = [
49-
resource_type.model_fields["schemas"].default[0]
50-
for resource_type in resource_types
51-
]
52-
common_schemas = [
53-
schema for schema in payload_schemas if schema in resource_types_schemas
54-
]
55-
return common_schemas[0] if common_schemas else None
56-
57-
discriminator = Discriminator(get_schema_from_payload)
58-
59-
def get_tag(resource_type: type[BaseModel]) -> Tag:
60-
return Tag(resource_type.model_fields["schemas"].default[0])
61-
62-
tagged_resources = [
63-
Annotated[resource_type, get_tag(resource_type)]
64-
for resource_type in resource_types
65-
]
66-
union = Union[tuple(tagged_resources)]
67-
return Annotated[union, discriminator]
68-
69-
def __new__(cls, name, bases, attrs, **kwargs):
70-
if kwargs.get("__pydantic_generic_metadata__") and kwargs[
71-
"__pydantic_generic_metadata__"
72-
].get("args"):
73-
tagged_union = cls.tagged_resource_union(
74-
kwargs["__pydantic_generic_metadata__"]["args"][0]
75-
)
76-
kwargs["__pydantic_generic_metadata__"]["args"] = (tagged_union,)
77-
78-
klass = super().__new__(cls, name, bases, attrs, **kwargs)
79-
return klass
80-
81-
82-
class ListResponse(Message, Generic[AnyResource], metaclass=ListResponseMetaclass):
20+
class ListResponse(Message, Generic[AnyResource], metaclass=GenericMessageMetaclass):
8321
schemas: Annotated[list[str], Required.true] = [
8422
"urn:ietf:params:scim:api:messages:2.0:ListResponse"
8523
]

scim2_models/rfc7644/message.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,98 @@
11
from typing import Annotated
2+
from typing import Any
3+
from typing import Optional
4+
from typing import Union
5+
from typing import get_args
6+
from typing import get_origin
7+
8+
from pydantic import Discriminator
9+
from pydantic import Tag
210

311
from ..base import BaseModel
12+
from ..base import BaseModelType
413
from ..base import Required
14+
from ..utils import UNION_TYPES
515

616

717
class Message(BaseModel):
818
"""SCIM protocol messages as defined by :rfc:`RFC7644 §3.1 <7644#section-3.1>`."""
919

1020
schemas: Annotated[list[str], Required.true]
21+
22+
23+
def get_schema_from_payload(payload: Any) -> Optional[str]:
24+
"""Extract schema from SCIM payload for discrimination.
25+
26+
:param payload: SCIM payload dict or object
27+
:return: First matching schema or None
28+
"""
29+
if not payload:
30+
return None
31+
32+
payload_schemas = (
33+
payload.get("schemas", []) if isinstance(payload, dict) else payload.schemas
34+
)
35+
36+
# This will be set by the calling context
37+
resource_types_schemas = getattr(get_schema_from_payload, "_resource_schemas", [])
38+
common_schemas = [
39+
schema for schema in payload_schemas if schema in resource_types_schemas
40+
]
41+
return common_schemas[0] if common_schemas else None
42+
43+
44+
def get_tag(resource_type: type[BaseModel]) -> Tag:
45+
"""Create Pydantic tag from resource type schema.
46+
47+
:param resource_type: SCIM resource type
48+
:return: Pydantic Tag for discrimination
49+
"""
50+
return Tag(resource_type.model_fields["schemas"].default[0])
51+
52+
53+
def create_tagged_resource_union(resource_union):
54+
"""Build Discriminated Unions for SCIM resources.
55+
56+
Creates discriminated unions so Pydantic can determine which class to instantiate
57+
by inspecting the payload's schemas field.
58+
59+
:param resource_union: Union type of SCIM resources
60+
:return: Annotated discriminated union or original type
61+
"""
62+
if get_origin(resource_union) not in UNION_TYPES:
63+
return resource_union
64+
65+
resource_types = get_args(resource_union)
66+
67+
# Set up schemas for the discriminator function
68+
resource_types_schemas = [
69+
resource_type.model_fields["schemas"].default[0]
70+
for resource_type in resource_types
71+
]
72+
get_schema_from_payload._resource_schemas = resource_types_schemas
73+
74+
discriminator = Discriminator(get_schema_from_payload)
75+
76+
tagged_resources = [
77+
Annotated[resource_type, get_tag(resource_type)]
78+
for resource_type in resource_types
79+
]
80+
union = Union[tuple(tagged_resources)]
81+
return Annotated[union, discriminator]
82+
83+
84+
class GenericMessageMetaclass(BaseModelType):
85+
"""Metaclass for SCIM generic types with discriminated unions."""
86+
87+
def __new__(cls, name, bases, attrs, **kwargs):
88+
"""Create class with tagged resource unions for generic parameters."""
89+
if kwargs.get("__pydantic_generic_metadata__") and kwargs[
90+
"__pydantic_generic_metadata__"
91+
].get("args"):
92+
tagged_union = create_tagged_resource_union(
93+
kwargs["__pydantic_generic_metadata__"]["args"][0]
94+
)
95+
kwargs["__pydantic_generic_metadata__"]["args"] = (tagged_union,)
96+
97+
klass = super().__new__(cls, name, bases, attrs, **kwargs)
98+
return klass

0 commit comments

Comments
 (0)