diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 9b5fc9619..6ff0b2f12 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1103,25 +1103,16 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame): fill_value: int | _str | dict | None = None, sort: _bool = True, ) -> DataFrame: ... - @overload - def map( - self, - arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2], - na_action: Literal["ignore"] = ..., - ) -> Series[S2]: ... - @overload def map( self, - arg: Callable[[S1 | NAType], S2 | NAType] | Mapping[S1, S2] | Series[S2], - na_action: None = None, + arg: ( + Callable[[S1], S2 | None | NAType] + | Mapping[S1, S2 | None | NAType] + | Series[S2] + ), + na_action: Literal["ignore"] | None = None, ) -> Series[S2]: ... @overload - def map( - self, - arg: Callable[[Any], Any] | Mapping[Any, Any] | Series, - na_action: Literal["ignore"] | None = ..., - ) -> Series: ... - @overload def aggregate( self: Series[int], func: Literal["mean"], diff --git a/tests/series/test_series.py b/tests/series/test_series.py index dc2417a8d..24513a1c8 100644 --- a/tests/series/test_series.py +++ b/tests/series/test_series.py @@ -3584,6 +3584,7 @@ def test_map() -> None: pd.Series, str, ) + check(assert_type(s.map(mapping), "pd.Series[str]"), pd.Series, str) def callable(x: int) -> str: return str(x) @@ -3593,16 +3594,29 @@ def callable(x: int) -> str: pd.Series, str, ) + check(assert_type(s.map(callable), "pd.Series[str]"), pd.Series, str) series = pd.Series(["a", "b", "c"]) check( assert_type(s.map(series, na_action="ignore"), "pd.Series[str]"), pd.Series, str ) + check(assert_type(s.map(series), "pd.Series[str]"), pd.Series, str) unknown_series = pd.Series([1, 0, None]) check( - assert_type(unknown_series.map({1: True, 0: False, None: None}), pd.Series), + assert_type( + unknown_series.map({1: True, 0: False, None: None}), "pd.Series[bool]" + ), pd.Series, + bool, + ) + check( + assert_type( + unknown_series.map({1: True, 0: False, None: None}, na_action="ignore"), + "pd.Series[bool]", + ), + pd.Series, + bool, ) @@ -3620,10 +3634,22 @@ def callable(x: int | NAType) -> str | NAType: check( assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str ) + # na_action defaults to None + check(assert_type(s.map(callable), "pd.Series[str]"), pd.Series, str) series = pd.Series(["a", "b", "c"]) check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str) + def callable2(x: int | NAType | None) -> str | None: + if isinstance(x, int): + return str(x) + return None + + check( + assert_type(s.map(callable2, na_action=None), "pd.Series[str]"), pd.Series, str + ) + check(assert_type(s.map(callable2), "pd.Series[str]"), pd.Series, str) + def test_case_when() -> None: c = pd.Series([6, 7, 8, 9], name="c")