Skip to content

Commit 6ce5897

Browse files
committed
refactor: use __class_getitem__ instead of metaclasses for Resources
This fixes a typing issue with mypy.
1 parent 8112273 commit 6ce5897

3 files changed

Lines changed: 86 additions & 45 deletions

File tree

scim2_models/rfc7643/resource.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ..attributes import MultiValuedComplexAttribute
2525
from ..attributes import is_complex_attribute
2626
from ..base import BaseModel
27-
from ..base import BaseModelType
2827
from ..context import Context
2928
from ..reference import Reference
3029
from ..scim_object import ScimObject
@@ -104,6 +103,8 @@ def from_schema(cls, schema: "Schema") -> type["Extension"]:
104103

105104
AnyExtension = TypeVar("AnyExtension", bound="Extension")
106105

106+
_PARAMETERIZED_CLASSES: dict[tuple[type, tuple], type] = {}
107+
107108

108109
def extension_serializer(
109110
value: Any, handler: SerializerFunctionWrapHandler, info: SerializationInfo
@@ -122,33 +123,7 @@ def extension_serializer(
122123
return result or None
123124

124125

125-
class ResourceMetaclass(BaseModelType):
126-
def __new__(cls, name: str, bases: tuple, attrs: dict, **kwargs: Any) -> type:
127-
"""Dynamically add a field for each extension."""
128-
if "__pydantic_generic_metadata__" in kwargs:
129-
extensions = kwargs["__pydantic_generic_metadata__"]["args"][0]
130-
extensions = (
131-
get_args(extensions)
132-
if get_origin(extensions) in UNION_TYPES
133-
else [extensions]
134-
)
135-
for extension in extensions:
136-
schema = extension.model_fields["schemas"].default[0]
137-
attrs.setdefault("__annotations__", {})[extension.__name__] = Annotated[
138-
Optional[extension],
139-
WrapSerializer(extension_serializer),
140-
]
141-
attrs[extension.__name__] = Field(
142-
None,
143-
serialization_alias=schema,
144-
validation_alias=normalize_attribute_name(schema),
145-
)
146-
147-
klass = super().__new__(cls, name, bases, attrs, **kwargs)
148-
return klass
149-
150-
151-
class Resource(ScimObject, Generic[AnyExtension], metaclass=ResourceMetaclass):
126+
class Resource(ScimObject, Generic[AnyExtension]):
152127
# Common attributes as defined by
153128
# https://www.rfc-editor.org/rfc/rfc7643#section-3.1
154129

@@ -171,6 +146,71 @@ class Resource(ScimObject, Generic[AnyExtension], metaclass=ResourceMetaclass):
171146
meta: Annotated[Optional[Meta], Mutability.read_only, Returned.default] = None
172147
"""A complex attribute containing resource metadata."""
173148

149+
@classmethod
150+
def __class_getitem__(cls, item: Any) -> type["Resource"]:
151+
"""Create a Resource class with extension fields dynamically added."""
152+
if hasattr(cls, "__scim_extension_metadata__"):
153+
return cls
154+
155+
extensions = get_args(item) if get_origin(item) in UNION_TYPES else [item]
156+
157+
# Skip TypeVar parameters (used for generic class definitions)
158+
valid_extensions = [
159+
extension for extension in extensions if not isinstance(extension, TypeVar)
160+
]
161+
162+
if not valid_extensions:
163+
return cls
164+
165+
cache_key = (cls, tuple(valid_extensions))
166+
if cache_key in _PARAMETERIZED_CLASSES:
167+
return _PARAMETERIZED_CLASSES[cache_key]
168+
169+
for extension in valid_extensions:
170+
if not (isinstance(extension, type) and issubclass(extension, Extension)):
171+
raise TypeError(f"{extension} is not a valid Extension type")
172+
173+
class_name = (
174+
f"{cls.__name__}[{', '.join(ext.__name__ for ext in valid_extensions)}]"
175+
)
176+
177+
class_attrs = {
178+
"__scim_extension_metadata__": {
179+
"args": (item,),
180+
"origin": cls,
181+
"extensions": valid_extensions,
182+
}
183+
}
184+
185+
for extension in valid_extensions:
186+
schema = extension.model_fields["schemas"].default[0]
187+
class_attrs[extension.__name__] = Field(
188+
None,
189+
serialization_alias=schema,
190+
validation_alias=normalize_attribute_name(schema),
191+
)
192+
193+
new_annotations = {
194+
extension.__name__: Annotated[
195+
Optional[extension],
196+
WrapSerializer(extension_serializer),
197+
]
198+
for extension in valid_extensions
199+
}
200+
201+
new_class = type(
202+
class_name,
203+
(cls,),
204+
{
205+
"__annotations__": new_annotations,
206+
**class_attrs,
207+
},
208+
)
209+
210+
_PARAMETERIZED_CLASSES[cache_key] = new_class
211+
212+
return new_class
213+
174214
def __getitem__(self, item: Any) -> Optional[Extension]:
175215
if not isinstance(item, type) or not issubclass(item, Extension):
176216
raise KeyError(f"{item} is not a valid extension type")
@@ -186,13 +226,7 @@ def __setitem__(self, item: Any, value: "Resource") -> None:
186226
@classmethod
187227
def get_extension_models(cls) -> dict[str, type[Extension]]:
188228
"""Return extension a dict associating extension models with their schemas."""
189-
generic_args: Any = cls.__pydantic_generic_metadata__.get("args", [])
190-
extension_models = (
191-
get_args(generic_args[0])
192-
if len(generic_args) == 1 and get_origin(generic_args[0]) in UNION_TYPES
193-
else generic_args
194-
)
195-
229+
extension_models = getattr(cls, "__scim_extension_metadata__", [])
196230
by_schema = {
197231
ext.model_fields["schemas"].default[0]: ext for ext in extension_models
198232
}

scim2_models/rfc7643/resource_type.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Annotated
22
from typing import Optional
3-
from typing import get_args
4-
from typing import get_origin
53

64
from pydantic import Field
75
from typing_extensions import Self
@@ -13,7 +11,6 @@
1311
from ..attributes import ComplexAttribute
1412
from ..reference import Reference
1513
from ..reference import URIReference
16-
from ..utils import UNION_TYPES
1714
from .resource import Resource
1815

1916

@@ -85,13 +82,10 @@ def from_resource(cls, resource_model: type[Resource]) -> Self:
8582
"""Build a naive ResourceType from a resource model."""
8683
schema = resource_model.model_fields["schemas"].default[0]
8784
name = schema.split(":")[-1]
88-
if resource_model.__pydantic_generic_metadata__["args"]:
89-
extensions = resource_model.__pydantic_generic_metadata__["args"][0]
90-
extensions = (
91-
get_args(extensions)
92-
if get_origin(extensions) in UNION_TYPES
93-
else [extensions]
94-
)
85+
86+
# Get extensions from the metadata system
87+
if hasattr(resource_model, "__scim_extension_metadata__"):
88+
extensions = resource_model.__scim_extension_metadata__["extensions"]
9589
else:
9690
extensions = []
9791

tests/test_resource_extension.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,16 @@ def test_get_extension_model():
307307
)
308308
is None
309309
)
310+
311+
312+
def test_class_getitem():
313+
UserEnt = User[EnterpriseUser]
314+
UserEnt2 = UserEnt[EnterpriseUser]
315+
assert UserEnt is UserEnt2
316+
317+
# Test line 178: invalid extension type raises TypeError
318+
with pytest.raises(TypeError, match="is not a valid Extension type"):
319+
User[str]
320+
321+
with pytest.raises(TypeError, match="is not a valid Extension type"):
322+
User[int]

0 commit comments

Comments
 (0)