Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 42 additions & 47 deletions src/_pytest/assertion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Set as AbstractSet
import dataclasses
import pprint
from typing import Any
from typing import Literal
from typing import Protocol
from typing import TypeGuard
from unicodedata import normalize

from _pytest import outcomes
Expand Down Expand Up @@ -118,45 +119,42 @@ def _format_lines(lines: Sequence[str]) -> list[str]:
return result


def issequence(x: Any) -> bool:
def issequence(x: object) -> TypeGuard[collections.abc.Sequence[object]]:
return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)


def istext(x: Any) -> bool:
def istext(x: object) -> TypeGuard[str]:
return isinstance(x, str)


def isdict(x: Any) -> bool:
def isdict(x: object) -> TypeGuard[dict[object, object]]:
return isinstance(x, dict)


def isset(x: Any) -> bool:
def isset(x: object) -> TypeGuard[set[object] | frozenset[object]]:
return isinstance(x, set | frozenset)


def isnamedtuple(obj: Any) -> bool:
def isnamedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None


def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
isdatacls = dataclasses.is_dataclass


def isattrs(obj: Any) -> bool:
def isattrs(obj: object) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None


def isiterable(obj: Any) -> bool:
def isiterable(obj: object) -> TypeGuard[collections.abc.Iterable[object]]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should take object|Iterable[object]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterable[object] is a subtype of object, so isn't this equivalent to object?

try:
iter(obj)
iter(obj) # type: ignore[call-overload]
return not istext(obj)
except Exception:
return False


def has_default_eq(
obj: object,
) -> bool:
def has_default_eq(obj: object) -> bool:
"""Check if an instance of an object contains the default eq

First, we check if the object's __eq__ attribute has __code__,
Expand All @@ -176,7 +174,7 @@ def has_default_eq(


def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
config: Config, op: str, left: object, right: object
) -> list[str] | None:
"""Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
Expand Down Expand Up @@ -246,20 +244,19 @@ def assertrepr_compare(


def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
left: object, right: object, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, highlighter, verbose)
else:
from _pytest.python_api import ApproxBase

if isinstance(left, ApproxBase) or isinstance(right, ApproxBase):
# Although the common order should be obtained == expected, this ensures both ways
approx_side = left if isinstance(left, ApproxBase) else right
other_side = right if isinstance(left, ApproxBase) else left

explanation = approx_side._repr_compare(other_side)
# Although the common order should be obtained == approx(...), allow both ways.
if isinstance(right, ApproxBase):
explanation = right._repr_compare(left)
elif isinstance(left, ApproxBase):
explanation = left._repr_compare(right)
elif type(left) is type(right) and (
isdatacls(left) or isattrs(left) or isnamedtuple(left)
):
Expand Down Expand Up @@ -338,8 +335,8 @@ def _diff_text(


def _compare_eq_iterable(
left: Iterable[Any],
right: Iterable[Any],
left: Iterable[object],
right: Iterable[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand Down Expand Up @@ -367,8 +364,8 @@ def _compare_eq_iterable(


def _compare_eq_sequence(
left: Sequence[Any],
right: Sequence[Any],
left: Sequence[object],
right: Sequence[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand All @@ -387,8 +384,8 @@ def _compare_eq_sequence(
# 102
# >>> s[0:1]
# b'f'
left_value = left[i : i + 1]
right_value = right[i : i + 1]
left_value: object = left[i : i + 1]
right_value: object = right[i : i + 1]
else:
left_value = left[i]
right_value = right[i]
Expand Down Expand Up @@ -427,8 +424,8 @@ def _compare_eq_sequence(


def _compare_eq_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
left: AbstractSet[object],
right: AbstractSet[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand All @@ -439,8 +436,8 @@ def _compare_eq_set(


def _compare_gt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
left: AbstractSet[object],
right: AbstractSet[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand All @@ -451,8 +448,8 @@ def _compare_gt_set(


def _compare_lt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
left: AbstractSet[object],
right: AbstractSet[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand All @@ -463,17 +460,17 @@ def _compare_lt_set(


def _compare_gte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
left: AbstractSet[object],
right: AbstractSet[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter)


def _compare_lte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
left: AbstractSet[object],
right: AbstractSet[object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand All @@ -482,8 +479,8 @@ def _compare_lte_set(

def _set_one_sided_diff(
posn: str,
set1: AbstractSet[Any],
set2: AbstractSet[Any],
set1: AbstractSet[object],
set2: AbstractSet[object],
highlighter: _HighlightFunc,
) -> list[str]:
explanation = []
Expand All @@ -496,8 +493,8 @@ def _set_one_sided_diff(


def _compare_eq_dict(
left: Mapping[Any, Any],
right: Mapping[Any, Any],
left: Mapping[object, object],
right: Mapping[object, object],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
Expand Down Expand Up @@ -542,20 +539,18 @@ def _compare_eq_dict(


def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
left: object, right: object, highlighter: _HighlightFunc, verbose: int
) -> list[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
import dataclasses

all_fields = dataclasses.fields(left)
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
all_fields = left.__attrs_attrs__ # type: ignore[attr-defined]
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
elif isnamedtuple(left):
fields_to_check = left._fields
fields_to_check = left._fields # type: ignore[attr-defined]
else:
assert False

Expand Down
13 changes: 13 additions & 0 deletions testing/python/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,19 @@ def test_approx_on_unordered_mapping_matching():
result = pytester.runpytest()
result.assert_outcomes(passed=1)

def test_assertion_rewriting_works_with_approx_on_lhs(
self, pytestconfig: pytest.Config
) -> None:
"""Assertion rewriting works also when approx() is on the left-hand side."""
with temporary_verbosity(pytestconfig, verbosity=0):
with pytest.raises(AssertionError) as e:
assert pytest.approx(1) == 2
obtained_message = str(e.value).splitlines()[-2:]
assert obtained_message == [
" Obtained: 2",
" Expected: 1 ± 1.0e-06",
]


class MyVec3: # incomplete
"""sequence like"""
Expand Down