From b87c76150fcc351a15d45a28798bf149f15f3e03 Mon Sep 17 00:00:00 2001 From: marko1olo Date: Sun, 7 Jun 2026 03:47:55 +0400 Subject: [PATCH] Fix walrus state leaking between asserts --- AUTHORS | 1 + changelog/14445.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 32 +++++++++++++++++++-- testing/test_assertrewrite.py | 49 ++++++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 changelog/14445.bugfix.rst diff --git a/AUTHORS b/AUTHORS index 27c0b3ac408..408ad6c063c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -302,6 +302,7 @@ Mark Abramowitz Mark Dickinson Mark Vong Marko Pacak +marko1olo Markus Unterwaditzer Martijn Faassen Martin Altmayer diff --git a/changelog/14445.bugfix.rst b/changelog/14445.bugfix.rst new file mode 100644 index 00000000000..df3722cfd06 --- /dev/null +++ b/changelog/14445.bugfix.rst @@ -0,0 +1 @@ +Fixed assertion rewriting leaking walrus operator state into later assertions in the same function. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 99815b70cf1..4a10277e929 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -869,6 +869,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: self.statements: list[ast.stmt] = [] self.variables: list[str] = [] self.variable_counter = itertools.count() + self.variables_overwrite[self.scope] = {} + self.variable_restore_names: list[tuple[str, str]] = [] if self.enable_assertion_pass_hook: self.format_variables: list[str] = [] @@ -881,6 +883,15 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: negation = ast.UnaryOp(ast.Not(), top_condition) + def restore_walrus_targets() -> list[ast.Assign]: + return [ + ast.Assign( + [ast.Name(target_id, ast.Store())], + ast.Name(temp_id, ast.Load()), + ) + for target_id, temp_id in self.variable_restore_names + ] + if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook msg = self.pop_format_context(ast.Constant(explanation)) @@ -899,6 +910,7 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: raise_ = ast.Raise(exc, None) statements_fail = [] statements_fail.extend(self.expl_stmts) + statements_fail.extend(restore_walrus_targets()) statements_fail.append(raise_) # Passed @@ -918,7 +930,10 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: [*self.expl_stmts, hook_call_pass], [], ) - statements_pass: list[ast.stmt] = [hook_impl_test] + statements_pass: list[ast.stmt] = [ + hook_impl_test, + *restore_walrus_targets(), + ] # Test for assertion condition main_test = ast.If(negation, statements_fail, statements_pass) @@ -947,7 +962,9 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: exc = ast.Call(err_name, [fmt], []) raise_ = ast.Raise(exc, None) + body.extend(restore_walrus_targets()) body.append(raise_) + self.statements.extend(restore_walrus_targets()) # Clear temporary variables by setting them to None. if self.variables: @@ -1001,6 +1018,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner + restore_after_operand: tuple[str, str] | None = None match v: # Check if the left operand is an ast.NamedExpr and the value has already been visited case ast.Compare( @@ -1012,9 +1030,14 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] # mypy's false positive, we're checking that the 'target' attribute exists. v.left.target.id = pytest_temp # type:ignore[attr-defined] + restore_after_operand = (target_id, pytest_temp) + else: + restore_after_operand = None self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) + if restore_after_operand is not None: + self.variable_restore_names.append(restore_after_operand) expl_format = self.pop_format_context(ast.Constant(expl)) call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) @@ -1119,13 +1142,16 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: + restore_after_compare: tuple[str, str] | None = None match (next_operand, left_res): case ( ast.NamedExpr(target=ast.Name(id=target_id)), ast.Name(id=name_id), ) if target_id == name_id: - next_operand.target.id = self.variable() + temp_id = self.variable() + next_operand.target.id = temp_id self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + restore_after_compare = (name_id, temp_id) next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): @@ -1137,6 +1163,8 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: expls.append(ast.Constant(expl)) res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) self.statements.append(ast.Assign([store_names[i]], res_expr)) + if restore_after_compare is not None: + self.variable_restore_names.append(restore_after_compare) left_res, left_expl = next_res, next_expl # Use pytest.assertion.util._reprcompare if that's available. expl_call = self.helper( diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index e11863547ba..0f8b87f2824 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1722,6 +1722,55 @@ def test_walrus_operator_not_override_value(): result = pytester.runpytest() assert result.ret == 0 + def test_assertion_walrus_operator_value_changes_cleared_after_each_assert( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + class Counter: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + + def test_walrus_operator_change_value_between_asserts(): + counter = Counter() + assert (before := counter.value) == 0 + counter.increment() + assert before != (after := counter.value) + assert before == 0 + assert after == 1 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_restore_with_assertion_pass_hook( + self, pytester: Pytester + ) -> None: + pytester.makeini("[pytest]\nenable_assertion_pass_hook = True\n") + pytester.makepyfile( + """ + def test_walrus_operator_pass_compare_restore(): + a = "Hello" + assert a != (a := a.lower()) + assert a == "hello" + + def test_walrus_operator_pass_bool_restore(): + a = True + assert a and ((a := False) is False) and (a is False) + assert a is False + + def test_walrus_operator_fail_compare_explanation(): + a = "Hello" + assert a == (a := a.lower()) + """ + ) + result = pytester.runpytest() + result.assert_outcomes(passed=2, failed=1) + result.stdout.fnmatch_lines(["*assert 'Hello' == 'hello'"]) + def test_assertion_namedexpr_compare_left_overwrite( self, pytester: Pytester ) -> None: