Skip to content
170 changes: 166 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@
CallExpr,
ClassDef,
ComparisonExpr,
ConditionalExpr,
Context,
ContinueStmt,
Decorator,
DelStmt,
DictExpr,
DictionaryComprehension,
EllipsisExpr,
Expression,
ExpressionStmt,
Expand All @@ -98,6 +100,7 @@
FuncBase,
FuncDef,
FuncItem,
GeneratorExpr,
GlobalDecl,
IfStmt,
Import,
Expand All @@ -107,6 +110,7 @@
IndexExpr,
IntExpr,
LambdaExpr,
ListComprehension,
ListExpr,
Lvalue,
MatchStmt,
Expand All @@ -124,7 +128,9 @@
RaiseStmt,
RefExpr,
ReturnStmt,
SetComprehension,
SetExpr,
SliceExpr,
StarExpr,
Statement,
StrExpr,
Expand Down Expand Up @@ -4928,6 +4934,8 @@ def infer_context_dependent(
return typ

# If there are errors with the original type context, try re-inferring in empty context.
# However, skip this fallback if the expression contains assignment expressions (walrus
# operator), as they can cause incorrect type inference when the context is removed.
original_messages = msg.filtered_errors()
original_type_map = type_map
with self.msg.filter_errors(
Expand All @@ -4937,7 +4945,12 @@ def infer_context_dependent(
alt_typ = get_proper_type(
self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call)
)
if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx):

if (
not msg.has_new_errors()
and is_subtype(alt_typ, type_ctx)
and not self.contains_assignment_expr(expr)
):
self.store_types(type_map)
return alt_typ

Expand Down Expand Up @@ -4979,7 +4992,10 @@ def check_return_stmt(self, s: ReturnStmt) -> None:

# Return with a value.
if (
isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr))
isinstance(
s.expr,
(CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AssignmentExpr),
)
or isinstance(s.expr, AwaitExpr)
and isinstance(s.expr.expr, CallExpr)
):
Expand Down Expand Up @@ -5057,6 +5073,125 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if self.in_checked_function():
self.fail(message_registry.RETURN_VALUE_EXPECTED, s)

def contains_assignment_expr(self, expr: Expression) -> bool:
"""Check if expression contains any AssignmentExpr (walrus operator)."""
# Base case: found an assignment expression
if isinstance(expr, AssignmentExpr):
return True

# Recursively check nested expressions in various expression types

# Container expressions
if isinstance(expr, (TupleExpr, ListExpr, SetExpr)):
return any(self.contains_assignment_expr(item) for item in expr.items)

if isinstance(expr, DictExpr):
# Check both keys and values
# DictExpr.items is list[tuple[Expression | None, Expression]]
for key_expr, value_expr in expr.items:
if key_expr is not None and self.contains_assignment_expr(key_expr):
return True
if self.contains_assignment_expr(value_expr):
return True
return False

# Binary operations (left and right operands)
if isinstance(expr, OpExpr):
return self.contains_assignment_expr(expr.left) or self.contains_assignment_expr(
expr.right
)

# Unary operations
if isinstance(expr, UnaryExpr):
return self.contains_assignment_expr(expr.expr)

# Comparison expressions (multiple operands)
if isinstance(expr, ComparisonExpr):
return any(self.contains_assignment_expr(operand) for operand in expr.operands)

# Function calls (check arguments)
if isinstance(expr, CallExpr):
# Check callee and all arguments
if self.contains_assignment_expr(expr.callee):
return True
return any(self.contains_assignment_expr(arg) for arg in expr.args)

# Index expressions (subscripts)
if isinstance(expr, IndexExpr):
if self.contains_assignment_expr(expr.base):
return True
return self.contains_assignment_expr(expr.index)

# Member access
if isinstance(expr, MemberExpr):
return self.contains_assignment_expr(expr.expr)

# Starred expressions (unpacking)
if isinstance(expr, StarExpr):
return self.contains_assignment_expr(expr.expr)

# Await expressions
if isinstance(expr, AwaitExpr):
return self.contains_assignment_expr(expr.expr)

