diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 59a0937a2..dc39ca8c1 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -38,6 +38,7 @@ from pandas import ( Timestamp, ) from pandas.core.arraylike import OpsMixin +from pandas.core.base import IndexOpsMixin from pandas.core.generic import NDFrame from pandas.core.groupby.generic import DataFrameGroupBy from pandas.core.indexers import BaseIndexer @@ -92,6 +93,7 @@ from pandas._typing import ( ArrayLike, AstypeArg, Axes, + AxesData, Axis, AxisColumn, AxisIndex, @@ -197,6 +199,8 @@ class _iLocIndexerFrame(_iLocIndexer, Generic[_T]): | tuple[slice] ), ) -> _T: ... + + # Keep in sync with `DataFrame.__setitem__` def __setitem__( self, idx: ( @@ -209,7 +213,7 @@ class _iLocIndexerFrame(_iLocIndexer, Generic[_T]): ), value: ( Scalar - | Series + | IndexOpsMixin | DataFrame | np_ndarray | NAType @@ -273,6 +277,8 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): ) -> Series: ... @overload def __getitem__(self, idx: tuple[Scalar, slice]) -> Series | _T: ... + + # Keep in sync with `DataFrame.__setitem__` @overload def __setitem__( self, @@ -284,7 +290,7 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): | NAType | NaTType | ArrayLike - | Series + | IndexOpsMixin | DataFrame | list | Mapping[Hashable, Scalar | NAType | NaTType] @@ -328,7 +334,7 @@ class _AtIndexerFrame(_AtIndexer): | NAType | NaTType | ArrayLike - | Series + | IndexOpsMixin | DataFrame | list | Mapping[Hashable, Scalar | NAType | NaTType] @@ -800,7 +806,74 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def isetitem( self, loc: int | Sequence[int], value: Scalar | ArrayLike | list[Any] ) -> None: ... - def __setitem__(self, key, value) -> None: ... + + # Keep in sync with `_iLocIndexerFrame.__setitem__` + @overload + def __setitem__( + self, + idx: ( + int + | IndexType + | tuple[int, int] + | tuple[IndexType, int] + | tuple[IndexType, IndexType] + | tuple[int, IndexType] + ), + value: ( + Scalar + | IndexOpsMixin + | Sequence[Scalar] + | DataFrame + | np.ndarray + | NAType + | NaTType + | Mapping[Hashable, Scalar | NAType | NaTType] + | None + ), + ) -> None: ... + # Keep in sync with `_LocIndexerFrame.__setitem__` + @overload + def __setitem__( + self, + idx: ( + MaskType | StrLike | _IndexSliceTuple | list[ScalarT] | IndexingInt | slice + ), + value: ( + Scalar + | NAType + | NaTType + | ArrayLike + | IndexOpsMixin + | Sequence[Scalar] + | DataFrame + | list + | Mapping[Hashable, Scalar | NAType | NaTType] + | None + ), + ) -> None: ... + @overload + def __setitem__( + self, + idx: tuple[_IndexSliceTuple, Hashable], + value: ( + Scalar + | NAType + | NaTType + | ArrayLike + | IndexOpsMixin + | Sequence[Scalar] + | dict + | None + ), + ) -> None: ... + # Extra cases not supported by `_LocIndexerFrame.__setitem__` / + # `_iLocIndexerFrame.__setitem__`. + @overload + def __setitem__( + self, + idx: IndexOpsMixin | DataFrame, + value: Scalar | NAType | NaTType | ArrayLike | Series | list | dict | None, + ): ... @overload def query( self, @@ -1917,7 +1990,11 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): **kwargs: Any, ) -> Series[_bool]: ... @final - def asof(self, where, subset: _str | list[_str] | None = None) -> Self: ... + def asof( + self, + where: Scalar | AnyArrayLike | Sequence[Scalar], + subset: Hashable | list[Hashable] | None = None, + ) -> Self: ... @final def asfreq( self, @@ -2454,7 +2531,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): **kwargs: Any, ) -> Series: ... # Not actually positional, but used to handle removal of deprecated - def set_axis(self, labels, *, axis: Axis = ..., copy: _bool = ...) -> Self: ... + def set_axis( + self, labels: AxesData, *, axis: Axis = 0, copy: _bool = ... + ) -> Self: ... def skew( self, axis: Axis | None = ..., diff --git a/pandas-stubs/core/indexing.pyi b/pandas-stubs/core/indexing.pyi index 832e2618a..868d10afc 100644 --- a/pandas-stubs/core/indexing.pyi +++ b/pandas-stubs/core/indexing.pyi @@ -3,7 +3,7 @@ from typing import ( TypeVar, ) -from pandas.core.indexes.api import Index +from pandas.core.base import IndexOpsMixin from pandas._libs.indexing import _NDFrameIndexerBase from pandas._typing import ( @@ -13,7 +13,7 @@ from pandas._typing import ( ) _IndexSliceTuple: TypeAlias = tuple[ - Index | MaskType | Scalar | list[ScalarT] | slice | tuple[Scalar, ...], ... + IndexOpsMixin | MaskType | Scalar | list[ScalarT] | slice | tuple[Scalar, ...], ... ] _IndexSliceUnion: TypeAlias = slice | _IndexSliceTuple diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 61eb0222e..321b82281 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -557,14 +557,14 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame): @final def __getattr__(self, name: _str) -> S1: ... - # Keep in sync with `iLocIndexerSeries.__getitem__` + # Keep in sync with `_iLocIndexerSeries.__getitem__` @overload def __getitem__(self, idx: IndexingInt) -> S1: ... @overload def __getitem__( self, idx: Index | Series | slice | np_ndarray_anyint ) -> Series[S1]: ... - # Keep in sync with `LocIndexerSeries.__getitem__` + # Keep in sync with `_LocIndexerSeries.__getitem__` @overload def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] self, @@ -1546,8 +1546,8 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame): @final def asof( self, - where: Scalar | Sequence[Scalar], - subset: _str | Sequence[_str] | None = None, + where: Scalar | AnyArrayLike | Sequence[Scalar], + subset: None = None, ) -> Scalar | Series[S1]: ... @overload def clip( # pyright: ignore[reportOverlappingOverload] diff --git a/tests/test_frame.py b/tests/test_frame.py index eb2960b0f..63fa9e65e 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -292,6 +292,9 @@ def test_types_setitem() -> None: df[5] = [5, 6] df[["col1", "col2"]] = [[1, 2], [3, 4]] df[s] = [5, 6] + df.loc[:, s] = [5, 6] + df["col1"] = [5, 6] + df[df["col1"] > 1] = [5, 6, 7] df[a] = [[1, 2], [3, 4]] df[i] = [8, 9]