@@ -781,10 +781,10 @@ def base(self) -> DtypeObj: # type: ignore[override]
781781 def str (self ) -> str : # type: ignore[override]
782782 return f"|M8[{ self .unit } ]"
783783
784- def __init__ (self , unit : str_type | DatetimeTZDtype = "ns" , tz = None ) -> None :
784+ def __init__ (self , unit : TimeUnit | DatetimeTZDtype = "ns" , tz = None ) -> None :
785785 if isinstance (unit , DatetimeTZDtype ):
786786 # error: "str" has no attribute "tz"
787- unit , tz = unit .unit , unit .tz # type: ignore[attr-defined ]
787+ unit , tz = unit .unit , unit .tz # type: ignore[union-attr ]
788788
789789 if unit != "ns" :
790790 if isinstance (unit , str ) and tz is None :
@@ -895,7 +895,8 @@ def construct_from_string(cls, string: str_type) -> DatetimeTZDtype:
895895 if match :
896896 d = match .groupdict ()
897897 try :
898- return cls (unit = d ["unit" ], tz = d ["tz" ])
898+ unit = cast ("TimeUnit" , d ["unit" ])
899+ return cls (unit = unit , tz = d ["tz" ])
899900 except (KeyError , TypeError , ValueError ) as err :
900901 # KeyError if maybe_get_tz tries and fails to get a
901902 # zoneinfo timezone (actually zoneinfo.ZoneInfoNotFoundError).
@@ -972,6 +973,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
972973 if all (isinstance (t , DatetimeTZDtype ) and t .tz == self .tz for t in dtypes ):
973974 np_dtype = np .max ([cast (DatetimeTZDtype , t ).base for t in [self , * dtypes ]])
974975 unit = np .datetime_data (np_dtype )[0 ]
976+ unit = cast ("TimeUnit" , unit )
975977 return type (self )(unit = unit , tz = self .tz )
976978 return super ()._get_common_dtype (dtypes )
977979
0 commit comments