From 14246542fdf855adcf9b1a39049af2ca590e696c Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 16 Jan 2026 20:33:15 +0100 Subject: [PATCH] canonicalized set/frozenset signatures --- .../@tests/test_cases/builtins/check_set.py | 43 +++++++++++++++++++ stdlib/builtins.pyi | 24 +++++------ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/stdlib/@tests/test_cases/builtins/check_set.py b/stdlib/@tests/test_cases/builtins/check_set.py index 604251b0bb67..89cb8683bfe1 100644 --- a/stdlib/@tests/test_cases/builtins/check_set.py +++ b/stdlib/@tests/test_cases/builtins/check_set.py @@ -12,3 +12,46 @@ def test_set_difference(x: set[Literal["foo", "bar"]], y: set[str], z: set[int]) assert_type(z - x, set[int]) assert_type(y - z, set[str]) assert_type(z - y, set[int]) + + +def test_set_interface_overlapping_type(s: set[Literal["foo", "bar"]], y: set[str], key: str) -> None: + s.add(key) # type: ignore + s.discard(key) + s.remove(key) # type: ignore + s.difference_update(y) + s.intersection_update(y) + s.symmetric_difference_update(y) # type: ignore + s.update(y) # type: ignore + + assert_type(s.difference(y), set[Literal["foo", "bar"]]) + assert_type(s.intersection(y), set[Literal["foo", "bar"]]) + assert_type(s.isdisjoint(y), bool) + assert_type(s.issubset(y), bool) + assert_type(s.issuperset(y), bool) + assert_type(s.symmetric_difference(y), set[str]) + assert_type(s.union(y), set[str]) + + assert_type(s - y, set[Literal["foo", "bar"]]) + assert_type(s & y, set[Literal["foo", "bar"]]) + assert_type(s | y, set[str]) + assert_type(s ^ y, set[str]) + + s -= y + s &= y + s |= y # type: ignore + s ^= y # type: ignore + + +def test_frozenset_interface(s: frozenset[Literal["foo", "bar"]], y: frozenset[str]) -> None: + assert_type(s.difference(y), frozenset[Literal["foo", "bar"]]) + assert_type(s.intersection(y), frozenset[Literal["foo", "bar"]]) + assert_type(s.isdisjoint(y), bool) + assert_type(s.issubset(y), bool) + assert_type(s.issuperset(y), bool) + assert_type(s.symmetric_difference(y), frozenset[str]) + assert_type(s.union(y), frozenset[str]) + + assert_type(s - y, frozenset[Literal["foo", "bar"]]) + assert_type(s & y, frozenset[Literal["foo", "bar"]]) + assert_type(s | y, frozenset[str]) + assert_type(s ^ y, frozenset[str]) diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 5695b17ca36d..a22bb400adfe 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1262,16 +1262,16 @@ class set(MutableSet[_T]): def __init__(self, iterable: Iterable[_T], /) -> None: ... def add(self, element: _T, /) -> None: ... def copy(self) -> set[_T]: ... - def difference(self, *s: Iterable[Any]) -> set[_T]: ... - def difference_update(self, *s: Iterable[Any]) -> None: ... - def discard(self, element: _T, /) -> None: ... - def intersection(self, *s: Iterable[Any]) -> set[_T]: ... - def intersection_update(self, *s: Iterable[Any]) -> None: ... - def isdisjoint(self, s: Iterable[Any], /) -> bool: ... - def issubset(self, s: Iterable[Any], /) -> bool: ... - def issuperset(self, s: Iterable[Any], /) -> bool: ... + def difference(self, *s: Iterable[object]) -> set[_T]: ... + def difference_update(self, *s: Iterable[object]) -> None: ... + def discard(self, element: object, /) -> None: ... + def intersection(self, *s: Iterable[object]) -> set[_T]: ... + def intersection_update(self, *s: Iterable[object]) -> None: ... + def isdisjoint(self, s: Iterable[object], /) -> bool: ... + def issubset(self, s: Iterable[object], /) -> bool: ... + def issuperset(self, s: Iterable[object], /) -> bool: ... def remove(self, element: _T, /) -> None: ... - def symmetric_difference(self, s: Iterable[_T], /) -> set[_T]: ... + def symmetric_difference(self, s: Iterable[_S], /) -> set[_T | _S]: ... def symmetric_difference_update(self, s: Iterable[_T], /) -> None: ... def union(self, *s: Iterable[_S]) -> set[_T | _S]: ... def update(self, *s: Iterable[_T]) -> None: ... @@ -1303,15 +1303,15 @@ class frozenset(AbstractSet[_T_co]): def copy(self) -> frozenset[_T_co]: ... def difference(self, *s: Iterable[object]) -> frozenset[_T_co]: ... def intersection(self, *s: Iterable[object]) -> frozenset[_T_co]: ... - def isdisjoint(self, s: Iterable[_T_co], /) -> bool: ... + def isdisjoint(self, s: Iterable[object], /) -> bool: ... def issubset(self, s: Iterable[object], /) -> bool: ... def issuperset(self, s: Iterable[object], /) -> bool: ... - def symmetric_difference(self, s: Iterable[_T_co], /) -> frozenset[_T_co]: ... + def symmetric_difference(self, s: Iterable[_S], /) -> frozenset[_T_co | _S]: ... def union(self, *s: Iterable[_S]) -> frozenset[_T_co | _S]: ... def __len__(self) -> int: ... def __contains__(self, o: object, /) -> bool: ... def __iter__(self) -> Iterator[_T_co]: ... - def __and__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ... + def __and__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ... def __or__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ... def __sub__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ... def __xor__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...