Skip to content
Open
1 change: 1 addition & 0 deletions changelog/14445.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed assertion rewriting evaluating walrus operator (``:=``) expressions multiple times, causing incorrect test results when the expression had side effects (e.g., incrementing a counter or calling a function).
202 changes: 202 additions & 0 deletions scripts/diff-assert-rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Compare assert-rewrite output between two pytest versions (or worktree).

Runs the dump script for both sides (in parallel) and shows a unified diff.
Supports all three output formats: source, ast, and compact.

Usage::

# Two released versions:
python scripts/diff-assert-rewrite.py --left 7.4.0 --right 8.0.0 example.py

# Release vs local worktree:
python scripts/diff-assert-rewrite.py --left 8.3.0 --right worktree example.py

# Compact AST diff (strips position noise):
python scripts/diff-assert-rewrite.py --left 7.4.0 --right worktree -f compact example.py

# All formats at once:
python scripts/diff-assert-rewrite.py --left 7.4.0 --right 8.0.0 -f all example.py
"""

from __future__ import annotations

import argparse
from concurrent.futures import ThreadPoolExecutor
import difflib
from pathlib import Path
import sys


_DUMP_SCRIPT = Path(__file__).resolve().parent / "dump-assert-rewrite.py"

_FORMATS = ("source", "ast", "compact")


def _label(spec: str) -> str:
return "worktree" if spec == "worktree" else f"pytest=={spec}"


def get_dump(spec: str, file_path: Path, fmt: str) -> str:
"""Run the dump script for a single side and return its output."""
# Import inline so this file stays lightweight at module level.
import subprocess

args = [sys.executable, str(_DUMP_SCRIPT)]
if spec == "worktree":
args.append("--worktree")
else:
args.extend(["--pytest-version", spec])
args.extend(["--format", fmt, str(file_path)])

result = subprocess.run(args, capture_output=True, text=True, check=False)
if result.returncode != 0:
sys.stderr.write(result.stderr)
raise SystemExit(f"Dump failed for {_label(spec)}")
return result.stdout


def colored_diff(lines: list[str]) -> str:
"""Apply ANSI colours to a unified-diff line list."""
RED = "\033[31m"
GREEN = "\033[32m"
CYAN = "\033[36m"
RESET = "\033[0m"

out: list[str] = []
for line in lines:
if line.startswith(("---", "+++")):
out.append(f"{CYAN}{line}{RESET}")
elif line.startswith("-"):
out.append(f"{RED}{line}{RESET}")
elif line.startswith("+"):
out.append(f"{GREEN}{line}{RESET}")
elif line.startswith("@@"):
out.append(f"{CYAN}{line}{RESET}")
else:
out.append(line)
return "\n".join(out)


def show_diff(
left_text: str,
right_text: str,
*,
left_label: str,
right_label: str,
fmt: str,
context: int,
use_color: bool,
) -> bool:
"""Print a unified diff; return True if differences were found."""
if left_text == right_text:
return False

diff_lines = list(
difflib.unified_diff(
left_text.splitlines(),
right_text.splitlines(),
fromfile=f"{left_label} [{fmt}]",
tofile=f"{right_label} [{fmt}]",
n=context,
)
)

if use_color:
print(colored_diff(diff_lines))
else:
print("\n".join(diff_lines))

return True


def main(argv: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--left",
required=True,
metavar="VER|worktree",
help="Left side: pytest version or 'worktree'",
)
parser.add_argument(
"--right",
required=True,
metavar="VER|worktree",
help="Right side: pytest version or 'worktree'",
)
parser.add_argument(
"--format",
"-f",
dest="fmt",
choices=(*_FORMATS, "all"),
default="source",
help="Output format (default: source)",
)
parser.add_argument(
"--context",
"-C",
type=int,
default=3,
help="Context lines in diff (default: 3)",
)
parser.add_argument(
"--no-color",
action="store_true",
help="Disable coloured output",
)
parser.add_argument("file", type=Path, help="Python file to compare")
args = parser.parse_args(argv)

if not args.file.is_file():
raise SystemExit(f"File not found: {args.file}")

formats = _FORMATS if args.fmt == "all" else (args.fmt,)
use_color = not args.no_color and sys.stdout.isatty()
left_label = _label(args.left)
right_label = _label(args.right)

# Fetch all needed dumps in parallel (both sides x all formats).
jobs: dict[tuple[str, str], str] = {}
with ThreadPoolExecutor(max_workers=len(formats) * 2) as pool:
futures = {
(side, fmt): pool.submit(get_dump, spec, args.file, fmt)
for fmt in formats
for side, spec in [("left", args.left), ("right", args.right)]
}
for key, future in futures.items():
jobs[key] = future.result()

any_diff = False
for fmt in formats:
if len(formats) > 1:
header = f"=== {fmt} ==="
if use_color:
header = f"\033[1m{header}\033[0m"
print(header)

had_diff = show_diff(
jobs[("left", fmt)],
jobs[("right", fmt)],
left_label=left_label,
right_label=right_label,
fmt=fmt,
context=args.context,
use_color=use_color,
)
if not had_diff:
print(
f"No differences in {fmt} output between {left_label} and {right_label}"
)
else:
any_diff = True

if len(formats) > 1:
print()

raise SystemExit(1 if any_diff else 0)


if __name__ == "__main__":
main()
166 changes: 166 additions & 0 deletions scripts/dump-assert-rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Dump the assert-rewritten form of a Python file for a specific pytest version.

Uses ``uv run`` to execute the rewriter in an ephemeral environment, so any
released pytest version can be inspected without installing it globally.

Usage::

# Rewritten source (default):
python scripts/dump-assert-rewrite.py --pytest-version 8.0.0 example.py

# Using local worktree:
python scripts/dump-assert-rewrite.py --worktree example.py

# Compact AST (best for diffing -- no position attributes):
python scripts/dump-assert-rewrite.py --worktree --format compact example.py

# Full AST with positions:
python scripts/dump-assert-rewrite.py --pytest-version 7.4.0 --format ast example.py
"""

