From fd53440951594fbf6a0dcf3872194e4871c30052 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 10:37:47 +0200 Subject: [PATCH 1/8] fix(rewrite): prevent walrus operator double evaluation in assertions Fixes #14445 - assertion rewriting evaluated NamedExpr (:=) expressions multiple times, causing side effects to fire repeatedly. The root cause was the `variables_overwrite` mechanism which stored and re-evaluated NamedExpr AST nodes in subsequent assertions, in `_call_reprcompare`'s results tuple, and in explanation formatting. The fix: - visit_NamedExpr: reference the target variable in explanations instead of re-evaluating the full expression - visit_Compare: assign left-side NamedExpr to a temp before right-side hoisting; freeze left_res when a comparator walrus targets the same name; replace NamedExpr entries in `results` with target variables - visit_BoolOp: capture short-circuit condition in a stable temp for the explanation path; remove walrus target rename logic - visit_Call: remove variables_overwrite substitution (walrus now properly assigns to user variables in its natural evaluation position) - Remove variables_overwrite, scope tracking, Sentinel class Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 102 +++++++++++-------------------- testing/test_assertrewrite.py | 68 ++++++++++++++++++++- 2 files changed, 102 insertions(+), 68 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 99815b70cf1..d3af1db26bc 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -3,7 +3,6 @@ from __future__ import annotations import ast -from collections import defaultdict from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator @@ -58,10 +57,6 @@ from _pytest.assertion import AssertionState -class Sentinel: - pass - - assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -69,9 +64,6 @@ class Sentinel: PYC_EXT = ".py" + ((__debug__ and "c") or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT -# Special marker that denotes we have just left a scope definition -_SCOPE_END_MARKER = Sentinel() - class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -652,14 +644,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. - :scope: A tuple containing the current scope used for variables_overwrite. - - :variables_overwrite: A dict filled with references to variables - that change value within an assert. This happens when a variable is - reassigned with the walrus operator - - This state, except the variables_overwrite, is reset on every new assert - statement visited and used by the other visitors. + This state is reset on every new assert statement visited and used by + the other visitors. """ def __init__( @@ -675,10 +661,6 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.scope: tuple[ast.AST, ...] = () - self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = ( - defaultdict(dict) - ) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -728,16 +710,9 @@ def run(self, mod: ast.Module) -> None: mod.body[pos:pos] = imports # Collect asserts. - self.scope = (mod,) - nodes: list[ast.AST | Sentinel] = [mod] + nodes: list[ast.AST] = [mod] while nodes: node = nodes.pop() - if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef): - self.scope = tuple((*self.scope, node)) - nodes.append(_SCOPE_END_MARKER) - if node == _SCOPE_END_MARKER: - self.scope = self.scope[:-1] - continue assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): @@ -964,15 +939,17 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: return self.statements def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: - # This method handles the 'walrus operator' repr of the target - # name if it's a local variable or _should_repr_global_name() - # thinks it's acceptable. + # Return the NamedExpr as-is so it evaluates in its natural position + # (preserving left-to-right evaluation order). For the explanation, + # reference the target variable (already assigned by the walrus) to + # avoid re-evaluating the expression. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id + target_name = ast.Name(target_id, ast.Load()) inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) - dorepr = self.helper("_should_repr_global_name", name) + dorepr = self.helper("_should_repr_global_name", target_name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) - expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) + expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id)) return name, self.explanation_param(expr) def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]: @@ -998,20 +975,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: for i, v in enumerate(boolop.values): if i: fail_inner: list[ast.stmt] = [] - # cond is set in a prior loop iteration below - self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 + # expl_cond is set in a prior loop iteration below + self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner - match v: - # Check if the left operand is an ast.NamedExpr and the value has already been visited - case ast.Compare( - left=ast.NamedExpr(target=ast.Name(id=target_id)) - ) if target_id in [ - e.id for e in boolop.values[:i] if hasattr(e, "id") - ]: - pytest_temp = self.variable() - self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] - # mypy's false positive, we're checking that the 'target' attribute exists. - v.left.target.id = pytest_temp # type:ignore[attr-defined] self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1022,8 +988,16 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: cond: ast.expr = res if is_or: cond = ast.UnaryOp(ast.Not(), cond) + # Capture the condition in a temp variable so the explanation + # path (which runs after walrus operators may have modified + # the original variable) sees the correct truthiness. + cond_var = self.variable() + body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond)) + expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841 inner: list[ast.stmt] = [] - self.statements.append(ast.If(cond, inner, [])) + self.statements.append( + ast.If(ast.Name(cond_var, ast.Load()), inner, []) + ) self.statements = body = inner self.statements = save self.expl_stmts = fail_save @@ -1053,19 +1027,10 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( - self.scope, {} - ): - arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - match keyword.value: - case ast.Name(id=id) if id in self.variables_overwrite.get( - self.scope, {} - ): - keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: @@ -1100,17 +1065,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: self.push_format_context() - # We first check if we have overwritten a variable in the previous assert - match comp.left: - case ast.Name(id=name_id) if name_id in self.variables_overwrite.get( - self.scope, {} - ): - comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment] - case ast.NamedExpr(target=ast.Name(id=target_id)): - self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" + # If the left operand is a NamedExpr, assign it to a temp so the + # walrus executes before any right-side expressions are hoisted. + if isinstance(left_res, ast.NamedExpr): + left_res = self.assign(left_res) res_variables = [self.variable() for i in range(len(comp.ops))] load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables] @@ -1119,13 +1080,16 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: + # If the next operand is a walrus that assigns to the same name as + # the current left_res, we must freeze left_res's value before the + # walrus modifies it. match (next_operand, left_res): case ( ast.NamedExpr(target=ast.Name(id=target_id)), ast.Name(id=name_id), ) if target_id == name_id: - next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + left_res = self.assign(left_res) + results[-1] = left_res next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): @@ -1138,6 +1102,12 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl + # Replace NamedExpr entries in results with their target variable + # to avoid re-evaluating walrus operators in the explanation path. + results = [ + ast.Name(r.target.id, ast.Load()) if isinstance(r, ast.NamedExpr) else r + for r in results + ] # Use pytest.assertion.util._reprcompare if that's available. expl_call = self.helper( "_call_reprcompare", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 2668001af65..7c131c9a4f5 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) + result.stdout.fnmatch_lines(["*assert not (False and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + result.stdout.fnmatch_lines(["*assert not (None and None is None)"]) def test_assertion_walrus_operator_value_changes_cleared_after_each_test( self, pytester: Pytester @@ -1846,6 +1846,70 @@ def test_2(): assert result.ret == 0 +class TestIssue14445: + """Regression tests for #14445: walrus operator double evaluation.""" + + def test_walrus_no_double_eval_basic(self, pytester: Pytester) -> None: + """Walrus captures the value at assignment time, not re-evaluated later.""" + pytester.makepyfile( + """ + class Counter: + def __init__(self): + self.value = 0 + def increment(self): + self.value += 1 + + def test_walrus_in_assertion_basic(): + c = Counter() + assert (before := c.value) == 0 + c.increment() + assert before != (after := c.value) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_running_counter(self, pytester: Pytester) -> None: + """Walrus increments fire exactly once per assert statement.""" + pytester.makepyfile( + """ + def test_walrus_running_counter(): + count = 0 + items = [] + items.append("a") + assert (count := count + 1) == len(items) + items.append("b") + assert (count := count + 1) == len(items) + items.append("c") + assert (count := count + 1) == len(items) + assert count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_in_function_call(self, pytester: Pytester) -> None: + """Walrus in function call arguments not evaluated twice.""" + pytester.makepyfile( + """ + call_count = 0 + + def side_effect(): + global call_count + call_count += 1 + return call_count + + def test_walrus_side_effect(): + assert (val := side_effect()) == 1 + assert val == 1 + assert (val := side_effect()) == 2 + assert val == 2 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" ) From 9f371c71c74a4271c30a91342ff6a6f981ed6150 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 12:59:20 +0200 Subject: [PATCH 2/8] Add changelog fragment for #14445 Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- changelog/14445.bugfix.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/14445.bugfix.rst diff --git a/changelog/14445.bugfix.rst b/changelog/14445.bugfix.rst new file mode 100644 index 00000000000..aaae0c615f5 --- /dev/null +++ b/changelog/14445.bugfix.rst @@ -0,0 +1 @@ +Fixed assertion rewriting evaluating walrus operator (``:=``) expressions multiple times, causing incorrect test results when the expression had side effects (e.g., incrementing a counter or calling a function). From df0e0a9d2c5fda04ed59b7c94c714ba9f369ef03 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 13:04:00 +0200 Subject: [PATCH 3/8] test(rewrite): add xfail tests for remaining walrus edge cases Add tests for two remaining walrus double-evaluation scenarios: - Bare NamedExpr as BoolOp operand evaluated twice via condition check - Same walrus target in chained comparison evaluated multiple times Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- testing/test_assertrewrite.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 7c131c9a4f5..91394263756 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1909,6 +1909,46 @@ def test_walrus_side_effect(): result = pytester.runpytest() assert result.ret == 0 + @pytest.mark.xfail(reason="BoolOp condition re-evaluates walrus operand") + def test_walrus_no_double_eval_in_boolop(self, pytester: Pytester) -> None: + """Bare walrus as a BoolOp operand must not be evaluated twice.""" + pytester.makepyfile( + """ + call_count = 0 + + def side_effect(): + global call_count + call_count += 1 + return call_count + + def test_walrus_boolop(): + assert (x := side_effect()) and x == 1 + assert call_count == 1 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.xfail(reason="Chained compare re-evaluates walrus with same target") + def test_walrus_no_double_eval_chained_compare(self, pytester: Pytester) -> None: + """Same walrus target in chained comparison must evaluate each once.""" + pytester.makepyfile( + """ + call_count = 0 + + def track(value): + global call_count + call_count += 1 + return value + + def test_walrus_chained(): + assert (x := track(1)) < (x := track(3)) < (x := track(5)) + assert call_count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" From 60f8e24297a28925baaf50921faa1872721b8321 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 13:05:41 +0200 Subject: [PATCH 4/8] fix(rewrite): avoid double evaluation of walrus in BoolOp condition Use the already-assigned res_var to build the short-circuit condition instead of the raw visitor result, preventing bare NamedExpr operands from being evaluated a second time when checking truthiness. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 9 +++++---- testing/test_assertrewrite.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index d3af1db26bc..d8ea1c5ec8e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -985,12 +985,13 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: - cond: ast.expr = res + # Use res_var (already assigned above) rather than res directly, + # so that NamedExpr operands aren't evaluated a second time. + cond: ast.expr = ast.Name(res_var, ast.Load()) if is_or: cond = ast.UnaryOp(ast.Not(), cond) - # Capture the condition in a temp variable so the explanation - # path (which runs after walrus operators may have modified - # the original variable) sees the correct truthiness. + # Capture the condition in a stable temp for the explanation + # path — res_var is overwritten by subsequent operands. cond_var = self.variable() body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond)) expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841 diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 91394263756..341514b377e 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1909,7 +1909,6 @@ def test_walrus_side_effect(): result = pytester.runpytest() assert result.ret == 0 - @pytest.mark.xfail(reason="BoolOp condition re-evaluates walrus operand") def test_walrus_no_double_eval_in_boolop(self, pytester: Pytester) -> None: """Bare walrus as a BoolOp operand must not be evaluated twice.""" pytester.makepyfile( From ffdc372a3613faed3af8f6d9e86c030223d945ed Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 13:08:10 +0200 Subject: [PATCH 5/8] fix(rewrite): assign walrus comparators to temps in chained comparisons In a chained comparison like `(x := f()) < (x := g()) < (x := h())`, each NamedExpr comparator is now assigned to a temp variable so it evaluates exactly once. Previously the raw NamedExpr node would be reused as left_res in the next iteration, causing double evaluation. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 11 +++++------ testing/test_assertrewrite.py | 1 - 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index d8ea1c5ec8e..3fa3217f6e0 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1095,6 +1095,11 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" + # Assign NamedExpr comparators to a temp so each walrus evaluates + # exactly once — critical for chained comparisons where the same + # node would otherwise be re-evaluated as left_res next iteration. + if isinstance(next_res, ast.NamedExpr): + next_res = self.assign(next_res) results.append(next_res) sym = BINOP_MAP[op.__class__] syms.append(ast.Constant(sym)) @@ -1103,12 +1108,6 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl - # Replace NamedExpr entries in results with their target variable - # to avoid re-evaluating walrus operators in the explanation path. - results = [ - ast.Name(r.target.id, ast.Load()) if isinstance(r, ast.NamedExpr) else r - for r in results - ] # Use pytest.assertion.util._reprcompare if that's available. expl_call = self.helper( "_call_reprcompare", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 341514b377e..11995321826 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1928,7 +1928,6 @@ def test_walrus_boolop(): result = pytester.runpytest() assert result.ret == 0 - @pytest.mark.xfail(reason="Chained compare re-evaluates walrus with same target") def test_walrus_no_double_eval_chained_compare(self, pytester: Pytester) -> None: """Same walrus target in chained comparison must evaluate each once.""" pytester.makepyfile( From ece8626888713fdab5686b7adc79ce2307a1050d Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Wed, 3 Jun 2026 10:55:00 +0200 Subject: [PATCH 6/8] fix(rewrite): show correct walrus values in BoolOp explanations When multiple walrus operators target the same variable in a BoolOp (e.g., `assert (x := side_effect()) and (x := False)`), the assertion explanation previously showed the final value of `x` for all operands because the format context evaluated lazily after all operands ran. Fix by tracking Name/NamedExpr operand values in stable @py_assert variables (via self.assign) immediately after evaluation, then pointing the explanation format context at the tracked copy. This uses the same value-tracking mechanism already used by visit_Call, visit_Attribute, etc. Fixes the case reported by @bluetech in PR review. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 19 +++++++++---------- testing/test_assertrewrite.py | 22 ++++++++++++++++++++-- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 3fa3217f6e0..97b3bb74ace 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -940,9 +940,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: # Return the NamedExpr as-is so it evaluates in its natural position - # (preserving left-to-right evaluation order). For the explanation, - # reference the target variable (already assigned by the walrus) to - # avoid re-evaluating the expression. + # (preserving left-to-right evaluation order in function calls, etc.). + # For the explanation, reference the target variable. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id target_name = ast.Name(target_id, ast.Load()) @@ -981,12 +980,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) + # For Name/NamedExpr operands, track the value in a stable + # @py_assert variable so the explanation shows the value at + # evaluation time — even if a later walrus overwrites the name. + if isinstance(v, ast.NamedExpr | ast.Name): + tracked = self.assign(ast.Name(res_var, ast.Load())) + for key in self.stack[-1]: + self.stack[-1][key] = self.display(tracked) expl_format = self.pop_format_context(ast.Constant(expl)) call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: - # Use res_var (already assigned above) rather than res directly, - # so that NamedExpr operands aren't evaluated a second time. cond: ast.expr = ast.Name(res_var, ast.Load()) if is_or: cond = ast.UnaryOp(ast.Not(), cond) @@ -1069,8 +1073,6 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" - # If the left operand is a NamedExpr, assign it to a temp so the - # walrus executes before any right-side expressions are hoisted. if isinstance(left_res, ast.NamedExpr): left_res = self.assign(left_res) res_variables = [self.variable() for i in range(len(comp.ops))] @@ -1095,9 +1097,6 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" - # Assign NamedExpr comparators to a temp so each walrus evaluates - # exactly once — critical for chained comparisons where the same - # node would otherwise be re-evaluated as left_res next iteration. if isinstance(next_res, ast.NamedExpr): next_res = self.assign(next_res) results.append(next_res) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 11995321826..867f6adf447 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (False and False is False)"]) + result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (None and None is None)"]) + result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) def test_assertion_walrus_operator_value_changes_cleared_after_each_test( self, pytester: Pytester @@ -1947,6 +1947,24 @@ def test_walrus_chained(): result = pytester.runpytest() assert result.ret == 0 + def test_walrus_boolop_same_target_correct_explanation( + self, pytester: Pytester + ) -> None: + """Multiple walrus operators to the same name in a BoolOp must show + each operand's value at evaluation time, not the final value.""" + pytester.makepyfile( + """ + def side_effect(): + return True + + def test_walrus_boolop(): + assert (x := side_effect()) and (x := False) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*assert (True and False)"]) + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" From 302f9c9ec75e239d3e3b9f792da4d2c90e2f5bd2 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Wed, 3 Jun 2026 11:55:59 +0200 Subject: [PATCH 7/8] scripts: add assert-rewrite dump and diff tools Add scripts to inspect and compare assertion rewriting across pytest versions, useful for tracking changes in the rewriter output. - dump-assert-rewrite.py: dumps the rewritten form of a Python file using any pytest version (via uv ephemeral env) or the local worktree. Supports source, ast, and compact output formats. - diff-assert-rewrite.py: compares rewrite output between two versions (or worktree) with colored unified diff, runs both sides in parallel. - example_asserts.py: example file covering common assertion patterns including bluetech's walrus-in-BoolOp case from #14445. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- scripts/diff-assert-rewrite.py | 202 +++++++++++++++++++++++++++++++++ scripts/dump-assert-rewrite.py | 166 +++++++++++++++++++++++++++ scripts/example_asserts.py | 42 +++++++ 3 files changed, 410 insertions(+) create mode 100644 scripts/diff-assert-rewrite.py create mode 100644 scripts/dump-assert-rewrite.py create mode 100644 scripts/example_asserts.py diff --git a/scripts/diff-assert-rewrite.py b/scripts/diff-assert-rewrite.py new file mode 100644 index 00000000000..070b4040d71 --- /dev/null +++ b/scripts/diff-assert-rewrite.py @@ -0,0 +1,202 @@ +"""Compare assert-rewrite output between two pytest versions (or worktree). + +Runs the dump script for both sides (in parallel) and shows a unified diff. +Supports all three output formats: source, ast, and compact. + +Usage:: + + # Two released versions: + python scripts/diff-assert-rewrite.py --left 7.4.0 --right 8.0.0 example.py + + # Release vs local worktree: + python scripts/diff-assert-rewrite.py --left 8.3.0 --right worktree example.py + + # Compact AST diff (strips position noise): + python scripts/diff-assert-rewrite.py --left 7.4.0 --right worktree -f compact example.py + + # All formats at once: + python scripts/diff-assert-rewrite.py --left 7.4.0 --right 8.0.0 -f all example.py +""" + +from __future__ import annotations + +import argparse +from concurrent.futures import ThreadPoolExecutor +import difflib +from pathlib import Path +import sys + + +_DUMP_SCRIPT = Path(__file__).resolve().parent / "dump-assert-rewrite.py" + +_FORMATS = ("source", "ast", "compact") + + +def _label(spec: str) -> str: + return "worktree" if spec == "worktree" else f"pytest=={spec}" + + +def get_dump(spec: str, file_path: Path, fmt: str) -> str: + """Run the dump script for a single side and return its output.""" + # Import inline so this file stays lightweight at module level. + import subprocess + + args = [sys.executable, str(_DUMP_SCRIPT)] + if spec == "worktree": + args.append("--worktree") + else: + args.extend(["--pytest-version", spec]) + args.extend(["--format", fmt, str(file_path)]) + + result = subprocess.run(args, capture_output=True, text=True, check=False) + if result.returncode != 0: + sys.stderr.write(result.stderr) + raise SystemExit(f"Dump failed for {_label(spec)}") + return result.stdout + + +def colored_diff(lines: list[str]) -> str: + """Apply ANSI colours to a unified-diff line list.""" + RED = "\033[31m" + GREEN = "\033[32m" + CYAN = "\033[36m" + RESET = "\033[0m" + + out: list[str] = [] + for line in lines: + if line.startswith(("---", "+++")): + out.append(f"{CYAN}{line}{RESET}") + elif line.startswith("-"): + out.append(f"{RED}{line}{RESET}") + elif line.startswith("+"): + out.append(f"{GREEN}{line}{RESET}") + elif line.startswith("@@"): + out.append(f"{CYAN}{line}{RESET}") + else: + out.append(line) + return "\n".join(out) + + +def show_diff( + left_text: str, + right_text: str, + *, + left_label: str, + right_label: str, + fmt: str, + context: int, + use_color: bool, +) -> bool: + """Print a unified diff; return True if differences were found.""" + if left_text == right_text: + return False + + diff_lines = list( + difflib.unified_diff( + left_text.splitlines(), + right_text.splitlines(), + fromfile=f"{left_label} [{fmt}]", + tofile=f"{right_label} [{fmt}]", + n=context, + ) + ) + + if use_color: + print(colored_diff(diff_lines)) + else: + print("\n".join(diff_lines)) + + return True + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--left", + required=True, + metavar="VER|worktree", + help="Left side: pytest version or 'worktree'", + ) + parser.add_argument( + "--right", + required=True, + metavar="VER|worktree", + help="Right side: pytest version or 'worktree'", + ) + parser.add_argument( + "--format", + "-f", + dest="fmt", + choices=(*_FORMATS, "all"), + default="source", + help="Output format (default: source)", + ) + parser.add_argument( + "--context", + "-C", + type=int, + default=3, + help="Context lines in diff (default: 3)", + ) + parser.add_argument( + "--no-color", + action="store_true", + help="Disable coloured output", + ) + parser.add_argument("file", type=Path, help="Python file to compare") + args = parser.parse_args(argv) + + if not args.file.is_file(): + raise SystemExit(f"File not found: {args.file}") + + formats = _FORMATS if args.fmt == "all" else (args.fmt,) + use_color = not args.no_color and sys.stdout.isatty() + left_label = _label(args.left) + right_label = _label(args.right) + + # Fetch all needed dumps in parallel (both sides x all formats). + jobs: dict[tuple[str, str], str] = {} + with ThreadPoolExecutor(max_workers=len(formats) * 2) as pool: + futures = { + (side, fmt): pool.submit(get_dump, spec, args.file, fmt) + for fmt in formats + for side, spec in [("left", args.left), ("right", args.right)] + } + for key, future in futures.items(): + jobs[key] = future.result() + + any_diff = False + for fmt in formats: + if len(formats) > 1: + header = f"=== {fmt} ===" + if use_color: + header = f"\033[1m{header}\033[0m" + print(header) + + had_diff = show_diff( + jobs[("left", fmt)], + jobs[("right", fmt)], + left_label=left_label, + right_label=right_label, + fmt=fmt, + context=args.context, + use_color=use_color, + ) + if not had_diff: + print( + f"No differences in {fmt} output between {left_label} and {right_label}" + ) + else: + any_diff = True + + if len(formats) > 1: + print() + + raise SystemExit(1 if any_diff else 0) + + +if __name__ == "__main__": + main() diff --git a/scripts/dump-assert-rewrite.py b/scripts/dump-assert-rewrite.py new file mode 100644 index 00000000000..75e0b0d734f --- /dev/null +++ b/scripts/dump-assert-rewrite.py @@ -0,0 +1,166 @@ +"""Dump the assert-rewritten form of a Python file for a specific pytest version. + +Uses ``uv run`` to execute the rewriter in an ephemeral environment, so any +released pytest version can be inspected without installing it globally. + +Usage:: + + # Rewritten source (default): + python scripts/dump-assert-rewrite.py --pytest-version 8.0.0 example.py + + # Using local worktree: + python scripts/dump-assert-rewrite.py --worktree example.py + + # Compact AST (best for diffing -- no position attributes): + python scripts/dump-assert-rewrite.py --worktree --format compact example.py + + # Full AST with positions: + python scripts/dump-assert-rewrite.py --pytest-version 7.4.0 --format ast example.py +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path +import subprocess +import sys +import textwrap + + +# Self-contained script executed inside the target pytest environment. +# Reads source from stdin, writes the rewritten form to stdout. +_WORKER_SCRIPT = textwrap.dedent("""\ + import ast + import sys + + source = sys.stdin.buffer.read() + fmt = sys.argv[1] + + try: + from _pytest.assertion.rewrite import rewrite_asserts + except ImportError: + sys.exit("Could not import rewrite_asserts from this pytest version") + + tree = ast.parse(source) + try: + rewrite_asserts(tree, source) + except TypeError: + # pytest < 6 did not accept the source parameter + tree = ast.parse(source) + rewrite_asserts(tree) + + ast.fix_missing_locations(tree) + + if fmt == "source": + print(ast.unparse(tree)) + elif fmt == "ast": + print(ast.dump(tree, indent=2)) + elif fmt == "compact": + print(ast.dump(tree, indent=2, include_attributes=False)) + else: + sys.exit(f"Unknown format: {fmt!r}") +""") + + +def run_worker( + *, + pytest_version: str | None, + worktree: bool, + file_content: bytes, + fmt: str, +) -> str: + """Execute the worker script and return its stdout.""" + if worktree: + repo_root = Path(__file__).resolve().parent.parent + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(repo_root / "src") + os.pathsep + env.get("PYTHONPATH", "") + ) + cmd = [sys.executable, "-c", _WORKER_SCRIPT, fmt] + else: + assert pytest_version is not None + cmd = [ + "uv", + "run", + "--no-project", + "--with", + f"pytest=={pytest_version}", + "--", + "python", + "-c", + _WORKER_SCRIPT, + fmt, + ] + env = None + + try: + result = subprocess.run( + cmd, input=file_content, capture_output=True, check=False, env=env + ) + except FileNotFoundError as exc: + if "uv" in str(exc): + raise SystemExit( + "'uv' not found — install it: https://docs.astral.sh/uv/" + ) from exc + raise + + if result.returncode != 0: + label = version_label(pytest_version=pytest_version, worktree=worktree) + sys.stderr.buffer.write(result.stderr) + raise SystemExit(f"Worker failed for {label}") + + return result.stdout.decode() + + +def version_label(*, pytest_version: str | None = None, worktree: bool = False) -> str: + """Human-readable label for a pytest source.""" + return "worktree" if worktree else f"pytest=={pytest_version}" + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + src = parser.add_mutually_exclusive_group(required=True) + src.add_argument( + "--pytest-version", + metavar="VER", + help="Released pytest version (e.g. 8.0.0)", + ) + src.add_argument( + "--worktree", + action="store_true", + help="Use the local worktree's src/", + ) + parser.add_argument( + "--format", + dest="fmt", + choices=("source", "ast", "compact"), + default="source", + help="Output format (default: source)", + ) + parser.add_argument("file", type=Path, help="Python file to rewrite") + return parser + + +def main(argv: list[str] | None = None) -> None: + args = build_parser().parse_args(argv) + + if not args.file.is_file(): + raise SystemExit(f"File not found: {args.file}") + + output = run_worker( + pytest_version=args.pytest_version, + worktree=args.worktree, + file_content=args.file.read_bytes(), + fmt=args.fmt, + ) + sys.stdout.write(output) + if output and not output.endswith("\n"): + sys.stdout.write("\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/example_asserts.py b/scripts/example_asserts.py new file mode 100644 index 00000000000..22571d9995f --- /dev/null +++ b/scripts/example_asserts.py @@ -0,0 +1,42 @@ +"""Example file for dump-assert-rewrite.py and diff-assert-rewrite.py. + +Includes bluetech's walrus-in-BoolOp case from the #14445 review, plus +other common assertion patterns useful for spotting rewrite changes. +""" + +from __future__ import annotations + + +def side_effect() -> bool: + return True + + +def test_walrus_boolop() -> None: + """Bluetech's example: walrus reassignment inside BoolOp.""" + assert (x := side_effect()) and (x := False) # noqa: F841 + + +def test_simple_equality() -> None: + x = 1 + assert x == 2 + + +def test_comparison() -> None: + a = [1, 2, 3] + b = [1, 2, 4] + assert a == b + + +def test_boolean() -> None: + x = True + y = False + assert x and y + + +def test_membership() -> None: + assert 5 in [1, 2, 3] + + +def test_message() -> None: + value = 42 + assert value > 100, f"expected > 100, got {value}" From 817d4314e947a94f93ecf32dd434ab9a62d80539 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Wed, 3 Jun 2026 11:56:28 +0200 Subject: [PATCH 8/8] refactor(rewrite): minimal snapshots in BoolOp for walrus conflicts Replace the blanket snapshot-all-operands approach with a targeted one: pre-scan the BoolOp to find walrus targets, then only snapshot operands whose value a later walrus would corrupt. Snapshot rules: - NamedExpr (non-last): always, to avoid re-evaluating side effects - Name with later walrus conflict: to freeze the pre-overwrite value - Everything else: use res directly (stable @py_assert or plain name) Non-walrus BoolOps now generate identical code to 8.3.5 (no snapshots). Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- src/_pytest/assertion/rewrite.py | 52 +++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 97b3bb74ace..d1e2cfce207 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -969,40 +969,56 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: body = save = self.statements fail_save = self.expl_stmts levels = len(boolop.values) - 1 + # Pre-scan: for each operand position, collect the set of variable + # names that a *later* operand's walrus operator will overwrite. + # An operand needs a snapshot only when its value references a name + # in this set (otherwise the explanation would show the post-walrus + # value instead of the value at evaluation time). + later_walrus_targets: list[set[str]] = [set() for _ in boolop.values] + seen: set[str] = set() + for idx in range(len(boolop.values) - 1, -1, -1): + later_walrus_targets[idx] = set(seen) + for node in ast.walk(boolop.values[idx]): + if isinstance(node, ast.NamedExpr): + seen.add(node.target.id) self.push_format_context() - # Process each operand, short-circuiting if needed. + # Process each operand, short-circuiting as needed. for i, v in enumerate(boolop.values): if i: fail_inner: list[ast.stmt] = [] - # expl_cond is set in a prior loop iteration below - self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821 + # cond is set in a prior loop iteration below + self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) - # For Name/NamedExpr operands, track the value in a stable - # @py_assert variable so the explanation shows the value at - # evaluation time — even if a later walrus overwrites the name. - if isinstance(v, ast.NamedExpr | ast.Name): - tracked = self.assign(ast.Name(res_var, ast.Load())) + # Snapshot when the raw ``res`` node would be unsafe to reuse + # as a condition or explanation reference: + # - NamedExpr (non-last): reusing the node re-evaluates the + # walrus expression including any side effects. + # - Name whose variable a later walrus overwrites: the + # explanation would show the post-walrus value. + needs_snapshot = (isinstance(v, ast.NamedExpr) and i < levels) or ( + isinstance(v, ast.Name) and v.id in later_walrus_targets[i] + ) + if needs_snapshot: + snapshot = self.assign(ast.Name(res_var, ast.Load())) + res = snapshot for key in self.stack[-1]: - self.stack[-1][key] = self.display(tracked) + self.stack[-1][key] = self.display(snapshot) expl_format = self.pop_format_context(ast.Constant(expl)) call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: - cond: ast.expr = ast.Name(res_var, ast.Load()) + # Short-circuit: and → continue if truthy; or → if falsy. + # ``res`` is a stable reference (Name vars are only + # snapshotted when a later walrus would corrupt them; + # calls/compares return @py_assert vars from assign()). + cond: ast.expr = res if is_or: cond = ast.UnaryOp(ast.Not(), cond) - # Capture the condition in a stable temp for the explanation - # path — res_var is overwritten by subsequent operands. - cond_var = self.variable() - body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond)) - expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841 inner: list[ast.stmt] = [] - self.statements.append( - ast.If(ast.Name(cond_var, ast.Load()), inner, []) - ) + self.statements.append(ast.If(cond, inner, [])) self.statements = body = inner self.statements = save self.expl_stmts = fail_save