Skip to content

Commit e8c8b46

Browse files
[BUG] fixed memory leak in TimeSeriesDataset by using @cached_property and clean-up of index construction (#1905)
#### Reference Issues/PRs #648 #### What does this implement/fix? Explain your changes. - Replaced `@property` and `@lru_cache` with `@cached_property` to fix a self-reference leak: previously, the cache kept strong references to every instance, preventing garbage collection and causing memory growth if many instances were created. - Improved `_construct_index()` function to return only essential columns in a consistent format.
1 parent cae3174 commit e8c8b46

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

pytorch_forecasting/data/timeseries/_timeseries.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
from copy import copy as _copy, deepcopy
10-
from functools import lru_cache
10+
from functools import cached_property
1111
import inspect
1212
from typing import Any, Callable, Optional, TypeVar, Union
1313
import warnings
@@ -812,8 +812,7 @@ def _get_lagged_names(self, name: str) -> dict[str, int]:
812812
"""
813813
return {f"{name}_lagged_by_{lag}": lag for lag in self._lags.get(name, [])}
814814

815-
@property
816-
@lru_cache(None)
815+
@cached_property
817816
def lagged_variables(self) -> dict[str, str]:
818817
"""Lagged variables.
819818
@@ -828,8 +827,7 @@ def lagged_variables(self) -> dict[str, str]:
828827
vars.update({lag_name: name for lag_name in self._get_lagged_names(name)})
829828
return vars
830829

831-
@property
832-
@lru_cache(None)
830+
@cached_property
833831
def lagged_targets(self) -> dict[str, str]:
834832
"""Subset of lagged_variables to variables that are lagged targets.
835833
@@ -850,8 +848,7 @@ def lagged_targets(self) -> dict[str, str]:
850848
)
851849
return vars
852850

853-
@property
854-
@lru_cache(None)
851+
@cached_property
855852
def min_lag(self) -> int:
856853
"""
857854
Minimum number of time steps variables are lagged.
@@ -865,8 +862,7 @@ def min_lag(self) -> int:
865862
else:
866863
return min([min(lag) for lag in self._lags.values()])
867864

868-
@property
869-
@lru_cache(None)
865+
@cached_property
870866
def max_lag(self) -> int:
871867
"""
872868
Maximum number of time steps variables are lagged.
@@ -983,8 +979,7 @@ def _get_auto_normalizer(self, data_properties: DataProperties) -> TorchNormaliz
983979
target_normalizer = normalizers[0]
984980
return target_normalizer
985981

986-
@property
987-
@lru_cache(None)
982+
@cached_property
988983
def _group_ids_mapping(self) -> dict[str, str]:
989984
"""
990985
Mapping of group id names to group ids used to identify series in dataset -
@@ -995,8 +990,7 @@ def _group_ids_mapping(self) -> dict[str, str]:
995990
"""
996991
return {name: f"__group_id__{name}" for name in self.group_ids}
997992

998-
@property
999-
@lru_cache(None)
993+
@cached_property
1000994
def _group_ids(self) -> list[str]:
1001995
"""
1002996
Group ids used to identify series in dataset.
@@ -1487,6 +1481,7 @@ def _to_tensor(cols, long=True) -> torch.Tensor:
14871481
weight=weight,
14881482
time=time,
14891483
)
1484+
14901485
return tensors
14911486

14921487
def _check_tensors(self, tensors):
@@ -1568,8 +1563,7 @@ def reals(self) -> list[str]:
15681563
+ self._time_varying_unknown_reals
15691564
)
15701565

1571-
@property
1572-
@lru_cache(None)
1566+
@cached_property
15731567
def target_names(self) -> list[str]:
15741568
"""
15751569
List of targets.
@@ -1860,7 +1854,18 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
18601854
)
18611855
assert len(df_index) > 0, msg
18621856

1863-
return df_index
1857+
minimal_columns = [
1858+
"index_start",
1859+
"index_end",
1860+
"sequence_length",
1861+
"time",
1862+
"sequence_id",
1863+
]
1864+
if predict_mode and "sequence_id" in df_index.columns:
1865+
minimal_columns.append("sequence_id")
1866+
1867+
df_index = df_index[minimal_columns].astype("int32", copy=False)
1868+
return df_index.reset_index(drop=True)
18641869

18651870
def filter(self, filter_func: Callable, copy: bool = True) -> TimeSeriesDataType:
18661871
"""Filter subsequences in dataset.

0 commit comments

Comments
 (0)