from __future__ import annotations

import argparse
import os
from pathlib import Path
import subprocess
import sys
import textwrap


# Self-contained script executed inside the target pytest environment.
# Reads source from stdin, writes the rewritten form to stdout.
_WORKER_SCRIPT = textwrap.dedent("""\
import ast
import sys

source = sys.stdin.buffer.read()
fmt = sys.argv[1]

try:
from _pytest.assertion.rewrite import rewrite_asserts
except ImportError:
sys.exit("Could not import rewrite_asserts from this pytest version")

tree = ast.parse(source)
try:
rewrite_asserts(tree, source)
except TypeError:
# pytest < 6 did not accept the source parameter
tree = ast.parse(source)
rewrite_asserts(tree)

ast.fix_missing_locations(tree)

if fmt == "source":
print(ast.unparse(tree))
elif fmt == "ast":
print(ast.dump(tree, indent=2))
elif fmt == "compact":
print(ast.dump(tree, indent=2, include_attributes=False))
else:
sys.exit(f"Unknown format: {fmt!r}")
""")


def run_worker(
*,
pytest_version: str | None,
worktree: bool,
file_content: bytes,
fmt: str,
) -> str:
"""Execute the worker script and return its stdout."""
if worktree:
repo_root = Path(__file__).resolve().parent.parent
env = os.environ.copy()
env["PYTHONPATH"] = (
str(repo_root / "src") + os.pathsep + env.get("PYTHONPATH", "")
)
cmd = [sys.executable, "-c", _WORKER_SCRIPT, fmt]
else:
assert pytest_version is not None
cmd = [
"uv",
"run",
"--no-project",
"--with",
f"pytest=={pytest_version}",
"--",
"python",
"-c",
_WORKER_SCRIPT,
fmt,
]
env = None

try:
result = subprocess.run(
cmd, input=file_content, capture_output=True, check=False, env=env
)
except FileNotFoundError as exc:
if "uv" in str(exc):
raise SystemExit(
"'uv' not found — install it: https://docs.astral.sh/uv/"
) from exc
raise

if result.returncode != 0:
label = version_label(pytest_version=pytest_version, worktree=worktree)
sys.stderr.buffer.write(result.stderr)
raise SystemExit(f"Worker failed for {label}")

return result.stdout.decode()


def version_label(*, pytest_version: str | None = None, worktree: bool = False) -> str:
"""Human-readable label for a pytest source."""
return "worktree" if worktree else f"pytest=={pytest_version}"


def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument(
"--pytest-version",
metavar="VER",
help="Released pytest version (e.g. 8.0.0)",
)
src.add_argument(
"--worktree",
action="store_true",
help="Use the local worktree's src/",
)
parser.add_argument(
"--format",
dest="fmt",
choices=("source", "ast", "compact"),
default="source",
help="Output format (default: source)",
)
parser.add_argument("file", type=Path, help="Python file to rewrite")
return parser


def main(argv: list[str] | None = None) -> None:
args = build_parser().parse_args(argv)

if not args.file.is_file():
raise SystemExit(f"File not found: {args.file}")

output = run_worker(
pytest_version=args.pytest_version,
worktree=args.worktree,
file_content=args.file.read_bytes(),
fmt=args.fmt,
)
sys.stdout.write(output)
if output and not output.endswith("\n"):
sys.stdout.write("\n")


if __name__ == "__main__":
main()
Loading
Loading