diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 1593c73cd2bfe..75995822417b8 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -734,6 +734,11 @@ def _parse_converter( converter_type = converter_expr.node.type elif isinstance(converter_expr.node, TypeInfo): converter_type = type_object_type(converter_expr.node) + elif isinstance(converter_expr.node, Var) and converter_expr.node.type: + # The converter is a variable annotated with a callable type. + var_type = get_proper_type(converter_expr.node.type) + if isinstance(var_type, FunctionLike): + converter_type = var_type elif ( isinstance(converter_expr, IndexExpr) and isinstance(converter_expr.analyzed, TypeApplication) @@ -751,6 +756,10 @@ def _parse_converter( ) else: converter_type = None + elif isinstance(converter_expr, CallExpr): + # The converter is the result of a call, e.g. `converter=make_converter(arg)`. + # Use the return type of the callee as the converter type. + converter_type = _callable_return_type(converter_expr) if isinstance(converter_expr, LambdaExpr): # TODO: should we send a fail if converter_expr.min_args > 1? @@ -794,6 +803,44 @@ def _parse_converter( return converter_info +def _callable_return_type(call: CallExpr) -> Type | None: + """Return the return type of `call` if it is statically known to be callable. + + This is used to support converters created by higher-order functions, e.g. + `converter=make_converter(arg)`. We don't perform full type inference at the + call site; we just look at the statically declared return type of the callee. + Generic returns are returned as-is and may contain unresolved type variables. + """ + callee = call.callee + callee_type: Type | None = None + if isinstance(callee, RefExpr) and callee.node: + if isinstance(callee.node, (FuncDef, OverloadedFuncDef)): + callee_type = callee.node.type + elif isinstance(callee.node, Var): + callee_type = callee.node.type + elif isinstance(callee, CallExpr): + # Chained calls like `factory()(arg)`. + callee_type = _callable_return_type(callee) + if callee_type is None: + return None + callee_type = get_proper_type(callee_type) + if isinstance(callee_type, CallableType): + ret = get_proper_type(callee_type.ret_type) + if isinstance(ret, FunctionLike): + return ret + elif isinstance(callee_type, Overloaded): + # Without type inference at the call site we can't pick the correct + # overload. As a heuristic, take the first overload whose return type is + # itself a callable. This matches helpers like `attrs.converters.pipe` + # and `attrs.converters.default_if_none`, whose first overload is the + # most specific callable form. + for item in callee_type.items: + ret = get_proper_type(item.ret_type) + if isinstance(ret, FunctionLike): + return ret + return None + + def is_valid_overloaded_converter(defn: OverloadedFuncDef) -> bool: return all( (not isinstance(item, Decorator) or isinstance(item.func.type, FunctionLike)) diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index 5e6dd4d83ce02..fbede1eaa0f7f 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -942,6 +942,76 @@ class C: reveal_type(C) # N: Revealed type is "def (x: Any, y: Any, z: Any) -> __main__.C" [builtins fixtures/list.pyi] +[case testAttrsUsingHigherOrderConverter] +# Regression test for https://github.com/python/mypy/issues/15736 +from typing import Any, Callable +from attrs import define, field + +def make_converter(_length: int) -> Callable[[str], str]: + def converter(val: str) -> str: + return val + return converter + +def make_untyped_converter(_length: int) -> Callable[[Any], Any]: + def f(val: Any) -> Any: + return val + return f + +@define +class C: + a: str = field(converter=make_converter(40)) + b: str = field(converter=make_untyped_converter(40)) + +reveal_type(C) # N: Revealed type is "def (a: builtins.str, b: Any) -> __main__.C" +reveal_type(C("hi", 5).a) # N: Revealed type is "builtins.str" +[builtins fixtures/list.pyi] + +[case testAttrsUsingCallableVariableConverter] +from typing import Callable +from attrs import define, field + +def to_str(x: int) -> str: + return "" +my_converter: Callable[[int], str] = to_str + +@define +class C: + x: str = field(converter=my_converter) + +reveal_type(C) # N: Revealed type is "def (x: builtins.int) -> __main__.C" +reveal_type(C(15).x) # N: Revealed type is "builtins.str" +[builtins fixtures/list.pyi] + +[case testAttrsUsingHigherOrderConverterChainedCall] +from typing import Callable +from attrs import define, field + +def outer() -> Callable[[int], Callable[[str], str]]: + def middle(_n: int) -> Callable[[str], str]: + def inner(v: str) -> str: + return v + return inner + return middle + +@define +class C: + x: str = field(converter=outer()(40)) + +reveal_type(C) # N: Revealed type is "def (x: builtins.str) -> __main__.C" +[builtins fixtures/list.pyi] + +[case testAttrsUsingDefaultIfNoneConverter] +from typing import Optional +from attrs import define, field +from attrs.converters import default_if_none + +@define +class C: + x: int = field(default=None, converter=default_if_none(0)) + +reveal_type(C) # N: Revealed type is "def (x: Any =) -> __main__.C" +[builtins fixtures/plugin_attrs.pyi] + [case testAttrsUsingConverterAndSubclass] import attr