Skip to content

Commit 61ad913

Browse files
committed
cast to TimeUnit
1 parent 71f7c04 commit 61ad913

File tree

8 files changed

+36
-12
lines changed

8 files changed

+36
-12
lines changed

pandas/core/arrays/_ranges.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@
2222
from pandas.core.construction import range_to_ndarray
2323

2424
if TYPE_CHECKING:
25-
from pandas._typing import npt
25+
from pandas._typing import (
26+
TimeUnit,
27+
npt,
28+
)
2629

2730

2831
def generate_regular_range(
2932
start: Timestamp | Timedelta | None,
3033
end: Timestamp | Timedelta | None,
3134
periods: int | None,
3235
freq: BaseOffset,
33-
unit: str = "ns",
36+
unit: TimeUnit = "ns",
3437
) -> npt.NDArray[np.intp]:
3538
"""
3639
Generate a range of dates or timestamps with the spans between dates
@@ -46,7 +49,7 @@ def generate_regular_range(
4649
Number of periods in produced date range.
4750
freq : Tick
4851
Describes space between dates in produced date range.
49-
unit : str, default "ns"
52+
unit : {'s', 'ms', 'us', 'ns'}, default "ns"
5053
The resolution the output is meant to represent.
5154
5255
Returns

pandas/core/arrays/datetimes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@
106106

107107

108108
@overload
109-
def tz_to_dtype(tz: tzinfo, unit: str = ...) -> DatetimeTZDtype: ...
109+
def tz_to_dtype(tz: tzinfo, unit: TimeUnit = ...) -> DatetimeTZDtype: ...
110110

111111

112112
@overload
113-
def tz_to_dtype(tz: None, unit: str = ...) -> np.dtype[np.datetime64]: ...
113+
def tz_to_dtype(tz: None, unit: TimeUnit = ...) -> np.dtype[np.datetime64]: ...
114114

115115

116116
def tz_to_dtype(
117-
tz: tzinfo | None, unit: str = "ns"
117+
tz: tzinfo | None, unit: TimeUnit = "ns"
118118
) -> np.dtype[np.datetime64] | DatetimeTZDtype:
119119
"""
120120
Return a datetime64[ns] dtype appropriate for the given timezone.
@@ -393,6 +393,7 @@ def _from_sequence_not_strict(
393393
)
394394

