From fc3ccdeb350afbe3382c7f7051735e8fadb20dd6 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sat, 7 Feb 2026 16:10:43 -0800 Subject: [PATCH] Avoid narrowing NewType Fixes #20733 --- mypy/checker.py | 9 +++ test-data/unit/check-narrowing.test | 100 ++++++++++++++++++++++++++++ test-data/unit/check-newtype.test | 2 +- 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index b87d594acde67..8698c43a16980 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -8346,6 +8346,15 @@ def conditional_types( return proposed_type, remaining_type proposed_type = make_simplified_union([type_range.item for type_range in proposed_type_ranges]) + items = proposed_type.items if isinstance(proposed_type, UnionType) else [proposed_type] + for i in range(len(items)): + item = get_proper_type(items[i]) + # Avoid ever narrowing to a NewType. The principle is values of NewType should only be + # produce by explicit wrapping + while isinstance(item, Instance) and item.type.is_newtype: + item = item.type.bases[0] + items[i] = item + proposed_type = get_proper_type(UnionType.make_union(items)) if isinstance(proper_type, AnyType): return proposed_type, current_type diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 257e7c9d39e07..7481eb308aa38 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -3726,3 +3726,103 @@ def main( reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview" reveal_type(v_memoryview) # N: Revealed type is "builtins.memoryview" [builtins fixtures/primitives.pyi] + + +[case testNarrowNewTypeVsSubclass] +# mypy: strict-equality, warn-unreachable +from typing import NewType + +M1 = NewType("M1", int) +M2 = NewType("M2", int) + +def check_m(base: int, m1: M1, m2: M2): + if m1 == m2: # E: Non-overlapping equality check (left operand type: "M1", right operand type: "M2") + reveal_type(m1) # N: Revealed type is "__main__.M1" + reveal_type(m2) # N: Revealed type is "__main__.M2" + + if m1 == base: + # We do not narrow base + reveal_type(m1) # N: Revealed type is "__main__.M1" + reveal_type(base) # N: Revealed type is "builtins.int" + if m2 == base: + reveal_type(m2) # N: Revealed type is "__main__.M2" + reveal_type(base) # N: Revealed type is "builtins.int" + +# We do narrow for subclasses! (assuming no custom equality) +class A: ... +class A1(A): ... +class A2(A): ... + +def check_a(base: A, a1: A1, a2: A2): + if a1 == a2: # E: Non-overlapping equality check (left operand type: "A1", right operand type: "A2") + reveal_type(a1) # N: Revealed type is "__main__.A1" + reveal_type(a2) # N: Revealed type is "__main__.A2" + + if a1 == base: + # We do narrow base + reveal_type(a1) # N: Revealed type is "__main__.A1" + reveal_type(base) # N: Revealed type is "__main__.A1" + if a2 == base: # E: Non-overlapping equality check (left operand type: "A2", right operand type: "A1") + reveal_type(a2) # N: Revealed type is "__main__.A2" + reveal_type(base) # N: Revealed type is "__main__.A1" +[builtins fixtures/primitives.pyi] + + +[case testNarrowNewTypeFromObject] +# mypy: strict-equality, warn-unreachable +from __future__ import annotations +from typing import NewType + +UserId = NewType("UserId", int) + +def f1(whatever: object, uid: UserId): + # The general principle is that we should not be able to produce a value of NewType + # without there being explicit wrapping somewhere + if whatever == uid: + reveal_type(whatever) # N: Revealed type is "builtins.int" + reveal_type(uid) # N: Revealed type is "__main__.UserId" + +class Other: ... + +def f2(whatever: object, uid: UserId | Other): + if whatever == uid: + reveal_type(whatever) # N: Revealed type is "builtins.int | __main__.Other" + reveal_type(uid) # N: Revealed type is "__main__.UserId | __main__.Other" +[builtins fixtures/primitives.pyi] + + +[case testNarrowNewTypeNested] +# mypy: strict-equality, warn-unreachable +from typing import NewType, Final + +Path = NewType("Path", str) +NormPath = NewType("NormPath", Path) + +def op(normpath: NormPath, path: Path): + if normpath == path: + # No narrowing + reveal_type(normpath) # N: Revealed type is "__main__.NormPath" + reveal_type(path) # N: Revealed type is "__main__.Path" +[builtins fixtures/primitives.pyi] + + +[case testNarrowNewTypeSharedValue] +# mypy: strict-equality, warn-unreachable +from typing import NewType, Final + +UserId = NewType("UserId", int) +TeamId = NewType("TeamId", int) +INVALID = 123 + +def get_owner(uid: UserId, tid: TeamId): + # No narrowing for INVALID + if uid == INVALID: + reveal_type(uid) # N: Revealed type is "__main__.UserId" + reveal_type(tid) # N: Revealed type is "__main__.TeamId" + reveal_type(INVALID) # N: Revealed type is "builtins.int" + if tid == INVALID: + reveal_type(uid) # N: Revealed type is "__main__.UserId" + reveal_type(tid) # N: Revealed type is "__main__.TeamId" + reveal_type(INVALID) # N: Revealed type is "builtins.int" + return None +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-newtype.test b/test-data/unit/check-newtype.test index 1e2775b3aaa65..a0789596a4798 100644 --- a/test-data/unit/check-newtype.test +++ b/test-data/unit/check-newtype.test @@ -368,7 +368,7 @@ from typing import NewType T = NewType('T', int) d: object if isinstance(d, T): # E: Cannot use isinstance() with NewType type - reveal_type(d) # N: Revealed type is "__main__.T" + reveal_type(d) # N: Revealed type is "builtins.int" issubclass(object, T) # E: Cannot use issubclass() with NewType type [builtins fixtures/isinstancelist.pyi]