Skip to content

Commit f7f0889

Browse files
jakkdlbluetech
authored andcommitted
Add ParamSpec to legacy callable forms of raises/warns/deprecated_call
`pytest.raises`, `warns` & `deprecated_call` previously typed `*args` and `**kwargs` as `Any` in the legacy callable form, so this did not raise errors: ```py def foo(x: int) -> None: raise ValueError raises(ValueError, foo, None) ``` but now it will give call-overload. It also makes it possible to pass `func` as a kwarg, which the type hints previously showed as possible, but it didn't work. It's possible that `func` (and the expected type?) should be pos-only, as this looks quite weird: ```py raises(1, 2, kwarg1=3, func=my_func, kwarg2=4, expected_exception=ValueError) ``` but if somebody is dynamically generating parameters to send to `raises` then we probably shouldn't ban it needlessly; and we can't make `func` pos-only without making `expected_exception` pos-only, and that could break backwards compatibility.
1 parent d42cef1 commit f7f0889

5 files changed

Lines changed: 30 additions & 19 deletions

File tree

changelog/13241.improvement.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
:func:`pytest.raises`, :func:`pytest.warns` and :func:`pytest.deprecated_call` now uses :class:`ParamSpec` for the type hint to the (old and not recommended) callable overload, instead of :class:`Any`. This allows type checkers to raise errors when passing incorrect function parameters.
2+
``func`` can now also be passed as a kwarg, which the type hint previously showed as possible but didn't accept.

doc/en/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@
102102
# TypeVars
103103
("py:class", "_pytest._code.code.E"),
104104
("py:class", "E"), # due to delayed annotation
105+
("py:class", "T"),
106+
("py:class", "P"),
107+
("py:class", "P.args"),
108+
("py:class", "P.kwargs"),
105109
("py:class", "_pytest.fixtures.FixtureFunction"),
106110
("py:class", "_pytest.nodes._NodeType"),
107111
("py:class", "_NodeType"), # due to delayed annotation

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ exclude_lines = [
453453
'^\s*case unreachable:',
454454
'^\s*assert_never\(',
455455
'^\s*if TYPE_CHECKING:',
456+
'^\s*(el)?if TYPE_CHECKING:',
456457
'^\s*@overload( |$)',
457458
'^\s*def .+: \.\.\.$',
458459
'^\s*@pytest\.mark\.xfail',

src/_pytest/raises.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ def raises(*, check: Callable[[BaseException], bool]) -> RaisesExc[BaseException
9595
@overload
9696
def raises(
9797
expected_exception: type[E] | tuple[type[E], ...],
98-
func: Callable[..., Any],
99-
*args: Any,
100-
**kwargs: Any,
98+
func: Callable[P, object],
99+
*args: P.args,
100+
**kwargs: P.kwargs,
101101
) -> ExceptionInfo[E]: ...
102102

103103

104104
def raises(
105105
expected_exception: type[E] | tuple[type[E], ...] | None = None,
106+
func: Callable[P, object] | None = None,
106107
*args: Any,
107108
**kwargs: Any,
108109
) -> RaisesExc[BaseException] | ExceptionInfo[E]:
@@ -253,7 +254,7 @@ def raises(
253254
"""
254255
__tracebackhide__ = True
255256

256-
if not args:
257+
if func is None and not args:
257258
if set(kwargs) - {"match", "check", "expected_exception"}:
258259
msg = "Unexpected keyword arguments passed to pytest.raises: "
259260
msg += ", ".join(sorted(kwargs))
@@ -270,11 +271,10 @@ def raises(
270271
f"Raising exceptions is already understood as failing the test, so you don't need "
271272
f"any special code to say 'this should never raise an exception'."
272273
)
273-
func = args[0]
274274
if not callable(func):
275275
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
276276
with RaisesExc(expected_exception) as excinfo:
277-
func(*args[1:], **kwargs)
277+
func(*args, **kwargs)
278278
try:
279279
return excinfo
280280
finally:

src/_pytest/recwarn.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818

1919
if TYPE_CHECKING:
20+
from typing_extensions import ParamSpec
2021
from typing_extensions import Self
2122

23+
P = ParamSpec("P")
24+
2225
import warnings
2326

2427
from _pytest.deprecated import check_ispytest
@@ -49,7 +52,7 @@ def deprecated_call(
4952

5053

5154
@overload
52-
def deprecated_call(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: ...
55+
def deprecated_call(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
5356

5457

5558
def deprecated_call(
@@ -78,11 +81,12 @@ def deprecated_call(
7881
(regardless of whether it is an ``expected_warning`` or not).
7982
"""
8083
__tracebackhide__ = True
81-
if func is not None:
82-
args = (func, *args)
83-
return warns(
84-
(DeprecationWarning, PendingDeprecationWarning, FutureWarning), *args, **kwargs
85-
)
84+
dep_warnings = (DeprecationWarning, PendingDeprecationWarning, FutureWarning)
85+
if func is None:
86+
return warns(dep_warnings, *args, **kwargs)
87+
88+
with warns(dep_warnings):
89+
return func(*args, **kwargs)
8690

8791

8892
@overload
@@ -96,16 +100,16 @@ def warns(
96100
@overload
97101
def warns(
98102
expected_warning: type[Warning] | tuple[type[Warning], ...],
99-
func: Callable[..., T],
100-
*args: Any,
101-
**kwargs: Any,
103+
func: Callable[P, T],
104+
*args: P.args,
105+
**kwargs: P.kwargs,
102106
) -> T: ...
103107

104108

105109
def warns(
106110
expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
111+
func: Callable[..., object] | None = None,
107112
*args: Any,
108-
match: str | re.Pattern[str] | None = None,
109113
**kwargs: Any,
110114
) -> WarningsChecker | Any:
111115
r"""Assert that code raises a particular class of warning.
@@ -152,7 +156,8 @@ def warns(
152156
153157
"""
154158
__tracebackhide__ = True
155-
if not args:
159+
if func is None and not args:
160+
match: str | re.Pattern[str] | None = kwargs.pop("match", None)
156161
if kwargs:
157162
argnames = ", ".join(sorted(kwargs))
158163
raise TypeError(
@@ -161,11 +166,10 @@ def warns(
161166
)
162167
return WarningsChecker(expected_warning, match_expr=match, _ispytest=True)
163168
else:
164-
func = args[0]
165169
if not callable(func):
166170
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
167171
with WarningsChecker(expected_warning, _ispytest=True):
168-
return func(*args[1:], **kwargs)
172+
return func(*args, **kwargs)
169173

170174

171175
class WarningsRecorder(warnings.catch_warnings):

0 commit comments

Comments
 (0)