Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
)
from mypy.checkpattern import PatternChecker
from mypy.constraints import SUPERTYPE_OF
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
from mypy.erasetype import (
erase_type,
erase_typevars,
remove_instance_last_known_values,
shallow_erase_type_for_equality,
)
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
from mypy.errors import (
ErrorInfo,
Expand All @@ -45,7 +50,7 @@
from mypy.expandtype import expand_type
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
from mypy.meet import is_overlapping_types, meet_types
from mypy.message_registry import ErrorMessage
from mypy.messages import (
SUGGESTED_TEST_FIXTURES,
Expand Down Expand Up @@ -6540,19 +6545,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
narrowable_indices={0},
)

# We only try and narrow away 'None' for now
if (
not is_unreachable_map(if_map)
and is_overlapping_none(item_type)
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)

if right_index in narrowable_operand_index_to_hash:
if_type, else_type = self.conditional_types_for_iterable(
item_type, iterable_type
Expand Down Expand Up @@ -6676,6 +6668,9 @@ def narrow_type_by_identity_equality(
target_type = operand_types[j]
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
# Type A[T1] could compare equal to A[T2] even if T1 is disjoint from T2
# e.g. cast(list[int], []) == cast(list[str], [])
target_type = shallow_erase_type_for_equality(target_type)

if (
# See comments in ambiguous_enum_equality_keys
Expand All @@ -6689,7 +6684,7 @@ def narrow_type_by_identity_equality(
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if is_target_for_value_narrowing(get_proper_type(target_type)):
if is_target_for_value_narrowing(target_type):
all_if_maps.append(if_map)
all_else_maps.append(else_map)
else:
Expand Down Expand Up @@ -6758,13 +6753,15 @@ def narrow_type_by_identity_equality(
target_type = operand_types[j]
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
target_type = shallow_erase_type_for_equality(target_type)

target = TypeRange(target_type, is_upper_bound=False)

if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target], default=expr_type)
)
or_if_maps.append(if_map)
if is_target_for_value_narrowing(get_proper_type(target_type)):
if is_target_for_value_narrowing(target_type):
or_else_maps.append(else_map)

all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps))
Expand Down Expand Up @@ -8609,13 +8606,7 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
return result


BUILTINS_CUSTOM_EQ_CHECKS: Final = {
"builtins.bytearray",
"builtins.memoryview",
"builtins.list",
"builtins.dict",
"builtins.set",
}
BUILTINS_CUSTOM_EQ_CHECKS: Final = {"builtins.bytearray", "builtins.memoryview"}


def has_custom_eq_checks(t: Type) -> bool:
Expand Down
14 changes: 14 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,17 @@ def visit_union_type(self, t: UnionType) -> Type:
merged.append(orig_item)
return UnionType.make_union(merged)
return new


def shallow_erase_type_for_equality(typ: Type) -> ProperType:
"""Erase type variables from Instance's inside a type."""
p_typ = get_proper_type(typ)
if isinstance(p_typ, Instance):
if not p_typ.args:
return p_typ
args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form)
return Instance(p_typ.type, args, p_typ.line)
if isinstance(p_typ, UnionType):
items = [shallow_erase_type_for_equality(item) for item in p_typ.items]
return UnionType.make_union(items)
return p_typ
13 changes: 0 additions & 13 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Callable

from mypy import join
from mypy.erasetype import erase_type
from mypy.maptype import map_instance_to_supertype
from mypy.state import state
from mypy.subtypes import (
Expand Down Expand Up @@ -657,18 +656,6 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
return False


def is_overlapping_erased_types(
left: Type, right: Type, *, ignore_promotions: bool = False
) -> bool:
"""The same as 'is_overlapping_erased_types', except the types are erased first."""
return is_overlapping_types(
erase_type(left),
erase_type(right),
ignore_promotions=ignore_promotions,
prohibit_none_typevar_overlap=True,
)


def are_typed_dicts_overlapping(
left: TypedDictType, right: TypedDictType, is_overlapping: Callable[[Type, Type], bool]
) -> bool:
Expand Down
86 changes: 86 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,92 @@ def f(x: Custom, y: CustomSub):
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
[builtins fixtures/tuple.pyi]

[case testNarrowingCustomEqualityGeneric]
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Union

class Custom:
def __eq__(self, other: object) -> bool:
raise

class Default: ...

def f1(x: list[Custom] | Default, y: list[int]):
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int]")
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
else:
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"

f1([], [])

def f2(x: list[Custom] | Default, y: list[int] | list[Default]):
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
else:
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"

listcustom_or_default = Union[list[Custom], Default]
listint_or_default = Union[list[int], list[Default]]

def f2_with_alias(x: listcustom_or_default, y: listint_or_default):
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
else:
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"

def f3(x: Custom | dict[str, str], y: dict[int, int]):
if x == y:
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
else:
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
[builtins fixtures/primitives.pyi]

[case testNarrowingRecursiveCallable]
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Callable

class A: ...
class B: ...

T = Callable[[A], "S"]
S = Callable[[B], "T"]

def f(x: S, y: T):
if x == y: # E: Unsupported left operand type for == ("Callable[[B], T]")
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
else:
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
[builtins fixtures/tuple.pyi]

[case testNarrowingRecursiveUnion]
# flags: --strict-equality --warn-unreachable
from __future__ import annotations
from typing import Union

class A: ...
class B: ...

T = Union[A, "S"]
S = Union[B, "T"] # E: Invalid recursive alias: a union item of itself

def f(x: S, y: T):
if x == y:
reveal_type(x) # N: Revealed type is "Any"
reveal_type(y) # N: Revealed type is "__main__.A | Any"
[builtins fixtures/tuple.pyi]

[case testNarrowingUnreachableCases]
# flags: --strict-equality --warn-unreachable
from typing import Literal, Union
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,9 @@ class B: pass

def f1(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[B]]):
if x in possibles:
reveal_type(x) # N: Revealed type is "tuple[__main__.B]"
# TODO: this branch is actually unreachable
# This is an easy fix: https://github.com/python/mypy/pull/20660
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
else:
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"

Expand Down