Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 205 additions & 11 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,19 +1798,16 @@ def check_callable_call(

arg_types = self.infer_arg_types_in_context(callee, args, arg_kinds, formal_to_actual)

self.check_argument_count(
self.check_call_arguments(
callee,
arg_types,
arg_kinds,
arg_names,
args,
formal_to_actual,
context,
object_type,
callable_name,
)

self.check_argument_types(
arg_types, arg_kinds, args, callee, formal_to_actual, context, object_type=object_type
object_type,
)

if (
Expand Down Expand Up @@ -2340,6 +2337,198 @@ def apply_inferred_arguments(
# arguments.
return self.apply_generic_arguments(callee_type, inferred_args, context)

def check_call_arguments(
self,
callee: CallableType,
arg_types: list[Type],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
args: list[Expression],
formal_to_actual: list[list[int]],
context: Context,
callable_name: str | None,
object_type: Type | None,
) -> None:
"""Check argument count and types, consolidating errors for missing positional args."""
with self.msg.filter_errors():
_, missing_positional = self.check_argument_count(
callee,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context,
object_type,
callable_name,
)

if missing_positional:
func_name = callable_name or callee.name or "function"
if "." in func_name:
func_name = func_name.split(".")[-1]

shift_info = None
num_positional_args = sum(1 for k in arg_kinds if k == nodes.ARG_POS)
if num_positional_args >= 2:
shift_info = self.detect_shifted_positional_args(
callee, arg_types, arg_kinds, missing_positional
)

with self.msg.filter_errors() as type_error_watcher:
self.check_argument_types(
arg_types,
arg_kinds,
args,
callee,
formal_to_actual,
context,
object_type=object_type,
)
has_type_errors = type_error_watcher.has_new_errors()

if shift_info is not None:
_, param_name, expected_type, high_confidence = shift_info
if high_confidence and param_name:
type_str = format_type(expected_type, self.chk.options)
self.msg.fail(
f'Expected {type_str} for parameter "{param_name}"; '
f'did you forget argument "{param_name}"?',
context,
code=codes.CALL_ARG,
)
else:
self.msg.fail(
f'Incompatible arguments for "{func_name}"; check for missing arguments',
context,
code=codes.CALL_ARG,
)
elif has_type_errors:
self.msg.fail(
f'Incompatible arguments for "{func_name}"; check for missing arguments',
context,
code=codes.CALL_ARG,
)
else:
self.check_argument_count(
callee,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context,
object_type,
callable_name,
)
else:
self.check_argument_count(
callee,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context,
object_type,
callable_name,
)
self.check_argument_types(
arg_types,
arg_kinds,
args,
callee,
formal_to_actual,
context,
object_type=object_type,
)

def detect_shifted_positional_args(
self,
callee: CallableType,
actual_types: list[Type],
actual_kinds: list[ArgKind],
missing_positional: list[int],
) -> tuple[int, str | None, Type, bool] | None:
"""Detect if positional arguments are shifted due to a missing argument.

Returns (1-indexed position, param name, expected type, high_confidence) if a
shift pattern is found, None otherwise. High confidence is set when the function
has fixed parameters (no defaults, *args, or **kwargs).
"""
if not missing_positional:
return None

has_star_args = any(k == nodes.ARG_STAR for k in callee.arg_kinds)
has_star_kwargs = any(k == nodes.ARG_STAR2 for k in callee.arg_kinds)
has_defaults = any(k == nodes.ARG_OPT for k in callee.arg_kinds)
single_missing = len(missing_positional) == 1
high_confidence = (
single_missing and not has_star_args and not has_star_kwargs and not has_defaults
)

positional_actual_types = [
actual_types[i] for i, k in enumerate(actual_kinds) if k == nodes.ARG_POS
]
if len(positional_actual_types) < 2:
return None

positional_formal_types: list[Type] = []
positional_formal_names: list[str | None] = []
for i, kind in enumerate(callee.arg_kinds):
if kind.is_positional():
positional_formal_types.append(callee.arg_types[i])
positional_formal_names.append(callee.arg_names[i])

# Find first position where arg doesn't match but would match next position
shift_position = None
for i, actual_type in enumerate(positional_actual_types):
if i >= len(positional_formal_types):
break
if is_subtype(actual_type, positional_formal_types[i], options=self.chk.options):
continue
next_idx = i + 1
if next_idx >= len(positional_formal_types):
break
if is_subtype(
actual_type, positional_formal_types[next_idx], options=self.chk.options
):
shift_position = i
break
else:
break

if shift_position is None:
return None

# Validate that all args would match if we inserted one at shift_position
if not self._validate_shift_insertion(
positional_actual_types, positional_formal_types, shift_position
):
return None

return (
shift_position + 1,
positional_formal_names[shift_position],
positional_formal_types[shift_position],
high_confidence,
)

def _validate_shift_insertion(
self, actual_types: list[Type], formal_types: list[Type], insert_position: int
) -> bool:
"""Check if inserting an argument at insert_position would fix type errors."""
for i, actual_type in enumerate(actual_types):
if i < insert_position:
if i >= len(formal_types):
return False
expected = formal_types[i]
else:
shifted_idx = i + 1
if shifted_idx >= len(formal_types):
return False
expected = formal_types[shifted_idx]
if not is_subtype(actual_type, expected, options=self.chk.options):
return False
return True

def check_argument_count(
self,
callee: CallableType,
Expand All @@ -2350,13 +2539,15 @@ def check_argument_count(
context: Context | None,
object_type: Type | None = None,
callable_name: str | None = None,
) -> bool:
) -> tuple[bool, list[int]]:
"""Check that there is a value for all required arguments to a function.

Also check that there are no duplicate values for arguments. Report found errors
using 'messages' if it's not None. If 'messages' is given, 'context' must also be given.

Return False if there were any errors. Otherwise return True
Return a tuple of:
- False if there were any errors, True otherwise
- List of formal argument indices that are missing positional arguments
"""
if context is None:
# Avoid "is None" checks
Expand All @@ -2374,12 +2565,15 @@ def check_argument_count(
callee, actual_types, actual_kinds, actual_names, all_actuals, context
)

missing_positional: list[int] = []

# Check for too many or few values for formals.
for i, kind in enumerate(callee.arg_kinds):
mapped_args = formal_to_actual[i]
if kind.is_required() and not mapped_args and not is_unexpected_arg_error:
# No actual for a mandatory formal
if kind.is_positional():
missing_positional.append(i)
self.msg.too_few_arguments(callee, context, actual_names)
if object_type and callable_name and "." in callable_name:
self.missing_classvar_callable_note(object_type, callable_name, context)
Expand Down Expand Up @@ -2418,7 +2612,7 @@ def check_argument_count(
if actual_kinds[mapped_args[0]] == nodes.ARG_STAR2 and paramspec_entries > 1:
self.msg.fail("ParamSpec.kwargs should only be passed once", context)
ok = False
return ok
return ok, missing_positional

def check_for_extra_actual_arguments(
self,
Expand Down Expand Up @@ -2878,7 +3072,7 @@ def has_shape(typ: Type) -> bool:
matches.append(typ)
elif self.check_argument_count(
typ, arg_types, arg_kinds, arg_names, formal_to_actual, None
):
)[0]:
if args_have_var_arg and typ.is_var_arg:
star_matches.append(typ)
elif args_have_kw_arg and typ.is_kw_arg:
Expand Down Expand Up @@ -3251,7 +3445,7 @@ def erased_signature_similarity(
with self.msg.filter_errors():
if not self.check_argument_count(
callee, arg_types, arg_kinds, arg_names, formal_to_actual, None
):
)[0]:
# Too few or many arguments -> no match.
return False

Expand Down
48 changes: 48 additions & 0 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -3703,3 +3703,51 @@ foo(*args) # E: Argument 1 to "foo" has incompatible type "*list[object]"; expe
kwargs: dict[str, object]
foo(**kwargs) # E: Argument 1 to "foo" has incompatible type "**dict[str, object]"; expected "P"
[builtins fixtures/dict.pyi]

[case testMissingPositionalArgumentShiftedTypes]
def f(x: int, y: str, z: bytes, aa: int) -> None: ...

f(1, b'x', 1)
[builtins fixtures/primitives.pyi]
[out]
main:3: error: Expected "str" for parameter "y"; did you forget argument "y"?

[case testMissingPositionalArgumentShiftedTypesFirstArg]
def f(x: int, y: str, z: bytes) -> None: ...

f("hello", b'x')
[builtins fixtures/primitives.pyi]
[out]
main:3: error: Expected "int" for parameter "x"; did you forget argument "x"?

[case testMissingPositionalArgumentNoShift]
def f(x: int, y: str, z: bytes) -> None: ...

f("wrong", 123)
[builtins fixtures/primitives.pyi]
[out]
main:3: error: Incompatible arguments for "f"; check for missing arguments

[case testMissingPositionalArgumentShiftedTypesManyArgs]
def f(a: int, b: str, c: float, d: list[int], e: tuple[str, ...]) -> None: ...

f(1, 1.5, [1, 2, 3], ("a", "b"))
[builtins fixtures/list.pyi]
[out]
main:3: error: Expected "str" for parameter "b"; did you forget argument "b"?

[case testMissingPositionalArgumentShiftedWithDefaults]
def f(x: int, y: str, z: bytes = b'default') -> None: ...

f("hello")
[builtins fixtures/primitives.pyi]
[out]
main:3: error: Incompatible arguments for "f"; check for missing arguments

[case testMissingPositionalArgumentShiftedWithStarArgs]
def f(x: int, y: str, z: bytes, *args: int) -> None: ...

f("hello", b'x')
[builtins fixtures/primitives.pyi]
[out]
main:3: error: Incompatible arguments for "f"; check for missing arguments
Loading