diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index b5081f113f91..996ee8c42aeb 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -774,6 +774,21 @@ of the above sections. f(memoryview(b"")) # Ok +.. option:: --disallow-str-iteration + + Disallow iterating over ``str`` values. + This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected. + To iterate over characters, call ``iter`` on the string explicitly. + + .. code-block:: python + + s = "hello" + for ch in s: # error: Iterating over "str" is disallowed + print(ch) + + for ch in iter(s): # OK + print(ch) + .. option:: --extra-checks This flag enables additional checks that are technically correct but may be diff --git a/docs/source/config_file.rst b/docs/source/config_file.rst index 77f952471007..41c15536230d 100644 --- a/docs/source/config_file.rst +++ b/docs/source/config_file.rst @@ -852,6 +852,14 @@ section of the command line docs. Disable treating ``bytearray`` and ``memoryview`` as subtypes of ``bytes``. This will be enabled by default in *mypy 2.0*. +.. confval:: disallow_str_iteration + + :type: boolean + :default: False + + Disallow iterating over ``str`` values. + This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected. + .. confval:: strict :type: boolean diff --git a/misc/typeshed_patches/0001-Add-explicit-overload-for-iter-of-str.patch b/misc/typeshed_patches/0001-Add-explicit-overload-for-iter-of-str.patch new file mode 100644 index 000000000000..d5a9a7150291 --- /dev/null +++ b/misc/typeshed_patches/0001-Add-explicit-overload-for-iter-of-str.patch @@ -0,0 +1,13 @@ +diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi +index bd425ff3c..5dae75dd9 100644 +--- a/mypy/typeshed/stdlib/builtins.pyi ++++ b/mypy/typeshed/stdlib/builtins.pyi +@@ -1458,6 +1458,8 @@ class _GetItemIterable(Protocol[_T_co]): + @overload + def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ... + @overload ++def iter(object: str, /) -> Iterator[str]: ... ++@overload + def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ... + @overload + def iter(object: Callable[[], _T | None], sentinel: None, /) -> Iterator[_T]: ... diff --git a/mypy/checker.py b/mypy/checker.py index 008becdd3483..17b39b22563e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -32,6 +32,7 @@ ) from mypy.checkpattern import PatternChecker from mypy.constraints import SUPERTYPE_OF +from mypy.disallow_str_iteration_state import disallow_str_iteration_state from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode from mypy.errors import ( @@ -513,7 +514,11 @@ def check_first_pass(self) -> None: Deferred functions will be processed by check_second_pass(). """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional), checker_state.set(self): + with ( + state.strict_optional_set(self.options.strict_optional), + disallow_str_iteration_state.set(self.options.disallow_str_iteration), + checker_state.set(self), + ): self.errors.set_file( self.path, self.tree.fullname, scope=self.tscope, options=self.options ) @@ -558,7 +563,11 @@ def check_second_pass( """ self.allow_constructor_cache = allow_constructor_cache self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional), checker_state.set(self): + with ( + state.strict_optional_set(self.options.strict_optional), + disallow_str_iteration_state.set(self.options.disallow_str_iteration), + checker_state.set(self), + ): if not todo and not self.deferred_nodes: return False self.errors.set_file( @@ -5378,6 +5387,12 @@ def analyze_iterable_item_type_without_expression( echk = self.expr_checker iterable: Type iterable = get_proper_type(type) + + if disallow_str_iteration_state.disallow_str_iteration and self.is_str_iteration_type( + iterable + ): + self.msg.str_iteration_disallowed(context, iterable) + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0] if ( @@ -5390,6 +5405,18 @@ def analyze_iterable_item_type_without_expression( iterable = echk.check_method_call_by_name("__next__", iterator, [], [], context)[0] return iterator, iterable + def is_str_iteration_type(self, typ: Type) -> bool: + typ = get_proper_type(typ) + if isinstance(typ, LiteralType): + return isinstance(typ.value, str) + if isinstance(typ, Instance): + return is_proper_subtype(typ, self.named_type("builtins.str")) + if isinstance(typ, UnionType): + return any(self.is_str_iteration_type(item) for item in typ.relevant_items()) + if isinstance(typ, TypeVarType): + return self.is_str_iteration_type(typ.upper_bound) + return False + def analyze_range_native_int_type(self, expr: Expression) -> Type | None: """Try to infer native int item type from arguments to range(...). diff --git a/mypy/disallow_str_iteration_state.py b/mypy/disallow_str_iteration_state.py new file mode 100644 index 000000000000..930243b00af1 --- /dev/null +++ b/mypy/disallow_str_iteration_state.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Final + + +class DisallowStrIterationState: + # Wrap this in a class since it's faster that using a module-level attribute. + + def __init__(self, disallow_str_iteration: bool) -> None: + # Value varies by file being processed + self.disallow_str_iteration = disallow_str_iteration + + @contextmanager + def set(self, value: bool) -> Iterator[None]: + saved = self.disallow_str_iteration + self.disallow_str_iteration = value + try: + yield + finally: + self.disallow_str_iteration = saved + + +disallow_str_iteration_state: Final = DisallowStrIterationState(disallow_str_iteration=False) diff --git a/mypy/main.py b/mypy/main.py index 926e72515d95..c46ae3d07edb 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -937,6 +937,14 @@ def add_invertible_flag( group=strictness_group, ) + add_invertible_flag( + "--disallow-str-iteration", + default=True, + strict_flag=False, + help="Disallow iterating over str instances", + group=strictness_group, + ) + add_invertible_flag( "--extra-checks", default=False, diff --git a/mypy/meet.py b/mypy/meet.py index 365544d4584f..92f549915678 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -3,6 +3,7 @@ from collections.abc import Callable from mypy import join +from mypy.disallow_str_iteration_state import disallow_str_iteration_state from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype from mypy.state import state @@ -14,6 +15,7 @@ is_proper_subtype, is_same_type, is_subtype, + is_subtype_relation_ignored_to_disallow_str_iteration, ) from mypy.typeops import is_recursive_pair, make_simplified_union, tuple_fallback from mypy.types import ( @@ -596,6 +598,12 @@ def _type_object_overlap(left: Type, right: Type) -> bool: if right.type.fullname == "builtins.int" and left.type.fullname in MYPYC_NATIVE_INT_NAMES: return True + if disallow_str_iteration_state.disallow_str_iteration: + if is_subtype_relation_ignored_to_disallow_str_iteration(left, right): + return False + elif is_subtype_relation_ignored_to_disallow_str_iteration(right, left): + return False + # Two unrelated types cannot be partially overlapping: they're disjoint. if left.type.has_base(right.type.fullname): left = map_instance_to_supertype(left, right.type) diff --git a/mypy/messages.py b/mypy/messages.py index 5863b8719b95..665ea3259841 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1136,6 +1136,10 @@ def wrong_number_values_to_unpack( def unpacking_strings_disallowed(self, context: Context) -> None: self.fail("Unpacking a string is disallowed", context, code=codes.STR_UNPACK) + def str_iteration_disallowed(self, context: Context, str_type: Type) -> None: + self.fail(f"Iterating over {format_type(str_type, self.options)} is disallowed", context) + self.note("This is because --disallow-str-iteration is enabled", context) + def type_not_iterable(self, type: Type, context: Context) -> None: self.fail(f"{format_type(type, self.options)} object is not iterable", context) @@ -2210,6 +2214,15 @@ def report_protocol_problems( conflict_types = get_conflict_protocol_types( subtype, supertype, class_obj=class_obj, options=self.options ) + + if subtype.type.has_base("builtins.str") and supertype.type.has_base("typing.Container"): + # `str` doesn't properly conform to the `Container` protocol, but we don't want to show that as the reason for the error. + conflict_types = [ + conflict_type + for conflict_type in conflict_types + if conflict_type[0] != "__contains__" + ] + if conflict_types and ( not is_subtype(subtype, erase_type(supertype), options=self.options) or not subtype.type.defn.type_vars @@ -3122,6 +3135,7 @@ def get_conflict_protocol_types( Return them as a list of ('member', 'got', 'expected', 'is_lvalue'). """ assert right.type.is_protocol + conflicts: list[tuple[str, Type, Type, bool]] = [] for member in right.type.protocol_members: if member in ("__init__", "__new__"): diff --git a/mypy/options.py b/mypy/options.py index cb5088af7e79..044bacea13f3 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -37,6 +37,7 @@ class BuildType: "disallow_any_unimported", "disallow_incomplete_defs", "disallow_subclassing_any", + "disallow_str_iteration", "disallow_untyped_calls", "disallow_untyped_decorators", "disallow_untyped_defs", @@ -238,6 +239,9 @@ def __init__(self) -> None: # Disable treating bytearray and memoryview as subtypes of bytes self.strict_bytes = False + # Disallow iterating over str instances or using them as Sequence[T] + self.disallow_str_iteration = True + # Deprecated, use extra_checks instead. self.strict_concatenate = False diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 350d57a7e4ad..8b7c357edfbd 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,6 +8,7 @@ import mypy.constraints import mypy.typeops from mypy.checker_state import checker_state +from mypy.disallow_str_iteration_state import disallow_str_iteration_state from mypy.erasetype import erase_type from mypy.expandtype import ( expand_self_type, @@ -479,6 +480,13 @@ def visit_instance(self, left: Instance) -> bool: # dynamic base classes correctly, see #5456. return not isinstance(self.right, NoneType) right = self.right + + if ( + disallow_str_iteration_state.disallow_str_iteration + and isinstance(right, Instance) + and is_subtype_relation_ignored_to_disallow_str_iteration(left, right) + ): + return False if isinstance(right, TupleType) and right.partial_fallback.type.is_enum: return self._is_subtype(left, mypy.typeops.tuple_fallback(right)) if isinstance(right, TupleType): @@ -2311,3 +2319,21 @@ def is_erased_instance(t: Instance) -> bool: elif not isinstance(get_proper_type(arg), AnyType): return False return True + + +def is_subtype_relation_ignored_to_disallow_str_iteration(left: Instance, right: Instance) -> bool: + return ( + left.type.has_base("builtins.str") + and not right.type.has_base("builtins.str") + and any( + right.type.has_base(base) + for base in ( + "collections.abc.Collection", + "collections.abc.Iterable", + "collections.abc.Sequence", + "typing.Collection", + "typing.Iterable", + "typing.Sequence", + ) + ) + ) diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi index bd425ff3c212..5dae75dd9815 100644 --- a/mypy/typeshed/stdlib/builtins.pyi +++ b/mypy/typeshed/stdlib/builtins.pyi @@ -1458,6 +1458,8 @@ class _GetItemIterable(Protocol[_T_co]): @overload def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ... @overload +def iter(object: str, /) -> Iterator[str]: ... +@overload def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ... @overload def iter(object: Callable[[], _T | None], sentinel: None, /) -> Iterator[_T]: ... diff --git a/test-data/unit/check-flags.test b/test-data/unit/check-flags.test index 8d18c699e628..72999c88a69e 100644 --- a/test-data/unit/check-flags.test +++ b/test-data/unit/check-flags.test @@ -2451,6 +2451,80 @@ f(bytearray(b"asdf")) # E: Argument 1 to "f" has incompatible type "bytearray"; f(memoryview(b"asdf")) # E: Argument 1 to "f" has incompatible type "memoryview"; expected "bytes" [builtins fixtures/primitives.pyi] +[case testDisallowStrIteration] +# flags: --disallow-str-iteration +from abc import abstractmethod +from typing import Collection, Container, Iterable, Mapping, Protocol, Sequence, TypeVar, Union + +def takes_str(x: str): + for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str" + [ch for ch in x] # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled + +s = "hello" + +def takes_seq_str(x: Sequence[str]) -> None: ... +takes_seq_str(s) # E: Argument 1 to "takes_seq_str" has incompatible type "str"; expected "Sequence[str]" + +def takes_iter_str(x: Iterable[str]) -> None: ... +takes_iter_str(s) # E: Argument 1 to "takes_iter_str" has incompatible type "str"; expected "Iterable[str]" + +def takes_collection_str(x: Collection[str]) -> None: ... +takes_collection_str(s) # E: Argument 1 to "takes_collection_str" has incompatible type "str"; expected "Collection[str]" + +seq: Sequence[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Sequence[str]") +iterable: Iterable[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Iterable[str]") +collection: Collection[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Collection[str]") + +def takes_maybe_seq(x: "str | Sequence[int]") -> None: + for ch in x: # E: Iterating over "str | Sequence[int]" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str | builtins.int" + +T = TypeVar('T', bound=str) +_T_co = TypeVar('_T_co', covariant=True) + +def takes_str_upper_bound(x: T) -> None: + for ch in x: # E: Iterating over "T" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str" + +class StrSubclass(str): + def __contains__(self, x: object) -> bool: ... + +def takes_str_subclass(x: StrSubclass): + for ch in x: # E: Iterating over "StrSubclass" is disallowed # N: This is because --disallow-str-iteration is enabled + reveal_type(ch) # N: Revealed type is "builtins.str" + +class CollectionSubclass(Collection[_T_co], Protocol[_T_co]): + @abstractmethod + def __missing_impl__(self): ... + +def takes_collection_subclass(x: CollectionSubclass[str]) -> None: ... + +takes_collection_subclass(s) # E: Argument 1 to "takes_collection_subclass" has incompatible type "str"; expected "CollectionSubclass[str]" \ + # N: "str" is missing following "CollectionSubclass" protocol member: \ + # N: __missing_impl__ + +takes_collection_subclass(StrSubclass()) # E: Argument 1 to "takes_collection_subclass" has incompatible type "StrSubclass"; expected "CollectionSubclass[str]" \ + # N: "StrSubclass" is missing following "CollectionSubclass" protocol member: \ + # N: __missing_impl__ + +def dict_unpacking_unaffected_by_union_simplification(x: Mapping[str, Union[str, Sequence[str]]]) -> None: + x = {**x} + +def narrowing(x: "str | Sequence[str]"): + if isinstance(x, str): + reveal_type(x) # N: Revealed type is "builtins.str" + else: + reveal_type(x) # N: Revealed type is "typing.Sequence[builtins.str]" + +[builtins fixtures/str-iter.pyi] +[typing fixtures/typing-str-iter.pyi] + +[case testIterStrOverload] +# flags: --disallow-str-iteration +reveal_type(iter("foo")) # N: Revealed type is "typing.Iterable[builtins.str]" +[builtins fixtures/dict.pyi] + [case testNoCrashFollowImportsForStubs] # flags: --config-file tmp/mypy.ini {**{"x": "y"}} diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index ed2287511161..8fde41357a3d 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -61,4 +61,7 @@ class ellipsis: pass class BaseException: pass def isinstance(x: object, t: Union[type, Tuple[type, ...]]) -> bool: pass +@overload +def iter(__iterable: str) -> Iterable[str]: pass +@overload def iter(__iterable: Iterable[T]) -> Iterator[T]: pass diff --git a/test-data/unit/fixtures/str-iter.pyi b/test-data/unit/fixtures/str-iter.pyi new file mode 100644 index 000000000000..49c61bad65d2 --- /dev/null +++ b/test-data/unit/fixtures/str-iter.pyi @@ -0,0 +1,52 @@ +# Builtins stub used in disallow-str-iteration tests. + + +from typing import Generic, Iterator, Mapping, Sequence, TypeVar, overload + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + +class object: + def __init__(self) -> None: pass + +class type: pass +class int: pass +class bool(int): pass +class ellipsis: pass +class slice: pass + +class str(Sequence[str]): + def __iter__(self) -> Iterator[str]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + @overload + def __getitem__(self, i: int, /) -> str: ... + @overload + def __getitem__(self, s: slice, /) -> Sequence[str]: ... + +class list(Sequence[_T], Generic[_T]): + def __iter__(self) -> Iterator[_T]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + @overload + def __getitem__(self, i: int, /) -> _T: ... + @overload + def __getitem__(self, s: slice, /) -> list[_T]: ... + +class tuple(Sequence[_T], Generic[_T]): + def __iter__(self) -> Iterator[_T]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + @overload + def __getitem__(self, i: int, /) -> _T: ... + @overload + def __getitem__(self, s: slice, /) -> list[_T]: ... + +class dict(Mapping[_KT, _VT], Generic[_KT, _VT]): + def __iter__(self) -> Iterator[_KT]: pass + def __len__(self) -> int: pass + def __contains__(self, item: object) -> bool: pass + def __getitem__(self, key: _KT) -> _VT: pass + +def isinstance(x: object, t: type) -> bool: pass diff --git a/test-data/unit/fixtures/typing-str-iter.pyi b/test-data/unit/fixtures/typing-str-iter.pyi new file mode 100644 index 000000000000..f0fb8aa57cfb --- /dev/null +++ b/test-data/unit/fixtures/typing-str-iter.pyi @@ -0,0 +1,63 @@ +# Minimal typing fixture for disallow-str-iteration tests. + +import _typeshed +from abc import ABCMeta, abstractmethod + +Any = object() +TypeVar = 0 +Generic = 0 +Protocol = 0 +Union = 0 +overload = 0 + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_T_co = TypeVar("_T_co", covariant=True) +_KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers. +_VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers. +_TC = TypeVar("_TC", bound=type[object]) + +@runtime_checkable +class Iterable(Protocol[_T_co]): + @abstractmethod + def __iter__(self) -> Iterator[_T_co]: ... + +@runtime_checkable +class Iterator(Iterable[_T_co], Protocol[_T_co]): + @abstractmethod + def __next__(self) -> _T_co: ... + def __iter__(self) -> Iterator[_T_co]: ... + +@runtime_checkable +class Container(Protocol[_T_co]): + # This is generic more on vibes than anything else + @abstractmethod + def __contains__(self, x: object, /) -> bool: ... + +@runtime_checkable +class Collection(Iterable[_T_co], Container[_T_co], Protocol[_T_co]): + # Implement Sized (but don't have it as a base class). + @abstractmethod + def __len__(self) -> int: ... + +class Sequence(Collection[_T_co]): + @overload + @abstractmethod + def __getitem__(self, index: int) -> _T_co: ... + @overload + @abstractmethod + def __getitem__(self, index: slice) -> Sequence[_T_co]: ... + def __contains__(self, value: object) -> bool: ... + def __iter__(self) -> Iterator[_T_co]: ... + +class KeysView(Protocol[_KT_co]): + def __iter__(self) -> Iterator[_KT_co]: ... + +class Mapping(Collection[_KT], Generic[_KT, _VT_co]): + @abstractmethod + def __getitem__(self, key: _KT, /) -> _VT_co: ... + def __contains__(self, key: object, /) -> bool: ... + def keys(self) -> KeysView[_KT]: ... + +def runtime_checkable(cls: _TC) -> _TC: + return cls