# Yield expressions
if isinstance(expr, YieldExpr):
if expr.expr is not None:
return self.contains_assignment_expr(expr.expr)
return False

# Conditional expressions (ternary operator: x if cond else y)
if isinstance(expr, ConditionalExpr):
return (
self.contains_assignment_expr(expr.cond)
or self.contains_assignment_expr(expr.if_expr)
or self.contains_assignment_expr(expr.else_expr)
)

# Slice expressions (x:y:z)
if isinstance(expr, SliceExpr):
return (
(expr.begin_index is not None and self.contains_assignment_expr(expr.begin_index))
or (expr.end_index is not None and self.contains_assignment_expr(expr.end_index))
or (expr.stride is not None and self.contains_assignment_expr(expr.stride))
)

# Generator expressions and comprehensions
if isinstance(expr, GeneratorExpr):
if self.contains_assignment_expr(expr.left_expr):
return True
for seq in expr.sequences:
if self.contains_assignment_expr(seq):
return True
for condlist in expr.condlists:
for cond in condlist:
if self.contains_assignment_expr(cond):
return True
return False

if isinstance(expr, ListComprehension):
return self.contains_assignment_expr(expr.generator)

if isinstance(expr, SetComprehension):
return self.contains_assignment_expr(expr.generator)

if isinstance(expr, DictionaryComprehension):
if self.contains_assignment_expr(expr.key) or self.contains_assignment_expr(
expr.value
):
return True
for seq in expr.sequences:
if self.contains_assignment_expr(seq):
return True
for condlist in expr.condlists:
for cond in condlist:
if self.contains_assignment_expr(cond):
return True
return False

# All other expression types (NameExpr, IntExpr, StrExpr, etc.) don't contain nested expressions
return False

def visit_if_stmt(self, s: IfStmt) -> None:
"""Type check an if statement."""
# This frame records the knowledge from previous if/elif clauses not being taken.
Expand Down Expand Up @@ -6114,8 +6249,35 @@ def conditional_callable_type_map(
if not current_type:
return {}, {}

if isinstance(get_proper_type(current_type), AnyType):
return {}, {}
proper_type = get_proper_type(current_type)
if isinstance(proper_type, AnyType):
# Narrow Any to a generic callable type to satisfy no-any-return in strict mode.
# We use a synthesized fallback "<any callable fallback>" to preserve attribute
# access (fixing regressions in sympy, pandas, etc.) without triggering
# metaclass-related internal errors or breaking invariant subtyping.
obj_fallback = self.named_type("builtins.object")
if obj_fallback.type.fullname == "builtins.object":
cdef = nodes.ClassDef("<any callable fallback>", nodes.Block([]))
cdef._fullname = "<any callable fallback>"
info = TypeInfo(nodes.SymbolTable(), cdef, "")
info.mro = obj_fallback.type.mro
info.bases = obj_fallback.type.bases
info.fallback_to_any = True
fallback_instance = Instance(info, [])

return {
expr: CallableType(
[
AnyType(TypeOfAny.from_another_any, source_any=proper_type),
AnyType(TypeOfAny.from_another_any, source_any=proper_type),
],
[nodes.ARG_STAR, nodes.ARG_STAR2],
[None, None],
ret_type=AnyType(TypeOfAny.from_another_any, source_any=proper_type),
fallback=fallback_instance,
is_ellipsis_args=True,
)
}, {}

callables, uncallables = self.partition_by_callable(current_type, unsound_partition=False)

Expand Down
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3829,7 +3829,9 @@ def is_metaclass(self, *, precise: bool = False) -> bool:
return (
self.has_base("builtins.type")
or self.fullname == "abc.ABCMeta"
or (self.fallback_to_any and not precise)
or (
self.fallback_to_any and not precise and self.fullname != "<any callable fallback>"
)
)

def has_base(self, fullname: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ from typing import Any
x = 5 # type: Any

if callable(x):
reveal_type(x) # N: Revealed type is "Any"
reveal_type(x) # N: Revealed type is "def (*Any, **Any) -> Any"
else:
reveal_type(x) # N: Revealed type is "Any"
[builtins fixtures/callable.pyi]
Expand Down