Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions stdlib/@tests/test_cases/builtins/check_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,46 @@ def test_set_difference(x: set[Literal["foo", "bar"]], y: set[str], z: set[int])
assert_type(z - x, set[int])
assert_type(y - z, set[str])
assert_type(z - y, set[int])


def test_set_interface_overlapping_type(s: set[Literal["foo", "bar"]], y: set[str], key: str) -> None:
s.add(key) # type: ignore
s.discard(key)
s.remove(key) # type: ignore
s.difference_update(y)
s.intersection_update(y)
s.symmetric_difference_update(y) # type: ignore
s.update(y) # type: ignore

assert_type(s.difference(y), set[Literal["foo", "bar"]])
assert_type(s.intersection(y), set[Literal["foo", "bar"]])
assert_type(s.isdisjoint(y), bool)
assert_type(s.issubset(y), bool)
assert_type(s.issuperset(y), bool)
assert_type(s.symmetric_difference(y), set[str])
assert_type(s.union(y), set[str])

assert_type(s - y, set[Literal["foo", "bar"]])
assert_type(s & y, set[Literal["foo", "bar"]])
assert_type(s | y, set[str])
assert_type(s ^ y, set[str])

s -= y
s &= y
s |= y # type: ignore
s ^= y # type: ignore


def test_frozenset_interface(s: frozenset[Literal["foo", "bar"]], y: frozenset[str]) -> None:
assert_type(s.difference(y), frozenset[Literal["foo", "bar"]])
assert_type(s.intersection(y), frozenset[Literal["foo", "bar"]])
assert_type(s.isdisjoint(y), bool)
assert_type(s.issubset(y), bool)
assert_type(s.issuperset(y), bool)
assert_type(s.symmetric_difference(y), frozenset[str])
assert_type(s.union(y), frozenset[str])

assert_type(s - y, frozenset[Literal["foo", "bar"]])
assert_type(s & y, frozenset[Literal["foo", "bar"]])
assert_type(s | y, frozenset[str])
assert_type(s ^ y, frozenset[str])
24 changes: 12 additions & 12 deletions stdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1262,16 +1262,16 @@ class set(MutableSet[_T]):
def __init__(self, iterable: Iterable[_T], /) -> None: ...
def add(self, element: _T, /) -> None: ...
def copy(self) -> set[_T]: ...
def difference(self, *s: Iterable[Any]) -> set[_T]: ...
def difference_update(self, *s: Iterable[Any]) -> None: ...
def discard(self, element: _T, /) -> None: ...
def intersection(self, *s: Iterable[Any]) -> set[_T]: ...
def intersection_update(self, *s: Iterable[Any]) -> None: ...
def isdisjoint(self, s: Iterable[Any], /) -> bool: ...
def issubset(self, s: Iterable[Any], /) -> bool: ...
def issuperset(self, s: Iterable[Any], /) -> bool: ...
def difference(self, *s: Iterable[object]) -> set[_T]: ...
def difference_update(self, *s: Iterable[object]) -> None: ...
def discard(self, element: object, /) -> None: ...
def intersection(self, *s: Iterable[object]) -> set[_T]: ...
def intersection_update(self, *s: Iterable[object]) -> None: ...
def isdisjoint(self, s: Iterable[object], /) -> bool: ...
def issubset(self, s: Iterable[object], /) -> bool: ...
def issuperset(self, s: Iterable[object], /) -> bool: ...
def remove(self, element: _T, /) -> None: ...
def symmetric_difference(self, s: Iterable[_T], /) -> set[_T]: ...
def symmetric_difference(self, s: Iterable[_S], /) -> set[_T | _S]: ...
def symmetric_difference_update(self, s: Iterable[_T], /) -> None: ...
def union(self, *s: Iterable[_S]) -> set[_T | _S]: ...
def update(self, *s: Iterable[_T]) -> None: ...
Expand Down Expand Up @@ -1303,15 +1303,15 @@ class frozenset(AbstractSet[_T_co]):
def copy(self) -> frozenset[_T_co]: ...
def difference(self, *s: Iterable[object]) -> frozenset[_T_co]: ...
def intersection(self, *s: Iterable[object]) -> frozenset[_T_co]: ...
def isdisjoint(self, s: Iterable[_T_co], /) -> bool: ...
def isdisjoint(self, s: Iterable[object], /) -> bool: ...
def issubset(self, s: Iterable[object], /) -> bool: ...
def issuperset(self, s: Iterable[object], /) -> bool: ...
def symmetric_difference(self, s: Iterable[_T_co], /) -> frozenset[_T_co]: ...
def symmetric_difference(self, s: Iterable[_S], /) -> frozenset[_T_co | _S]: ...
def union(self, *s: Iterable[_S]) -> frozenset[_T_co | _S]: ...
def __len__(self) -> int: ...
def __contains__(self, o: object, /) -> bool: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __and__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ...
def __and__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ...
def __or__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...
def __sub__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ...
def __xor__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...
Expand Down