Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,12 +1948,7 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
# but better than AnyType...), but replace the return type
# with typevar.
callee = self.analyze_type_type_callee(get_proper_type(item.upper_bound), context)
callee = get_proper_type(callee)
if isinstance(callee, CallableType):
callee = callee.copy_modified(ret_type=item)
elif isinstance(callee, Overloaded):
callee = Overloaded([c.copy_modified(ret_type=item) for c in callee.items])
return callee
return self.replace_type_type_callee_ret_type(callee, item)
# We support Type of namedtuples but not of tuples in general
if isinstance(item, TupleType) and tuple_fallback(item).type.fullname != "builtins.tuple":
return self.analyze_type_type_callee(tuple_fallback(item), context)
Expand All @@ -1963,6 +1958,23 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
self.msg.unsupported_type_type(item, context)
return AnyType(TypeOfAny.from_error)

def replace_type_type_callee_ret_type(self, callee: Type, ret_type: Type) -> Type:
callee = get_proper_type(callee)
if isinstance(callee, CallableType):
return callee.copy_modified(ret_type=ret_type)
if isinstance(callee, Overloaded):
return Overloaded([c.copy_modified(ret_type=ret_type) for c in callee.items])
if isinstance(callee, UnionType):
return UnionType(
[
self.replace_type_type_callee_ret_type(item, ret_type)
for item in callee.relevant_items()
],
line=callee.line,
column=callee.column,
)
return callee

def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]:
"""Infer argument expression types in an empty context.

Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -3889,6 +3889,24 @@ def process(cls: Type[U]):
[builtins fixtures/classmethod.pyi]
[out]

[case testTypeUsingTypeCConstructorReturnFromTypeVarUnionBound]
from typing import Optional, Type, TypeVar, Union

class A:
def __init__(self, value: str = "") -> None: pass
class B:
def __init__(self, value: str = "") -> None: pass

T = TypeVar("T", bound=Union[A, B])

def make(ftype: Type[T], value: Optional[str]) -> T:
if value is None:
return ftype()
return ftype(value)

reveal_type(make(A, "a")) # N: Revealed type is "__main__.A"
reveal_type(make(B, None)) # N: Revealed type is "__main__.B"

[case testTypeUsingTypeCErrorUnsupportedType]
from typing import Type, Tuple
def foo(arg: Type[Tuple[int]]):
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -3531,7 +3531,7 @@ def switch(choice: type[T_Choice]) -> None:
reveal_type(choice()) # N: Revealed type is "b.Two"
case _:
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
reveal_type(choice()) # N: Revealed type is "b.One | b.Two"
reveal_type(choice()) # N: Revealed type is "T_Choice`-1"

[file b.py]
class One: ...
Expand Down
14 changes: 14 additions & 0 deletions test-data/unit/check-python312.test
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,20 @@ f(1, u)
f('x', None) # E: Value of type variable "T" of "f" cannot be "str" \
# E: Value of type variable "S" of "f" cannot be "None"

[case testPEP695UpperBoundTypeTypeConstructorReturnType]
class A:
def __init__(self, value: str = "") -> None: pass
class B:
def __init__(self, value: str = "") -> None: pass

def make[T: A | B](ftype: type[T], value: str | None) -> T:
if value is None:
return ftype()
return ftype(value)

reveal_type(make(A, "a")) # N: Revealed type is "__main__.A"
reveal_type(make(B, None)) # N: Revealed type is "__main__.B"

[case testPEP695InferVarianceOfTupleType]
class Cov[T](tuple[int, str]):
def f(self) -> T: pass
Expand Down
Loading