From ab33b972e6886d545fb82652640b8ad037e3be2b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 16 Feb 2025 13:32:24 +0530 Subject: [PATCH 01/11] initial commit --- pytorch_forecasting/data/data_modules.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 pytorch_forecasting/data/data_modules.py diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py new file mode 100644 index 000000000..e69de29bb From dd8b6a067f56bb22020eb67983a4fa6f0c861de9 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 16 Feb 2025 13:32:55 +0530 Subject: [PATCH 02/11] adding the timeseries and data module --- pytorch_forecasting/data/data_modules.py | 370 +++++++++++++++++++++++ pytorch_forecasting/data/timeseries.py | 210 +++++++++++++ 2 files changed, 580 insertions(+) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index e69de29bb..767efd2c7 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -0,0 +1,370 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + predict_mode : bool, default=False + Whether the module is in prediction mode. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + self.time_series_dataset = time_series_dataset + self.metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + + for idx, col in enumerate(self.metadata["cols"]["x"]): + if self.metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + processed_data = [] + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + # if torch.isnan(target).any(): + # (f"Warning: NaNs detected. Sample index: {idx}, Value: {target}") + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + features = sample["x"] + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + # if start_idx + enc_length + pred_length > len(data['target']): + # print(f"start_idx: {start_idx}, enc_length: {enc_length}, + # pred_length: {pred_length}, target length: {len(data['target'])}") + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices].abs().mean() + if target_scale == 0: + target_scale = torch.tensor(1.0) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + } + + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + + if stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + } + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..8037be9fc 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2668,3 +2668,213 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + ``__getitem__`` returns: + + * ``t``: tensor of shape (n_timepoints) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + + Optionally, the following str-keyed entries can be included: + + * ``t_f``: tensor of shape (n_timepoints_future) + Time index for each time point in the future. + Aligned with ``x_f``. + * ``x_f``: tensor of shape (n_timepoints_future, n_features) + Known features for each time point in the future. + Rows are time points, aligned with ``t_f``. + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + * ``weight_f``: tensor of shape (n_timepoints_future), only if weight is + not None. + + ----------------------------------------------------------------------------------- + + ``get_metadata`` returns metadata: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset.""" + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index.""" + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + result.update( + { + "t_f": torch.tensor(future_data[self.time].values), + "x_f": torch.tensor(future_data[self.known].values), + } + ) + + if self.weight: + result["weight_f"] = torch.tensor(future_data[self.weight].values) + + if self.weight: + result["weights"] = torch.tensor(data[self.weight].values) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From c4dd9cfc6ab6083de7068bbac1de02123743934a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 16 Feb 2025 18:10:32 +0530 Subject: [PATCH 03/11] adding predict to setup --- pytorch_forecasting/data/data_modules.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 767efd2c7..52b20eacd 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -62,8 +62,6 @@ class EncoderDecoderTimeSeriesDataModule(LightningDataModule): randomize_length : Union[None, Tuple[float, float], bool], default=False Whether to randomize input sequence length. - predict_mode : bool, default=False - Whether the module is in prediction mode. batch_size : int, default=32 Batch size for DataLoader. num_workers : int, default=0 @@ -314,7 +312,7 @@ def setup(self, stage: Optional[str] = None): self.val_processed, self.val_windows, self.add_relative_time_idx ) - if stage is None or stage == "test": + elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): self.test_processed = self._preprocess_data(self._test_indices) self.test_windows = self._create_windows(self.test_processed) @@ -322,6 +320,13 @@ def setup(self, stage: Optional[str] = None): self.test_dataset = self._ProcessedEncoderDecoderDataset( self.test_processed, self.test_windows, self.add_relative_time_idx ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) def train_dataloader(self): return DataLoader( @@ -348,6 +353,14 @@ def test_dataloader(self): collate_fn=self.collate_fn, ) + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + @staticmethod def collate_fn(batch): x_batch = { From 54d7828d33a53801b4e8e3a2c57b1b7793c9ee32 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 19 Feb 2025 19:28:20 +0530 Subject: [PATCH 04/11] Adding tests and some debugging --- pytorch_forecasting/data/data_modules.py | 8 + pytorch_forecasting/data/timeseries.py | 2 + tests/test_data/test_data_module.py | 432 +++++++++++++++++++++++ 3 files changed, 442 insertions(+) create mode 100644 tests/test_data/test_data_module.py diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 52b20eacd..f2c5752de 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -379,5 +379,13 @@ def collate_fn(batch): "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), } + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + y_batch = torch.stack([y for _, y in batch]) return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 8037be9fc..a08dc3721 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..5d85e55e5 --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,432 @@ +import numpy as np +import pandas as pd +import pytest +from torch.utils.data import DataLoader + +from pytorch_forecasting.data.data_modules import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Generate a sample time series dataset for testing.""" + dates = pd.date_range(start="2023-01-01", periods=50, freq="D") + n_series = 100 + + data = [] + for i in range(n_series): + group_id = i + static_feat = i % 2 + + series = pd.DataFrame( + { + "time": (dates - dates[0]).days, + "group_id": group_id, + "category_1": np.random.randint(0, 3, len(dates), dtype=np.int32), + "category_2": np.random.randint(0, 5, len(dates), dtype=np.int32), + "value_1": np.random.randn(len(dates)).astype(np.float32), + "value_2": np.random.randn(len(dates)).astype(np.float32), + "known_future_1": np.random.randn(len(dates)).astype(np.float32), + "known_future_2": np.random.randint(0, 3, len(dates), dtype=np.int32), + "unknown_future_1": np.random.randn(len(dates)).astype(np.float32), + "target": np.sin(np.linspace(0, 8 * np.pi, len(dates))).astype( + np.float32 + ) + + np.random.randn(len(dates)).astype(np.float32) * 0.1, + "static_feat": np.full(len(dates), static_feat, dtype=np.int32), + } + ) + data.append(series) + + df = pd.concat(data, ignore_index=True) + + df = df.astype( + { + "time": np.int32, + "group_id": np.int32, + "category_1": np.int32, + "category_2": np.int32, + "value_1": np.float32, + "value_2": np.float32, + "known_future_1": np.float32, + "known_future_2": np.int32, + "unknown_future_1": np.float32, + "target": np.float32, + "static_feat": np.int32, + } + ) + + future_dates = pd.date_range(start="2023-02-20", periods=20, freq="D") + future_data = [] + for i in range(n_series): + group_id = i + future_series = pd.DataFrame( + { + "time": (future_dates - dates[0]).days, + "group_id": group_id, + "known_future_1": np.random.randn(len(future_dates)).astype(np.float32), + "known_future_2": np.random.randint( + 0, 3, len(future_dates), dtype=np.int32 + ), + } + ) + future_data.append(future_series) + + future_df = pd.concat(future_data, ignore_index=True) + + future_df = future_df.astype( + { + "time": np.int32, + "group_id": np.int32, + "known_future_1": np.float32, + "known_future_2": np.int32, + } + ) + + ts = TimeSeries( + data=df, + data_future=future_df, + time="time", + target="target", + group=["group_id"], + static=["static_feat"], + cat=["category_1", "category_2", "known_future_2"], + num=["value_1", "value_2", "known_future_1", "unknown_future_1"], + known=["known_future_1", "known_future_2"], + unknown=["unknown_future_1"], + ) + + return ts + + +def test_known_unknown_features(sample_timeseries_data): + """Test handling of known and unknown future features. + + This test checks: + + - Whether metadata correctly identifies known and unknown future features. + - Whether future data is correctly included in the dataset. + - The structure and presence of known future feature tensors in a sample. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup() + batch = next(iter(datamodule.train_dataloader())) + x_batch, _ = batch + + # Verify metadata contains known/unknown information + metadata = sample_timeseries_data.get_metadata() + assert "col_known" in metadata + assert metadata["col_known"]["known_future_1"] == "K" + assert metadata["col_known"]["known_future_2"] == "K" + assert metadata["col_known"]["unknown_future_1"] == "U" + + # Verify future data handling + sample = sample_timeseries_data[0] + assert "x_f" in sample + assert sample["x_f"].shape[1] == 2 # known_future_1 and known_future_2 + + +def test_initialization(sample_timeseries_data): + """Test the initialization of the EncoderDecoderTimeSeriesDataModule. + + This test verifies: + + - The correct assignment of encoder and prediction lengths. + - The default batch size is set correctly. + - Categorical and continuous features are correctly identified. + - Metadata correctly maps categorical and continuous features. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=30, + max_prediction_length=10, + ) + + assert datamodule.max_encoder_length == 30 + assert datamodule.max_prediction_length == 10 + assert datamodule.batch_size == 32 + + # Check correct identification of categorical and continuous features + assert len(datamodule.categorical_indices) == 3 # category_1, category_2 + assert len(datamodule.continuous_indices) == 5 # value_1, value_2, static_feat + + # You might also want to verify the actual indices are correct + metadata = sample_timeseries_data.get_metadata() + feature_cols = metadata["cols"]["x"] + + # Verify categorical indices point to the right columns + for idx in datamodule.categorical_indices: + assert metadata["col_type"][feature_cols[idx]] == "C" + + # Verify continuous indices point to the right columns + for idx in datamodule.continuous_indices: + assert metadata["col_type"][feature_cols[idx]] == "F" + + +def test_setup_train_val_split(sample_timeseries_data): + """Test dataset splitting into train and validation sets. + + This test ensures: + + - The `setup` method properly splits the dataset. + - The train and validation datasets are correctly created. + - The size of the train dataset matches expectations based on the split ratio. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + train_val_test_split=(0.7, 0.15, 0.15), + ) + + datamodule.setup(stage="fit") + + # Verify dataset creation + assert hasattr(datamodule, "train_dataset") + assert hasattr(datamodule, "val_dataset") + + # Check split sizes + expected_train_size = int(0.7 * len(sample_timeseries_data)) + assert len(datamodule._train_indices) == expected_train_size + + +def test_data_loading(sample_timeseries_data): + """Test data loading and batch structure. + + This test checks: + + - The train dataloader is correctly instantiated. + - The batch contains all necessary components. + - The categorical and continuous features have the correct dimensions. + - The target tensor has the expected shape. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + batch_size=16, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + + # Verify DataLoader + assert isinstance(train_loader, DataLoader) + assert train_loader.batch_size == 16 + + # Check batch structure + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + # Verify all required components are present + expected_keys = { + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + } + assert all(key in x_batch for key in expected_keys) + + # Check shapes + batch_size = 16 + assert x_batch["encoder_cat"].shape == ( + batch_size, + 20, + 3, + ) # (batch, time, n_cat_features) + assert x_batch["encoder_cont"].shape == ( + batch_size, + 20, + 5, + ) # (batch, time, n_cont_features) + assert x_batch["decoder_cat"].shape == ( + batch_size, + 5, + 3, + ) # (batch, pred_length, n_cat_features) + assert x_batch["decoder_cont"].shape == ( + batch_size, + 5, + 5, + ) # (batch, pred_length, n_cont_features) + assert y_batch.shape == (batch_size, 5, 1) # (batch, pred_length, n_targets) + + +def test_different_settings(sample_timeseries_data): + """Test different configuration settings. + + This test verifies: + + - The model handles different encoder and prediction lengths correctly. + - Relative time indices and target scales are properly included. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=15, + min_encoder_length=10, + max_prediction_length=3, + min_prediction_length=2, + batch_size=8, + add_relative_time_idx=True, + add_target_scales=True, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[1] == 15 # max_encoder_length + assert x_batch["decoder_cat"].shape[1] == 3 # max_prediction_length + assert x_batch["encoder_time_idx"].shape[1] == 15 + assert "target_scale" in x_batch # verify target scales are included + + +def test_static_features(sample_timeseries_data): + """Test that static features are correctly included. + + This test ensures: + + - Static categorical features are present in the batch. + - Static feature tensor dimensions are as expected. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, _ = batch + + # Verify static features are present + assert "static_categorical_features" in x_batch + assert ( + x_batch["static_categorical_features"].dim() == 3 + ) # (batch_size, 1, n_static_features) + + +def test_group_handling(sample_timeseries_data): + """Test that group information is correctly processed. + + This test verifies: + + - The presence of group identifiers in the batch. + - Group tensor dimensions are as expected. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, _ = batch + + # Verify group information + assert "groups" in x_batch + assert x_batch["groups"].dim() == 2 # (batch_size, 1) + + +def test_window_creation(sample_timeseries_data): + """Test window creation for encoder-decoder time series. + + This test ensures: + + - Windows are correctly generated for each time series. + - Encoder and decoder window sizes match the expected values. + - Window indices reference valid series in the dataset. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + + # Check that windows are created for each series in training set + processed_data = datamodule.train_processed + windows = datamodule.train_windows + + # Verify window parameters + for window in windows: + series_idx, start_idx, enc_length, pred_length = window + assert enc_length == 20 # max_encoder_length + assert pred_length == 5 # max_prediction_length + assert series_idx < len(processed_data) + + +def test_prediction_mode(sample_timeseries_data): + """Test the behavior of the datamodule in prediction mode. + + This test checks: + + - Whether the prediction dataset is properly created. + - The structure of the prediction batch. + - The presence of target scale information. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="predict") + predict_loader = datamodule.predict_dataloader() + + # Check prediction dataset + assert hasattr(datamodule, "predict_dataset") + + # Verify prediction batch structure + batch = next(iter(predict_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[1] == 20 + assert x_batch["decoder_cat"].shape[1] == 5 + assert "target_scale" in x_batch + + +@pytest.mark.parametrize( + "train_val_test_split", [(0.6, 0.2, 0.2), (0.8, 0.1, 0.1), (0.7, 0.15, 0.15)] +) +def test_different_splits(sample_timeseries_data, train_val_test_split): + """Test different train-validation-test splits. + + This test verifies: + + - The dataset is correctly split according to different ratios. + - The sizes of train, validation, and test sets match expected values. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + train_val_test_split=train_val_test_split, + ) + + datamodule.setup(stage="fit") + total_size = len(sample_timeseries_data) + expected_train_size = int(train_val_test_split[0] * total_size) + expected_val_size = int(train_val_test_split[1] * total_size) + expected_test_size = int(train_val_test_split[2] * total_size) + + assert len(datamodule._train_indices) == expected_train_size + assert len(datamodule._val_indices) == expected_val_size + assert len(datamodule._test_indices) == expected_test_size From 9f8256ed84a84a0461dba33c3c2811fa2f773104 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 19 Feb 2025 19:30:58 +0530 Subject: [PATCH 05/11] debug --- tests/test_data/test_data_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 5d85e55e5..5600244fe 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -155,7 +155,7 @@ def test_initialization(sample_timeseries_data): assert len(datamodule.categorical_indices) == 3 # category_1, category_2 assert len(datamodule.continuous_indices) == 5 # value_1, value_2, static_feat - # You might also want to verify the actual indices are correct + # Verify the actual indices are correct metadata = sample_timeseries_data.get_metadata() feature_cols = metadata["cols"]["x"] From fda5f7edef7a76f7f90358d4997ff5b8d4afe9f0 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 19 Feb 2025 19:35:34 +0530 Subject: [PATCH 06/11] update comments --- tests/test_data/test_data_module.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 5600244fe..52ecab27b 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -152,8 +152,12 @@ def test_initialization(sample_timeseries_data): assert datamodule.batch_size == 32 # Check correct identification of categorical and continuous features - assert len(datamodule.categorical_indices) == 3 # category_1, category_2 - assert len(datamodule.continuous_indices) == 5 # value_1, value_2, static_feat + assert ( + len(datamodule.categorical_indices) == 3 + ) # category_1, category_2, known_future_2 + assert ( + len(datamodule.continuous_indices) == 5 + ) # value_1, value_2, static_feat, known_future_1, unknown_future_1 # Verify the actual indices are correct metadata = sample_timeseries_data.get_metadata() From d2a1f3881f8dc78c4d1e8f1a3563e9d948f6cfdf Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 7 Mar 2025 17:03:55 +0530 Subject: [PATCH 07/11] version-2 commit --- pytorch_forecasting/data/data_modules.py | 55 ++++++++++++++++++------ pytorch_forecasting/data/timeseries.py | 55 +++++++++++++++++++++--- 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index f2c5752de..450b83790 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -141,29 +141,44 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: sample = self.time_series_dataset[idx.item()] target = sample["y"] - # if torch.isnan(target).any(): - # (f"Warning: NaNs detected. Sample index: {idx}, Value: {target}") + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) if isinstance(target, torch.Tensor): target = target.float() else: target = torch.tensor(target, dtype=torch.float32) - features = sample["x"] if isinstance(features, torch.Tensor): features = features.float() else: features = torch.tensor(features, dtype=torch.float32) + features_imputed = features.clone() + for i in range(features.shape[1]): + if torch.isnan(features[:, i]).any(): + valid_values = features[time_mask, i] + valid_values = valid_values[~torch.isnan(valid_values)] + if len(valid_values) > 0: + mean_value = valid_values.mean() + else: + mean_value = 0.0 + features_imputed[:, i] = torch.where( + torch.isnan(features[:, i]), mean_value, features[:, i] + ) + categorical = ( - features[:, self.categorical_indices] + features_imputed[:, self.categorical_indices] if self.categorical_indices - else torch.zeros((features.shape[0], 0)) + else torch.zeros((features_imputed.shape[0], 0)) ) continuous = ( - features[:, self.continuous_indices] + features_imputed[:, self.continuous_indices] if self.continuous_indices - else torch.zeros((features.shape[0], 0)) + else torch.zeros((features_imputed.shape[0], 0)) ) processed_data.append( @@ -173,6 +188,9 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: "static": sample.get("st", None), "group": sample.get("group", torch.tensor([0])), "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, } ) @@ -208,18 +226,27 @@ def __len__(self): def __getitem__(self, idx): series_idx, start_idx, enc_length, pred_length = self.windows[idx] data = self.processed_data[series_idx] - # if start_idx + enc_length + pred_length > len(data['target']): - # print(f"start_idx: {start_idx}, enc_length: {enc_length}, - # pred_length: {pred_length}, target length: {len(data['target'])}") end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) decoder_indices = slice(start_idx + enc_length, end_idx) - target_scale = data["target"][encoder_indices].abs().mean() - if target_scale == 0: + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: target_scale = torch.tensor(1.0) + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + x = { "encoder_cat": data["features"]["categorical"][encoder_indices], "encoder_cont": data["features"]["continuous"][encoder_indices], @@ -232,6 +259,8 @@ def __getitem__(self, idx): "encoder_time_idx": torch.arange(enc_length), "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, } if data["static"] is not None: @@ -377,6 +406,8 @@ def collate_fn(batch): "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), } if "static_categorical_features" in batch[0][0]: diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index a08dc3721..debba1056 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2838,12 +2838,15 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: else: data = self.data + cutoff_time = data[self.time].max() + result = { "t": data[self.time].values, "y": torch.tensor(data[self.target].values), "x": torch.tensor(data[self.feature_cols].values), "group": torch.tensor([hash(str(group_id))]), "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, } if self.data_future is not None: @@ -2853,18 +2856,58 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: else: future_data = self.data_future + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + # Fill in current data + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + # Fill in known future features + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + # Only fill known features + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + result.update( { - "t_f": torch.tensor(future_data[self.time].values), - "x_f": torch.tensor(future_data[self.known].values), + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), } ) - if self.weight: - result["weight_f"] = torch.tensor(future_data[self.weight].values) - if self.weight: - result["weights"] = torch.tensor(data[self.weight].values) + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) return result From 07c80c6877286e1c8a06098a2ac7f9d147e211bb Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 7 Mar 2025 17:20:45 +0530 Subject: [PATCH 08/11] update the tests (just removed the reference of ) --- tests/test_data/test_data_module.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 52ecab27b..34fd68895 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -125,11 +125,6 @@ def test_known_unknown_features(sample_timeseries_data): assert metadata["col_known"]["known_future_2"] == "K" assert metadata["col_known"]["unknown_future_1"] == "U" - # Verify future data handling - sample = sample_timeseries_data[0] - assert "x_f" in sample - assert sample["x_f"].shape[1] == 2 # known_future_1 and known_future_2 - def test_initialization(sample_timeseries_data): """Test the initialization of the EncoderDecoderTimeSeriesDataModule. From 8d17925a8a97a4ebe601174b1d2ea3f5006bb4fa Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 4 Apr 2025 12:12:22 +0530 Subject: [PATCH 09/11] update the data_module and timeseries --- pytorch_forecasting/data/data_modules.py | 217 ++++++++++++++++++++++- pytorch_forecasting/data/timeseries.py | 85 ++++----- 2 files changed, 249 insertions(+), 53 deletions(-) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 450b83790..476e5efb7 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -98,8 +98,15 @@ def __init__( train_val_test_split: tuple = (0.7, 0.15, 0.15), ): super().__init__() + # TODO: AG comment: probably lot of these variables are useless + # TODO: AG comment: train_val_test_split seems related to groups and it is ok + # IN CASE of global forecasting but here + # TODO: AG comment: we need also the temporal split (ok by percentage) + # AS comment: These variables are copied from the original implementation of + # TimeSeriesDataset, some are here and others in OUR TimeSeries dataset class, + # we should discuss about what to do with these variables. self.time_series_dataset = time_series_dataset - self.metadata = time_series_dataset.get_metadata() + self.time_series_metadata = time_series_dataset.get_metadata() self.max_encoder_length = max_encoder_length self.min_encoder_length = min_encoder_length or max_encoder_length @@ -127,16 +134,128 @@ def __init__( self.categorical_indices = [] self.continuous_indices = [] + self._metadata = None + ## - for idx, col in enumerate(self.metadata["cols"]["x"]): - if self.metadata["col_type"].get(col) == "C": + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": self.categorical_indices.append(idx) else: self.continuous_indices.append(idx) + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + ## TODO: AG comment: if global forecast is FALSE we may want to add + # also the group as categorical variable + + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + ) + + return metadata + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Handles missing values in features by imputing with the mean of valid + historical values. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ processed_data = [] + ##TODO AG comment: for me this implementation hardly depend on d1 layer. You are + ## loading all the d1 layer in memory. This is for sure the most common use case + ## BUT this can be used only for small dataset. I suggest you to have a look to + ## the implementation made by Sandeep + + # AS Comment: We can add chunking pnce the basic backbone is agreed upon by + # everyone. + for idx in indices: sample = self.time_series_dataset[idx.item()] @@ -168,7 +287,8 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: mean_value = 0.0 features_imputed[:, i] = torch.where( torch.isnan(features[:, i]), mean_value, features[:, i] - ) + ) ##TODO AG comment: for me interpolating is something + # to do before D1 layer categorical = ( features_imputed[:, self.categorical_indices] @@ -180,7 +300,8 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: if self.continuous_indices else torch.zeros((features_imputed.shape[0], 0)) ) - + ##TODO AG COMMMENT: this can slow down all the process, + # dictionaries can be very resource consuming processed_data.append( { "features": {"categorical": categorical, "continuous": continuous}, @@ -224,6 +345,57 @@ def __len__(self): return len(self.windows) def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + + ## TODO AG commment: we need to be sure that the sample idx leads to + # something that is valid aka has no nan values + ## the check sould be done before and not during the getitem in my opinion + ## in my mind BEFORE calling the get_item you already know, given the + # index i, which + ## time series you need to retrieve from the d1 layer AND which slice of it + ## maybe it worth to precomputed as Sandeep is doing series_idx, start_idx, enc_length, pred_length = self.windows[idx] data = self.processed_data[series_idx] @@ -253,16 +425,22 @@ def __getitem__(self, idx): "decoder_cat": data["features"]["categorical"][decoder_indices], "decoder_cont": data["features"]["continuous"][decoder_indices], "encoder_lengths": torch.tensor(enc_length), + ##TODO AG comment: not useful in the getitem "decoder_lengths": torch.tensor(pred_length), + ##TODO AG comment: not useful in the getitem "decoder_target_lengths": torch.tensor(pred_length), + ##TODO AG comment: not useful in the getitem "groups": data["group"], "encoder_time_idx": torch.arange(enc_length), + ##TODO AG comment: not useful in the getitem "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + ##TODO AG comment: not useful in the getitem "target_scale": target_scale, "encoder_mask": encoder_mask, "decoder_mask": decoder_mask, } - + # AS comment: The getitem of original implementation also returns similar + # keys, should I remove these and directly add them to collate_fn? if data["static"] is not None: x["static_categorical_features"] = data["static"].unsqueeze(0) x["static_continuous_features"] = torch.zeros((1, 0)) @@ -276,6 +454,21 @@ def __getitem__(self, idx): def _create_windows( self, processed_data: List[Dict[str, Any]] ) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `processed_data`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ windows = [] for idx, data in enumerate(processed_data): @@ -314,6 +507,17 @@ def _create_windows( return windows def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares all datasets. + """ total_series = len(self.time_series_dataset) self._split_indices = torch.randperm(total_series) @@ -340,6 +544,7 @@ def setup(self, stage: Optional[str] = None): self.val_dataset = self._ProcessedEncoderDecoderDataset( self.val_processed, self.val_windows, self.add_relative_time_idx ) + # print(self.val_dataset[0]) elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index debba1056..f8284bdbe 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2675,48 +2675,6 @@ def _coerce_to_dict(obj): class TimeSeries(Dataset): """PyTorch Dataset for time series data stored in pandas DataFrame. - ``__getitem__`` returns: - - * ``t``: tensor of shape (n_timepoints) - Time index for each time point in the past or present. Aligned with ``y``, - and ``x`` not ending in ``f``. - * ``y``: tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with ``t``. - * ``x``: tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with ``t``. - * ``group``: tensor of shape (n_groups) - Group identifiers for time series instances. - * ``st``: tensor of shape (n_static_features) - Static features. - - Optionally, the following str-keyed entries can be included: - - * ``t_f``: tensor of shape (n_timepoints_future) - Time index for each time point in the future. - Aligned with ``x_f``. - * ``x_f``: tensor of shape (n_timepoints_future, n_features) - Known features for each time point in the future. - Rows are time points, aligned with ``t_f``. - * ``weights``: tensor of shape (n_timepoints), only if weight is not None - * ``weight_f``: tensor of shape (n_timepoints_future), only if weight is - not None. - - ----------------------------------------------------------------------------------- - - ``get_metadata`` returns metadata: - - * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } - Names of columns for y, x, and static features. - List elements are in same order as column dimensions. - Columns not appearing are assumed to be named (x0, x1, etc.), - (y0, y1, etc.), (st0, st1, etc.). - * ``col_type``: dict[str, str] - maps column names to data types "F" (numerical) and "C" (categorical). - Column names not occurring are assumed "F". - * ``col_known``: dict[str, str] - maps column names to "K" (future known) or "U" (future unknown). - Column names not occurring are assumed "K". - Parameters ---------- data : pd.DataFrame @@ -2807,7 +2765,22 @@ def __init__( self._prepare_metadata() def _prepare_metadata(self): - """Prepare metadata for the dataset.""" + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ self.metadata = { "cols": { "y": self.target, @@ -2829,7 +2802,28 @@ def __len__(self) -> int: return len(self._group_ids) def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: - """Get time series data for given index.""" + """Get time series data for given index. + + It returns: + + * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + * ``cutoff_time``: float or ``numpy.float64`` + Cutoff time for the time series instance. + + Optionally, the following str-keyed entry can be included: + + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + """ group_id = self._group_ids[index] if self.group: @@ -2866,18 +2860,15 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) y_merged = np.full((num_timepoints, len(self.target)), np.nan) - # Fill in current data current_time_indices = {t: i for i, t in enumerate(combined_times)} for i, t in enumerate(data[self.time].values): idx = current_time_indices[t] x_merged[idx] = data[self.feature_cols].values[i] y_merged[idx] = data[self.target].values[i] - # Fill in known future features for i, t in enumerate(future_data[self.time].values): if t in current_time_indices: idx = current_time_indices[t] - # Only fill known features for j, col in enumerate(self.known): if col in self.feature_cols: feature_idx = self.feature_cols.index(col) From cd2d34e6bb2ddb40e29c9a9b4fb5c9fea10fd4b0 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 4 Apr 2025 18:06:35 +0530 Subject: [PATCH 10/11] add _prepare_metadata to datamodule --- pytorch_forecasting/data/data_modules.py | 75 ++++++++++++++++-------- 1 file changed, 51 insertions(+), 24 deletions(-) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 476e5efb7..93f4c9bad 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -135,7 +135,6 @@ def __init__( self.categorical_indices = [] self.continuous_indices = [] self._metadata = None - ## for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): if self.time_series_metadata["col_type"].get(col) == "C": @@ -143,32 +142,29 @@ def __init__( else: self.continuous_indices.append(idx) - @property - def metadata(self): - """Compute metadata for model initialization. - - This property returns a dictionary containing the shapes and key information - related to the time series model. The metadata includes: - - * ``encoder_cat``: Number of categorical variables in the encoder. - * ``encoder_cont``: Number of continuous variables in the encoder. - * ``decoder_cat``: Number of categorical variables in the decoder that are - known in advance. - * ``decoder_cont``: Number of continuous variables in the decoder that are - known in advance. - * ``target``: Number of target variables. + def _prepare_metadata(self): + """Prepare metadata for model initialisation. - If static features are present, the following keys are added: - - * ``static_categorical_features``: Number of static categorical features - * ``static_continuous_features``: Number of static continuous features + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + * ``target``: Number of target variables. + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length - It also contains the following information: - * ``max_encoder_length``: maximum encoder length - * ``max_prediction_length``: maximum prediction length - * ``min_encoder_length``: minimum encoder length - * ``min_prediction_length``: minimum prediction length """ encoder_cat_count = len(self.categorical_indices) encoder_cont_count = len(self.continuous_indices) @@ -230,6 +226,37 @@ def metadata(self): return metadata + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. From 33e2b0427dbd182fc42522a55a52e41b9cba3327 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 4 Apr 2025 18:20:24 +0530 Subject: [PATCH 11/11] update docstring --- pytorch_forecasting/data/data_modules.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py index 93f4c9bad..7950b9f9f 100644 --- a/pytorch_forecasting/data/data_modules.py +++ b/pytorch_forecasting/data/data_modules.py @@ -151,19 +151,37 @@ def _prepare_metadata(self): dictionary containing the following keys: * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. * ``decoder_cat``: Number of categorical variables in the decoder that - are known in advance. + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) * ``decoder_cont``: Number of continuous variables in the decoder that - are known in advance. + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. * ``min_prediction_length``: minimum prediction length - + Taken directly from `self.min_prediction_length`. """ encoder_cat_count = len(self.categorical_indices)