diff --git a/pytorch_forecasting/data/data_modules.py b/pytorch_forecasting/data/data_modules.py new file mode 100644 index 000000000..7950b9f9f --- /dev/null +++ b/pytorch_forecasting/data/data_modules.py @@ -0,0 +1,672 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + # TODO: AG comment: probably lot of these variables are useless + # TODO: AG comment: train_val_test_split seems related to groups and it is ok + # IN CASE of global forecasting but here + # TODO: AG comment: we need also the temporal split (ok by percentage) + # AS comment: These variables are copied from the original implementation of + # TimeSeriesDataset, some are here and others in OUR TimeSeries dataset class, + # we should discuss about what to do with these variables. + self.time_series_dataset = time_series_dataset + self.time_series_metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + ## TODO: AG comment: if global forecast is FALSE we may want to add + # also the group as categorical variable + + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Handles missing values in features by imputing with the mean of valid + historical values. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + processed_data = [] + + ##TODO AG comment: for me this implementation hardly depend on d1 layer. You are + ## loading all the d1 layer in memory. This is for sure the most common use case + ## BUT this can be used only for small dataset. I suggest you to have a look to + ## the implementation made by Sandeep + + # AS Comment: We can add chunking pnce the basic backbone is agreed upon by + # everyone. + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + features_imputed = features.clone() + for i in range(features.shape[1]): + if torch.isnan(features[:, i]).any(): + valid_values = features[time_mask, i] + valid_values = valid_values[~torch.isnan(valid_values)] + if len(valid_values) > 0: + mean_value = valid_values.mean() + else: + mean_value = 0.0 + features_imputed[:, i] = torch.where( + torch.isnan(features[:, i]), mean_value, features[:, i] + ) ##TODO AG comment: for me interpolating is something + # to do before D1 layer + + categorical = ( + features_imputed[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features_imputed.shape[0], 0)) + ) + continuous = ( + features_imputed[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features_imputed.shape[0], 0)) + ) + ##TODO AG COMMMENT: this can slow down all the process, + # dictionaries can be very resource consuming + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + + ## TODO AG commment: we need to be sure that the sample idx leads to + # something that is valid aka has no nan values + ## the check sould be done before and not during the getitem in my opinion + ## in my mind BEFORE calling the get_item you already know, given the + # index i, which + ## time series you need to retrieve from the d1 layer AND which slice of it + ## maybe it worth to precomputed as Sandeep is doing + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + ##TODO AG comment: not useful in the getitem + "decoder_lengths": torch.tensor(pred_length), + ##TODO AG comment: not useful in the getitem + "decoder_target_lengths": torch.tensor(pred_length), + ##TODO AG comment: not useful in the getitem + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + ##TODO AG comment: not useful in the getitem + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + ##TODO AG comment: not useful in the getitem + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + # AS comment: The getitem of original implementation also returns similar + # keys, should I remove these and directly add them to collate_fn? + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `processed_data`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares all datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + # print(self.val_dataset[0]) + + elif stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..f8284bdbe 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) @@ -2668,3 +2670,247 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + It returns: + + * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + * ``cutoff_time``: float or ``numpy.float64`` + Cutoff time for the time series instance. + + Optionally, the following str-keyed entry can be included: + + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..34fd68895 --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,431 @@ +import numpy as np +import pandas as pd +import pytest +from torch.utils.data import DataLoader + +from pytorch_forecasting.data.data_modules import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Generate a sample time series dataset for testing.""" + dates = pd.date_range(start="2023-01-01", periods=50, freq="D") + n_series = 100 + + data = [] + for i in range(n_series): + group_id = i + static_feat = i % 2 + + series = pd.DataFrame( + { + "time": (dates - dates[0]).days, + "group_id": group_id, + "category_1": np.random.randint(0, 3, len(dates), dtype=np.int32), + "category_2": np.random.randint(0, 5, len(dates), dtype=np.int32), + "value_1": np.random.randn(len(dates)).astype(np.float32), + "value_2": np.random.randn(len(dates)).astype(np.float32), + "known_future_1": np.random.randn(len(dates)).astype(np.float32), + "known_future_2": np.random.randint(0, 3, len(dates), dtype=np.int32), + "unknown_future_1": np.random.randn(len(dates)).astype(np.float32), + "target": np.sin(np.linspace(0, 8 * np.pi, len(dates))).astype( + np.float32 + ) + + np.random.randn(len(dates)).astype(np.float32) * 0.1, + "static_feat": np.full(len(dates), static_feat, dtype=np.int32), + } + ) + data.append(series) + + df = pd.concat(data, ignore_index=True) + + df = df.astype( + { + "time": np.int32, + "group_id": np.int32, + "category_1": np.int32, + "category_2": np.int32, + "value_1": np.float32, + "value_2": np.float32, + "known_future_1": np.float32, + "known_future_2": np.int32, + "unknown_future_1": np.float32, + "target": np.float32, + "static_feat": np.int32, + } + ) + + future_dates = pd.date_range(start="2023-02-20", periods=20, freq="D") + future_data = [] + for i in range(n_series): + group_id = i + future_series = pd.DataFrame( + { + "time": (future_dates - dates[0]).days, + "group_id": group_id, + "known_future_1": np.random.randn(len(future_dates)).astype(np.float32), + "known_future_2": np.random.randint( + 0, 3, len(future_dates), dtype=np.int32 + ), + } + ) + future_data.append(future_series) + + future_df = pd.concat(future_data, ignore_index=True) + + future_df = future_df.astype( + { + "time": np.int32, + "group_id": np.int32, + "known_future_1": np.float32, + "known_future_2": np.int32, + } + ) + + ts = TimeSeries( + data=df, + data_future=future_df, + time="time", + target="target", + group=["group_id"], + static=["static_feat"], + cat=["category_1", "category_2", "known_future_2"], + num=["value_1", "value_2", "known_future_1", "unknown_future_1"], + known=["known_future_1", "known_future_2"], + unknown=["unknown_future_1"], + ) + + return ts + + +def test_known_unknown_features(sample_timeseries_data): + """Test handling of known and unknown future features. + + This test checks: + + - Whether metadata correctly identifies known and unknown future features. + - Whether future data is correctly included in the dataset. + - The structure and presence of known future feature tensors in a sample. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup() + batch = next(iter(datamodule.train_dataloader())) + x_batch, _ = batch + + # Verify metadata contains known/unknown information + metadata = sample_timeseries_data.get_metadata() + assert "col_known" in metadata + assert metadata["col_known"]["known_future_1"] == "K" + assert metadata["col_known"]["known_future_2"] == "K" + assert metadata["col_known"]["unknown_future_1"] == "U" + + +def test_initialization(sample_timeseries_data): + """Test the initialization of the EncoderDecoderTimeSeriesDataModule. + + This test verifies: + + - The correct assignment of encoder and prediction lengths. + - The default batch size is set correctly. + - Categorical and continuous features are correctly identified. + - Metadata correctly maps categorical and continuous features. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=30, + max_prediction_length=10, + ) + + assert datamodule.max_encoder_length == 30 + assert datamodule.max_prediction_length == 10 + assert datamodule.batch_size == 32 + + # Check correct identification of categorical and continuous features + assert ( + len(datamodule.categorical_indices) == 3 + ) # category_1, category_2, known_future_2 + assert ( + len(datamodule.continuous_indices) == 5 + ) # value_1, value_2, static_feat, known_future_1, unknown_future_1 + + # Verify the actual indices are correct + metadata = sample_timeseries_data.get_metadata() + feature_cols = metadata["cols"]["x"] + + # Verify categorical indices point to the right columns + for idx in datamodule.categorical_indices: + assert metadata["col_type"][feature_cols[idx]] == "C" + + # Verify continuous indices point to the right columns + for idx in datamodule.continuous_indices: + assert metadata["col_type"][feature_cols[idx]] == "F" + + +def test_setup_train_val_split(sample_timeseries_data): + """Test dataset splitting into train and validation sets. + + This test ensures: + + - The `setup` method properly splits the dataset. + - The train and validation datasets are correctly created. + - The size of the train dataset matches expectations based on the split ratio. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + train_val_test_split=(0.7, 0.15, 0.15), + ) + + datamodule.setup(stage="fit") + + # Verify dataset creation + assert hasattr(datamodule, "train_dataset") + assert hasattr(datamodule, "val_dataset") + + # Check split sizes + expected_train_size = int(0.7 * len(sample_timeseries_data)) + assert len(datamodule._train_indices) == expected_train_size + + +def test_data_loading(sample_timeseries_data): + """Test data loading and batch structure. + + This test checks: + + - The train dataloader is correctly instantiated. + - The batch contains all necessary components. + - The categorical and continuous features have the correct dimensions. + - The target tensor has the expected shape. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + batch_size=16, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + + # Verify DataLoader + assert isinstance(train_loader, DataLoader) + assert train_loader.batch_size == 16 + + # Check batch structure + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + # Verify all required components are present + expected_keys = { + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + } + assert all(key in x_batch for key in expected_keys) + + # Check shapes + batch_size = 16 + assert x_batch["encoder_cat"].shape == ( + batch_size, + 20, + 3, + ) # (batch, time, n_cat_features) + assert x_batch["encoder_cont"].shape == ( + batch_size, + 20, + 5, + ) # (batch, time, n_cont_features) + assert x_batch["decoder_cat"].shape == ( + batch_size, + 5, + 3, + ) # (batch, pred_length, n_cat_features) + assert x_batch["decoder_cont"].shape == ( + batch_size, + 5, + 5, + ) # (batch, pred_length, n_cont_features) + assert y_batch.shape == (batch_size, 5, 1) # (batch, pred_length, n_targets) + + +def test_different_settings(sample_timeseries_data): + """Test different configuration settings. + + This test verifies: + + - The model handles different encoder and prediction lengths correctly. + - Relative time indices and target scales are properly included. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=15, + min_encoder_length=10, + max_prediction_length=3, + min_prediction_length=2, + batch_size=8, + add_relative_time_idx=True, + add_target_scales=True, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[1] == 15 # max_encoder_length + assert x_batch["decoder_cat"].shape[1] == 3 # max_prediction_length + assert x_batch["encoder_time_idx"].shape[1] == 15 + assert "target_scale" in x_batch # verify target scales are included + + +def test_static_features(sample_timeseries_data): + """Test that static features are correctly included. + + This test ensures: + + - Static categorical features are present in the batch. + - Static feature tensor dimensions are as expected. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, _ = batch + + # Verify static features are present + assert "static_categorical_features" in x_batch + assert ( + x_batch["static_categorical_features"].dim() == 3 + ) # (batch_size, 1, n_static_features) + + +def test_group_handling(sample_timeseries_data): + """Test that group information is correctly processed. + + This test verifies: + + - The presence of group identifiers in the batch. + - Group tensor dimensions are as expected. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + train_loader = datamodule.train_dataloader() + batch = next(iter(train_loader)) + x_batch, _ = batch + + # Verify group information + assert "groups" in x_batch + assert x_batch["groups"].dim() == 2 # (batch_size, 1) + + +def test_window_creation(sample_timeseries_data): + """Test window creation for encoder-decoder time series. + + This test ensures: + + - Windows are correctly generated for each time series. + - Encoder and decoder window sizes match the expected values. + - Window indices reference valid series in the dataset. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="fit") + + # Check that windows are created for each series in training set + processed_data = datamodule.train_processed + windows = datamodule.train_windows + + # Verify window parameters + for window in windows: + series_idx, start_idx, enc_length, pred_length = window + assert enc_length == 20 # max_encoder_length + assert pred_length == 5 # max_prediction_length + assert series_idx < len(processed_data) + + +def test_prediction_mode(sample_timeseries_data): + """Test the behavior of the datamodule in prediction mode. + + This test checks: + + - Whether the prediction dataset is properly created. + - The structure of the prediction batch. + - The presence of target scale information. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + ) + + datamodule.setup(stage="predict") + predict_loader = datamodule.predict_dataloader() + + # Check prediction dataset + assert hasattr(datamodule, "predict_dataset") + + # Verify prediction batch structure + batch = next(iter(predict_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[1] == 20 + assert x_batch["decoder_cat"].shape[1] == 5 + assert "target_scale" in x_batch + + +@pytest.mark.parametrize( + "train_val_test_split", [(0.6, 0.2, 0.2), (0.8, 0.1, 0.1), (0.7, 0.15, 0.15)] +) +def test_different_splits(sample_timeseries_data, train_val_test_split): + """Test different train-validation-test splits. + + This test verifies: + + - The dataset is correctly split according to different ratios. + - The sizes of train, validation, and test sets match expected values. + """ + datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=20, + max_prediction_length=5, + train_val_test_split=train_val_test_split, + ) + + datamodule.setup(stage="fit") + total_size = len(sample_timeseries_data) + expected_train_size = int(train_val_test_split[0] * total_size) + expected_val_size = int(train_val_test_split[1] * total_size) + expected_test_size = int(train_val_test_split[2] * total_size) + + assert len(datamodule._train_indices) == expected_train_size + assert len(datamodule._val_indices) == expected_val_size + assert len(datamodule._test_indices) == expected_test_size