|
1 | 1 | from typing import Annotated |
| 2 | +from typing import Any |
| 3 | +from typing import Optional |
| 4 | +from typing import Union |
| 5 | +from typing import get_args |
| 6 | +from typing import get_origin |
| 7 | + |
| 8 | +from pydantic import Discriminator |
| 9 | +from pydantic import Tag |
2 | 10 |
|
3 | 11 | from ..base import BaseModel |
| 12 | +from ..base import BaseModelType |
4 | 13 | from ..base import Required |
| 14 | +from ..utils import UNION_TYPES |
5 | 15 |
|
6 | 16 |
|
7 | 17 | class Message(BaseModel): |
8 | 18 | """SCIM protocol messages as defined by :rfc:`RFC7644 §3.1 <7644#section-3.1>`.""" |
9 | 19 |
|
10 | 20 | schemas: Annotated[list[str], Required.true] |
| 21 | + |
| 22 | + |
| 23 | +def get_schema_from_payload(payload: Any) -> Optional[str]: |
| 24 | + """Extract schema from SCIM payload for discrimination. |
| 25 | +
|
| 26 | + :param payload: SCIM payload dict or object |
| 27 | + :return: First matching schema or None |
| 28 | + """ |
| 29 | + if not payload: |
| 30 | + return None |
| 31 | + |
| 32 | + payload_schemas = ( |
| 33 | + payload.get("schemas", []) if isinstance(payload, dict) else payload.schemas |
| 34 | + ) |
| 35 | + |
| 36 | + # This will be set by the calling context |
| 37 | + resource_types_schemas = getattr(get_schema_from_payload, "_resource_schemas", []) |
| 38 | + common_schemas = [ |
| 39 | + schema for schema in payload_schemas if schema in resource_types_schemas |
| 40 | + ] |
| 41 | + return common_schemas[0] if common_schemas else None |
| 42 | + |
| 43 | + |
| 44 | +def get_tag(resource_type: type[BaseModel]) -> Tag: |
| 45 | + """Create Pydantic tag from resource type schema. |
| 46 | +
|
| 47 | + :param resource_type: SCIM resource type |
| 48 | + :return: Pydantic Tag for discrimination |
| 49 | + """ |
| 50 | + return Tag(resource_type.model_fields["schemas"].default[0]) |
| 51 | + |
| 52 | + |
| 53 | +def create_tagged_resource_union(resource_union): |
| 54 | + """Build Discriminated Unions for SCIM resources. |
| 55 | +
|
| 56 | + Creates discriminated unions so Pydantic can determine which class to instantiate |
| 57 | + by inspecting the payload's schemas field. |
| 58 | +
|
| 59 | + :param resource_union: Union type of SCIM resources |
| 60 | + :return: Annotated discriminated union or original type |
| 61 | + """ |
| 62 | + if get_origin(resource_union) not in UNION_TYPES: |
| 63 | + return resource_union |
| 64 | + |
| 65 | + resource_types = get_args(resource_union) |
| 66 | + |
| 67 | + # Set up schemas for the discriminator function |
| 68 | + resource_types_schemas = [ |
| 69 | + resource_type.model_fields["schemas"].default[0] |
| 70 | + for resource_type in resource_types |
| 71 | + ] |
| 72 | + get_schema_from_payload._resource_schemas = resource_types_schemas |
| 73 | + |
| 74 | + discriminator = Discriminator(get_schema_from_payload) |
| 75 | + |
| 76 | + tagged_resources = [ |
| 77 | + Annotated[resource_type, get_tag(resource_type)] |
| 78 | + for resource_type in resource_types |
| 79 | + ] |
| 80 | + union = Union[tuple(tagged_resources)] |
| 81 | + return Annotated[union, discriminator] |
| 82 | + |
| 83 | + |
| 84 | +class GenericMessageMetaclass(BaseModelType): |
| 85 | + """Metaclass for SCIM generic types with discriminated unions.""" |
| 86 | + |
| 87 | + def __new__(cls, name, bases, attrs, **kwargs): |
| 88 | + """Create class with tagged resource unions for generic parameters.""" |
| 89 | + if kwargs.get("__pydantic_generic_metadata__") and kwargs[ |
| 90 | + "__pydantic_generic_metadata__" |
| 91 | + ].get("args"): |
| 92 | + tagged_union = create_tagged_resource_union( |
| 93 | + kwargs["__pydantic_generic_metadata__"]["args"][0] |
| 94 | + ) |
| 95 | + kwargs["__pydantic_generic_metadata__"]["args"] = (tagged_union,) |
| 96 | + |
| 97 | + klass = super().__new__(cls, name, bases, attrs, **kwargs) |
| 98 | + return klass |
0 commit comments