395395
data_unit = np.datetime_data(subarr.dtype)[0]
396+
data_unit = cast("TimeUnit", data_unit)
396397
data_dtype = tz_to_dtype(tz, data_unit)
397398
result = cls._simple_new(subarr, freq=inferred_freq, dtype=data_dtype)
398399
if unit is not None and unit != result.unit:
@@ -2935,7 +2936,7 @@ def _generate_range(
29352936
periods: int | None,
29362937
offset: BaseOffset,
29372938
*,
2938-
unit: str,
2939+
unit: TimeUnit,
29392940
) -> Generator[Timestamp]:
29402941
"""
29412942
Generates a sequence of dates corresponding to the specified time

pandas/core/dtypes/cast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
DtypeObj,
101101
NumpyIndexT,
102102
Scalar,
103+
TimeUnit,
103104
)
104105

105106
from pandas import Index
@@ -567,6 +568,7 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
567568
# different unit, e.g. passed np.timedelta64(24, "h") with dtype=m8[ns]
568569
# see if we can losslessly cast it to our dtype
569570
unit = np.datetime_data(dtype)[0]
571+
unit = cast("TimeUnit", unit)
570572
try:
571573
td = Timedelta(fill_value).as_unit(unit, round_ok=False)
572574
except OutOfBoundsTimedelta:

pandas/core/dtypes/dtypes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -781,10 +781,10 @@ def base(self) -> DtypeObj: # type: ignore[override]
781781
def str(self) -> str: # type: ignore[override]
782782
return f"|M8[{self.unit}]"
783783

784-
def __init__(self, unit: str_type | DatetimeTZDtype = "ns", tz=None) -> None:
784+
def __init__(self, unit: TimeUnit | DatetimeTZDtype = "ns", tz=None) -> None:
785785
if isinstance(unit, DatetimeTZDtype):
786786
# error: "str" has no attribute "tz"
787-
unit, tz = unit.unit, unit.tz # type: ignore[attr-defined]
787+
unit, tz = unit.unit, unit.tz # type: ignore[union-attr]
788788

789789
if unit != "ns":
790790
if isinstance(unit, str) and tz is None:
@@ -895,7 +895,8 @@ def construct_from_string(cls, string: str_type) -> DatetimeTZDtype:
895895
if match:
896896
d = match.groupdict()
897897
try:
898-
return cls(unit=d["unit"], tz=d["tz"])
898+
unit = cast("TimeUnit", d["unit"])
899+
return cls(unit=unit, tz=d["tz"])
899900
except (KeyError, TypeError, ValueError) as err:
900901
# KeyError if maybe_get_tz tries and fails to get a
901902
# zoneinfo timezone (actually zoneinfo.ZoneInfoNotFoundError).
@@ -972,6 +973,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
972973
if all(isinstance(t, DatetimeTZDtype) and t.tz == self.tz for t in dtypes):
973974
np_dtype = np.max([cast(DatetimeTZDtype, t).base for t in [self, *dtypes]])
974975
unit = np.datetime_data(np_dtype)[0]
976+
unit = cast("TimeUnit", unit)
975977
return type(self)(unit=unit, tz=self.tz)
976978
return super()._get_common_dtype(dtypes)
977979

pandas/core/reshape/tile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TYPE_CHECKING,
99
Any,
1010
Literal,
11+
cast,
1112
)
1213

1314
import numpy as np
@@ -49,6 +50,7 @@
4950
from pandas._typing import (
5051
DtypeObj,
5152
IntervalLeftRight,
53+
TimeUnit,
5254
)
5355

5456

@@ -412,7 +414,7 @@ def _nbins_to_bins(x_idx: Index, nbins: int, right: bool) -> Index:
412414
# error: Argument 1 to "dtype_to_unit" has incompatible type
413415
# "dtype[Any] | ExtensionDtype"; expected "DatetimeTZDtype | dtype[Any]"
414416
unit = dtype_to_unit(x_idx.dtype) # type: ignore[arg-type]
415-
td = Timedelta(seconds=1).as_unit(unit)
417+
td = Timedelta(seconds=1).as_unit(cast("TimeUnit", unit))
416418
# Use DatetimeArray/TimedeltaArray method instead of linspace
417419
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
418420
# has no attribute "_generate_range"
@@ -595,6 +597,7 @@ def _format_labels(
595597
# error: Argument 1 to "dtype_to_unit" has incompatible type
596598
# "dtype[Any] | ExtensionDtype"; expected "DatetimeTZDtype | dtype[Any]"
597599
unit = dtype_to_unit(bins.dtype) # type: ignore[arg-type]
600+
unit = cast("TimeUnit", unit)
598601
formatter = lambda x: x
599602
adjust = lambda x: x - Timedelta(1, unit=unit).as_unit(unit)
600603
else:

pandas/core/tools/datetimes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787

8888
from pandas._libs.tslibs.nattype import NaTType
8989
from pandas._libs.tslibs.timedeltas import UnitChoices
90+
from pandas._typing import TimeUnit
9091

9192
from pandas import (
9293
DataFrame,
@@ -447,6 +448,7 @@ def _convert_listlike_datetimes(
447448
# We can take a shortcut since the datetime64 numpy array
448449
# is in UTC
449450
out_unit = np.datetime_data(result.dtype)[0]
451+
out_unit = cast("TimeUnit", out_unit)
450452
dtype = tz_to_dtype(tz_parsed, out_unit)
451453
dt64_values = result.view(f"M8[{dtype.unit}]")
452454
dta = DatetimeArray._simple_new(dt64_values, dtype=dtype)
@@ -469,13 +471,15 @@ def _array_strptime_with_fallback(
469471
result, tz_out = array_strptime(arg, fmt, exact=exact, errors=errors, utc=utc)
470472
if tz_out is not None:
471473
unit = np.datetime_data(result.dtype)[0]
474+
unit = cast("TimeUnit", unit)
472475
dtype = DatetimeTZDtype(tz=tz_out, unit=unit)
473476
dta = DatetimeArray._simple_new(result, dtype=dtype)
474477
if utc:
475478
dta = dta.tz_convert("UTC")
476479
return Index(dta, name=name)
477480
elif result.dtype != object and utc:
478481
unit = np.datetime_data(result.dtype)[0]
482+
unit = cast("TimeUnit", unit)
479483
res = Index(result, dtype=f"M8[{unit}, UTC]", name=name)
480484
return res
481485
return Index(result, dtype=result.dtype, name=name)

pandas/core/window/ewm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import datetime
44
from functools import partial
55
from textwrap import dedent
6-
from typing import TYPE_CHECKING
6+
from typing import (
7+
TYPE_CHECKING,
8+
cast,
9+
)
710

811
import numpy as np
912

@@ -57,6 +60,7 @@
5760
if TYPE_CHECKING:
5861
from pandas._typing import (
5962
TimedeltaConvertibleTypes,
63+
TimeUnit,
6064
npt,
6165
)
6266

@@ -122,6 +126,7 @@ def _calculate_deltas(
122126
Diff of the times divided by the half-life
123127
"""
124128
unit = dtype_to_unit(times.dtype)
129+
unit = cast("TimeUnit", unit)
125130
if isinstance(times, ABCSeries):
126131
times = times._values
127132
_times = np.asarray(times.view(np.int64), dtype=np.float64)

pandas/io/pytables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
AxisInt,
135135
DtypeArg,
136136
FilePath,
137+
TimeUnit,
137138
npt,
138139
)
139140

@@ -5093,6 +5094,9 @@ def _set_tz(
50935094
# Argument "tz" to "tz_to_dtype" has incompatible type "str | tzinfo | None";
50945095
# expected "tzinfo"
50955096
unit, _ = np.datetime_data(datetime64_dtype) # parsing dtype: unit, count
5097+
unit = cast("TimeUnit", unit)
5098+
# error: Argument "tz" to "tz_to_dtype" has incompatible type
5099+
# "str | tzinfo | None"; expected "tzinfo"
50965100
dtype = tz_to_dtype(tz=tz, unit=unit) # type: ignore[arg-type]
50975101
dta = DatetimeArray._from_sequence(values, dtype=dtype)
50985102
return dta

0 commit comments

Comments
 (0)