77"""
88
99from copy import copy as _copy , deepcopy
10- from functools import lru_cache
10+ from functools import cached_property
1111import inspect
1212from typing import Any , Callable , Optional , TypeVar , Union
1313import warnings
@@ -812,8 +812,7 @@ def _get_lagged_names(self, name: str) -> dict[str, int]:
812812 """
813813 return {f"{ name } _lagged_by_{ lag } " : lag for lag in self ._lags .get (name , [])}
814814
815- @property
816- @lru_cache (None )
815+ @cached_property
817816 def lagged_variables (self ) -> dict [str , str ]:
818817 """Lagged variables.
819818
@@ -828,8 +827,7 @@ def lagged_variables(self) -> dict[str, str]:
828827 vars .update ({lag_name : name for lag_name in self ._get_lagged_names (name )})
829828 return vars
830829
831- @property
832- @lru_cache (None )
830+ @cached_property
833831 def lagged_targets (self ) -> dict [str , str ]:
834832 """Subset of lagged_variables to variables that are lagged targets.
835833
@@ -850,8 +848,7 @@ def lagged_targets(self) -> dict[str, str]:
850848 )
851849 return vars
852850
853- @property
854- @lru_cache (None )
851+ @cached_property
855852 def min_lag (self ) -> int :
856853 """
857854 Minimum number of time steps variables are lagged.
@@ -865,8 +862,7 @@ def min_lag(self) -> int:
865862 else :
866863 return min ([min (lag ) for lag in self ._lags .values ()])
867864
868- @property
869- @lru_cache (None )
865+ @cached_property
870866 def max_lag (self ) -> int :
871867 """
872868 Maximum number of time steps variables are lagged.
@@ -983,8 +979,7 @@ def _get_auto_normalizer(self, data_properties: DataProperties) -> TorchNormaliz
983979 target_normalizer = normalizers [0 ]
984980 return target_normalizer
985981
986- @property
987- @lru_cache (None )
982+ @cached_property
988983 def _group_ids_mapping (self ) -> dict [str , str ]:
989984 """
990985 Mapping of group id names to group ids used to identify series in dataset -
@@ -995,8 +990,7 @@ def _group_ids_mapping(self) -> dict[str, str]:
995990 """
996991 return {name : f"__group_id__{ name } " for name in self .group_ids }
997992
998- @property
999- @lru_cache (None )
993+ @cached_property
1000994 def _group_ids (self ) -> list [str ]:
1001995 """
1002996 Group ids used to identify series in dataset.
@@ -1487,6 +1481,7 @@ def _to_tensor(cols, long=True) -> torch.Tensor:
14871481 weight = weight ,
14881482 time = time ,
14891483 )
1484+
14901485 return tensors
14911486
14921487 def _check_tensors (self , tensors ):
@@ -1568,8 +1563,7 @@ def reals(self) -> list[str]:
15681563 + self ._time_varying_unknown_reals
15691564 )
15701565
1571- @property
1572- @lru_cache (None )
1566+ @cached_property
15731567 def target_names (self ) -> list [str ]:
15741568 """
15751569 List of targets.
@@ -1860,7 +1854,18 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
18601854 )
18611855 assert len (df_index ) > 0 , msg
18621856
1863- return df_index
1857+ minimal_columns = [
1858+ "index_start" ,
1859+ "index_end" ,
1860+ "sequence_length" ,
1861+ "time" ,
1862+ "sequence_id" ,
1863+ ]
1864+ if predict_mode and "sequence_id" in df_index .columns :
1865+ minimal_columns .append ("sequence_id" )
1866+
1867+ df_index = df_index [minimal_columns ].astype ("int32" , copy = False )
1868+ return df_index .reset_index (drop = True )
18641869
18651870 def filter (self , filter_func : Callable , copy : bool = True ) -> TimeSeriesDataType :
18661871 """Filter subsequences in dataset.
0 commit comments