From 139b2a9e3df11b07ffb7ed54c9d222d3889adad4 Mon Sep 17 00:00:00 2001 From: KevinRK29 Date: Fri, 16 Jan 2026 02:53:04 -0500 Subject: [PATCH 1/2] Detect missing positional args and suggest argument in error message --- mypy/checkexpr.py | 196 +++++++++++++++++++++++++--- test-data/unit/check-functions.test | 48 +++++++ 2 files changed, 226 insertions(+), 18 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9990caaeb7a1..2ff78ae2525d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1798,19 +1798,9 @@ def check_callable_call( arg_types = self.infer_arg_types_in_context(callee, args, arg_kinds, formal_to_actual) - self.check_argument_count( - callee, - arg_types, - arg_kinds, - arg_names, - formal_to_actual, - context, - object_type, - callable_name, - ) - - self.check_argument_types( - arg_types, arg_kinds, args, callee, formal_to_actual, context, object_type=object_type + self.check_call_arguments( + callee, arg_types, arg_kinds, arg_names, args, + formal_to_actual, context, callable_name, object_type, ) if ( @@ -2340,6 +2330,171 @@ def apply_inferred_arguments( # arguments. return self.apply_generic_arguments(callee_type, inferred_args, context) + def check_call_arguments( + self, + callee: CallableType, + arg_types: list[Type], + arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, + args: list[Expression], + formal_to_actual: list[list[int]], + context: Context, + callable_name: str | None, + object_type: Type | None, + ) -> None: + """Check argument count and types, consolidating errors for missing positional args.""" + with self.msg.filter_errors(): + _, missing_positional = self.check_argument_count( + callee, arg_types, arg_kinds, arg_names, formal_to_actual, + context, object_type, callable_name, + ) + + if missing_positional: + func_name = callable_name or callee.name or "function" + if "." in func_name: + func_name = func_name.split(".")[-1] + + shift_info = None + num_positional_args = sum(1 for k in arg_kinds if k == nodes.ARG_POS) + if num_positional_args >= 2: + shift_info = self.detect_shifted_positional_args( + callee, arg_types, arg_kinds, missing_positional + ) + + with self.msg.filter_errors() as type_error_watcher: + self.check_argument_types( + arg_types, arg_kinds, args, callee, formal_to_actual, context, + object_type=object_type + ) + has_type_errors = type_error_watcher.has_new_errors() + + if shift_info is not None: + _, param_name, expected_type, high_confidence = shift_info + if high_confidence and param_name: + type_str = format_type(expected_type, self.chk.options) + self.msg.fail( + f'Expected {type_str} for parameter "{param_name}"; ' + f'did you forget argument "{param_name}"?', + context, + code=codes.CALL_ARG, + ) + else: + self.msg.fail( + f'Incompatible arguments for "{func_name}"; check for missing arguments', + context, + code=codes.CALL_ARG, + ) + elif has_type_errors: + self.msg.fail( + f'Incompatible arguments for "{func_name}"; check for missing arguments', + context, + code=codes.CALL_ARG, + ) + else: + self.check_argument_count( + callee, arg_types, arg_kinds, arg_names, formal_to_actual, + context, object_type, callable_name, + ) + else: + self.check_argument_count( + callee, arg_types, arg_kinds, arg_names, formal_to_actual, + context, object_type, callable_name, + ) + self.check_argument_types( + arg_types, arg_kinds, args, callee, formal_to_actual, context, + object_type=object_type + ) + + def detect_shifted_positional_args( + self, + callee: CallableType, + actual_types: list[Type], + actual_kinds: list[ArgKind], + missing_positional: list[int], + ) -> tuple[int, str | None, Type, bool] | None: + """Detect if positional arguments are shifted due to a missing argument. + + Returns (1-indexed position, param name, expected type, high_confidence) if a + shift pattern is found, None otherwise. High confidence is set when the function + has fixed parameters (no defaults, *args, or **kwargs). + """ + if not missing_positional: + return None + + has_star_args = any(k == nodes.ARG_STAR for k in callee.arg_kinds) + has_star_kwargs = any(k == nodes.ARG_STAR2 for k in callee.arg_kinds) + has_defaults = any(k == nodes.ARG_OPT for k in callee.arg_kinds) + single_missing = len(missing_positional) == 1 + high_confidence = ( + single_missing and not has_star_args and not has_star_kwargs and not has_defaults + ) + + positional_actual_types = [ + actual_types[i] for i, k in enumerate(actual_kinds) if k == nodes.ARG_POS + ] + if len(positional_actual_types) < 2: + return None + + positional_formal_types: list[Type] = [] + positional_formal_names: list[str | None] = [] + for i, kind in enumerate(callee.arg_kinds): + if kind.is_positional(): + positional_formal_types.append(callee.arg_types[i]) + positional_formal_names.append(callee.arg_names[i]) + + # Find first position where arg doesn't match but would match next position + shift_position = None + for i, actual_type in enumerate(positional_actual_types): + if i >= len(positional_formal_types): + break + if is_subtype(actual_type, positional_formal_types[i], options=self.chk.options): + continue + next_idx = i + 1 + if next_idx >= len(positional_formal_types): + break + if is_subtype(actual_type, positional_formal_types[next_idx], options=self.chk.options): + shift_position = i + break + else: + break + + if shift_position is None: + return None + + # Validate that all args would match if we inserted one at shift_position + if not self._validate_shift_insertion( + positional_actual_types, positional_formal_types, shift_position + ): + return None + + return ( + shift_position + 1, + positional_formal_names[shift_position], + positional_formal_types[shift_position], + high_confidence, + ) + + def _validate_shift_insertion( + self, + actual_types: list[Type], + formal_types: list[Type], + insert_position: int, + ) -> bool: + """Check if inserting an argument at insert_position would fix type errors.""" + for i, actual_type in enumerate(actual_types): + if i < insert_position: + if i >= len(formal_types): + return False + expected = formal_types[i] + else: + shifted_idx = i + 1 + if shifted_idx >= len(formal_types): + return False + expected = formal_types[shifted_idx] + if not is_subtype(actual_type, expected, options=self.chk.options): + return False + return True + def check_argument_count( self, callee: CallableType, @@ -2350,13 +2505,15 @@ def check_argument_count( context: Context | None, object_type: Type | None = None, callable_name: str | None = None, - ) -> bool: + ) -> tuple[bool, list[int]]: """Check that there is a value for all required arguments to a function. Also check that there are no duplicate values for arguments. Report found errors using 'messages' if it's not None. If 'messages' is given, 'context' must also be given. - Return False if there were any errors. Otherwise return True + Return a tuple of: + - False if there were any errors, True otherwise + - List of formal argument indices that are missing positional arguments """ if context is None: # Avoid "is None" checks @@ -2374,12 +2531,15 @@ def check_argument_count( callee, actual_types, actual_kinds, actual_names, all_actuals, context ) + missing_positional: list[int] = [] + # Check for too many or few values for formals. for i, kind in enumerate(callee.arg_kinds): mapped_args = formal_to_actual[i] if kind.is_required() and not mapped_args and not is_unexpected_arg_error: # No actual for a mandatory formal if kind.is_positional(): + missing_positional.append(i) self.msg.too_few_arguments(callee, context, actual_names) if object_type and callable_name and "." in callable_name: self.missing_classvar_callable_note(object_type, callable_name, context) @@ -2418,7 +2578,7 @@ def check_argument_count( if actual_kinds[mapped_args[0]] == nodes.ARG_STAR2 and paramspec_entries > 1: self.msg.fail("ParamSpec.kwargs should only be passed once", context) ok = False - return ok + return ok, missing_positional def check_for_extra_actual_arguments( self, @@ -2878,7 +3038,7 @@ def has_shape(typ: Type) -> bool: matches.append(typ) elif self.check_argument_count( typ, arg_types, arg_kinds, arg_names, formal_to_actual, None - ): + )[0]: if args_have_var_arg and typ.is_var_arg: star_matches.append(typ) elif args_have_kw_arg and typ.is_kw_arg: @@ -3251,7 +3411,7 @@ def erased_signature_similarity( with self.msg.filter_errors(): if not self.check_argument_count( callee, arg_types, arg_kinds, arg_names, formal_to_actual, None - ): + )[0]: # Too few or many arguments -> no match. return False diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index b54dffe836b8..c4dff9028396 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3703,3 +3703,51 @@ foo(*args) # E: Argument 1 to "foo" has incompatible type "*list[object]"; expe kwargs: dict[str, object] foo(**kwargs) # E: Argument 1 to "foo" has incompatible type "**dict[str, object]"; expected "P" [builtins fixtures/dict.pyi] + +[case testMissingPositionalArgumentShiftedTypes] +def f(x: int, y: str, z: bytes, aa: int) -> None: ... + +f(1, b'x', 1) +[builtins fixtures/primitives.pyi] +[out] +main:3: error: Expected "str" for parameter "y"; did you forget argument "y"? + +[case testMissingPositionalArgumentShiftedTypesFirstArg] +def f(x: int, y: str, z: bytes) -> None: ... + +f("hello", b'x') +[builtins fixtures/primitives.pyi] +[out] +main:3: error: Expected "int" for parameter "x"; did you forget argument "x"? + +[case testMissingPositionalArgumentNoShift] +def f(x: int, y: str, z: bytes) -> None: ... + +f("wrong", 123) +[builtins fixtures/primitives.pyi] +[out] +main:3: error: Incompatible arguments for "f"; check for missing arguments + +[case testMissingPositionalArgumentShiftedTypesManyArgs] +def f(a: int, b: str, c: float, d: list[int], e: tuple[str, ...]) -> None: ... + +f(1, 1.5, [1, 2, 3], ("a", "b")) +[builtins fixtures/list.pyi] +[out] +main:3: error: Expected "str" for parameter "b"; did you forget argument "b"? + +[case testMissingPositionalArgumentShiftedWithDefaults] +def f(x: int, y: str, z: bytes = b'default') -> None: ... + +f("hello") +[builtins fixtures/primitives.pyi] +[out] +main:3: error: Incompatible arguments for "f"; check for missing arguments + +[case testMissingPositionalArgumentShiftedWithStarArgs] +def f(x: int, y: str, z: bytes, *args: int) -> None: ... + +f("hello", b'x') +[builtins fixtures/primitives.pyi] +[out] +main:3: error: Incompatible arguments for "f"; check for missing arguments From 7745bd2dad0ecc8af4e7d53463cb521528a4e96a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jan 2026 07:56:17 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checkexpr.py | 68 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2ff78ae2525d..32a8b8062f93 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1799,8 +1799,15 @@ def check_callable_call( arg_types = self.infer_arg_types_in_context(callee, args, arg_kinds, formal_to_actual) self.check_call_arguments( - callee, arg_types, arg_kinds, arg_names, args, - formal_to_actual, context, callable_name, object_type, + callee, + arg_types, + arg_kinds, + arg_names, + args, + formal_to_actual, + context, + callable_name, + object_type, ) if ( @@ -2345,8 +2352,14 @@ def check_call_arguments( """Check argument count and types, consolidating errors for missing positional args.""" with self.msg.filter_errors(): _, missing_positional = self.check_argument_count( - callee, arg_types, arg_kinds, arg_names, formal_to_actual, - context, object_type, callable_name, + callee, + arg_types, + arg_kinds, + arg_names, + formal_to_actual, + context, + object_type, + callable_name, ) if missing_positional: @@ -2363,8 +2376,13 @@ def check_call_arguments( with self.msg.filter_errors() as type_error_watcher: self.check_argument_types( - arg_types, arg_kinds, args, callee, formal_to_actual, context, - object_type=object_type + arg_types, + arg_kinds, + args, + callee, + formal_to_actual, + context, + object_type=object_type, ) has_type_errors = type_error_watcher.has_new_errors() @@ -2392,17 +2410,34 @@ def check_call_arguments( ) else: self.check_argument_count( - callee, arg_types, arg_kinds, arg_names, formal_to_actual, - context, object_type, callable_name, + callee, + arg_types, + arg_kinds, + arg_names, + formal_to_actual, + context, + object_type, + callable_name, ) else: self.check_argument_count( - callee, arg_types, arg_kinds, arg_names, formal_to_actual, - context, object_type, callable_name, + callee, + arg_types, + arg_kinds, + arg_names, + formal_to_actual, + context, + object_type, + callable_name, ) self.check_argument_types( - arg_types, arg_kinds, args, callee, formal_to_actual, context, - object_type=object_type + arg_types, + arg_kinds, + args, + callee, + formal_to_actual, + context, + object_type=object_type, ) def detect_shifted_positional_args( @@ -2452,7 +2487,9 @@ def detect_shifted_positional_args( next_idx = i + 1 if next_idx >= len(positional_formal_types): break - if is_subtype(actual_type, positional_formal_types[next_idx], options=self.chk.options): + if is_subtype( + actual_type, positional_formal_types[next_idx], options=self.chk.options + ): shift_position = i break else: @@ -2475,10 +2512,7 @@ def detect_shifted_positional_args( ) def _validate_shift_insertion( - self, - actual_types: list[Type], - formal_types: list[Type], - insert_position: int, + self, actual_types: list[Type], formal_types: list[Type], insert_position: int ) -> bool: """Check if inserting an argument at insert_position would fix type errors.""" for i, actual_type in enumerate(actual_types):