diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 29aeb24f5..0a9d600f8 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -19,6 +19,7 @@ TemporalFusionTransformer, ) from pytorch_forecasting.models.tide import TiDEModel +from pytorch_forecasting.models.timexer import TimeXer __all__ = [ "NBeats", @@ -37,4 +38,5 @@ "MultiEmbedding", "DecoderMLP", "TiDEModel", + "TimeXer", ] diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index ad0e210a5..28754298a 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -9,6 +9,7 @@ class DeepARMetadata(_BasePtForecaster): _tags = { "info:name": "DeepAR", "info:compute": 3, + "object_type": "ptf-v1", "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/models/nbeats/_nbeats_metadata.py b/pytorch_forecasting/models/nbeats/_nbeats_metadata.py index 9910a0ba1..f644b378a 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_metadata.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_metadata.py @@ -9,6 +9,7 @@ class NBeatsMetadata(_BasePtForecaster): _tags = { "info:name": "NBeats", "info:compute": 1, + "object_type": "ptf-v1", "authors": ["jdb78"], "capability:exogenous": False, "capability:multivariate": False, diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py new file mode 100644 index 000000000..91e2440ed --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -0,0 +1,61 @@ +"""TFT metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class TFTMetadata(_BasePtForecaster): + """TFT metadata container.""" + + _tags = { + "info:name": "TFT", + "object_type": "ptf-v2", + "authors": ["phoeenniixx"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT + + return TFT + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + return [ + {}, + dict( + hidden_size=25, + attention_head_size=5, + ), + dict( + data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3) + ), + dict( + hidden_size=24, + attention_head_size=8, + data_loader_kwargs=dict( + max_encoder_length=5, + max_prediction_length=3, + add_relative_time_idx=False, + ), + ), + dict( + hidden_size=12, + data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10), + ), + dict(attention_head_size=2), + ] diff --git a/pytorch_forecasting/models/tide/_tide_metadata.py b/pytorch_forecasting/models/tide/_tide_metadata.py index 502229b71..49a2acc67 100644 --- a/pytorch_forecasting/models/tide/_tide_metadata.py +++ b/pytorch_forecasting/models/tide/_tide_metadata.py @@ -9,6 +9,7 @@ class TiDEModelMetadata(_BasePtForecaster): _tags = { "info:name": "TiDEModel", "info:compute": 3, + "object_type": "ptf-v1", "authors": ["Sohaib-Ahmed21"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/models/timexer/__init__.py b/pytorch_forecasting/models/timexer/__init__.py new file mode 100644 index 000000000..8d3d51d94 --- /dev/null +++ b/pytorch_forecasting/models/timexer/__init__.py @@ -0,0 +1,29 @@ +""" +TimeXer model for forecasting time series. +""" + +from pytorch_forecasting.models.timexer._timexer import TimeXer +from pytorch_forecasting.models.timexer.sub_modules import ( + AttentionLayer, + DataEmbedding_inverted, + Encoder, + EncoderLayer, + EnEmbedding, + FlattenHead, + FullAttention, + PositionalEmbedding, + TriangularCausalMask, +) + +__all__ = [ + "TimeXer", + "TriangularCausalMask", + "FullAttention", + "AttentionLayer", + "DataEmbedding_inverted", + "PositionalEmbedding", + "FlattenHead", + "EnEmbedding", + "Encoder", + "EncoderLayer", +] diff --git a/pytorch_forecasting/models/timexer/_timexer.py b/pytorch_forecasting/models/timexer/_timexer.py new file mode 100644 index 000000000..94eb51a25 --- /dev/null +++ b/pytorch_forecasting/models/timexer/_timexer.py @@ -0,0 +1,267 @@ +""" +Time Series Transformer with eXogenous variables (TimeXer) +--------------------------------------------------------- +""" + +####################################################### +# Note: This is an example version to demonstrate the +# working of the TimeXer model with the exisiting v2 +# designs. The pending work includes building the D2 +# layer and base tslib model. +###################################################### + +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.metrics import QuantileLoss +from pytorch_forecasting.models.base._base_model_v2 import BaseModel +from pytorch_forecasting.models.timexer.sub_modules import ( + AttentionLayer, + DataEmbedding_inverted, + Encoder, + EncoderLayer, + EnEmbedding, + FlattenHead, + FullAttention, +) + + +class TimeXer(BaseModel): + def __init__( + self, + context_length: int, + prediction_length: int, + loss: nn.Module, + logging_metrics: Optional[list[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[dict] = None, + task_name: str = "long_term_forecast", + features: str = "MS", + enc_in: int = None, + d_model: int = 512, + n_heads: int = 8, + e_layers: int = 2, + d_ff: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable] = "torch.nn.functional.relu", + patch_length: int = 24, + use_norm: bool = False, + factor: int = 5, + embed_type: str = "fixed", + freq: str = "h", + metadata: Optional[dict] = None, + target_positions: torch.LongTensor = None, + ): + """An implementation of the TimeXer model. + TimeXer empowers the canonical transformer with the ability to reconcile + endogenous and exogenous information without any architectural modifications + and achieves consistent state-of-the-art performance across twelve real-world + forecasting benchmarks. + TimeXer employs patch-level and variate-level representations respectively for + endogenous and exogenous variables, with an endogenous global token as a bridge + in-between. With this design, TimeXer can jointly capture intra-endogenous + temporal dependencies and exogenous-to-endogenous correlations. + TimeXer model for time series forecasting with exogenous variables. + """ + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params or {}, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params or {}, + ) + + self.context_length = context_length + self.prediction_length = prediction_length + self.task_name = task_name + self.features = features + self.d_model = d_model + self.n_heads = n_heads + self.e_layers = e_layers + self.d_ff = d_ff + self.activation = activation + self.patch_length = patch_length + self.use_norm = use_norm + self.factor = factor + self.embed_type = embed_type + self.freq = freq + self.metadata = metadata + self.n_target_vars = self.metadata["target"] + self.target_positions = target_positions + self.enc_in = self.metadata["encoder_cont"] + self.patch_num = self.context_length // self.patch_length + self.dropout = dropout + + self.n_quantiles = None + + if isinstance(loss, QuantileLoss): + self.n_quantiles = len(loss.quantiles) + + self.en_embedding = EnEmbedding( + self.n_target_vars, + self.d_model, + self.patch_length, + self.dropout, + ) + + self.ex_embedding = DataEmbedding_inverted( + self.context_length, + self.d_model, + self.embed_type, + self.freq, + self.dropout, + ) + + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.d_model, + self.n_heads, + ), + AttentionLayer( + FullAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.d_model, + self.n_heads, + ), + self.d_model, + self.d_ff, + dropout=self.dropout, + activation=self.activation, + ) + for l in range(self.e_layers) + ], + norm_layer=torch.nn.LayerNorm(self.d_model), + ) + self.head_nf = self.d_model * (self.patch_num + 1) + self.head = FlattenHead( + self.enc_in, + self.head_nf, + self.prediction_length, + head_dropout=self.dropout, + n_quantiles=self.n_quantiles, + ) + + def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forecast for univariate or multivariate with single target (MS) case. + Args: + x: Dictionary containing entries for encoder_cat, encoder_cont + """ + batch_size = x["encoder_cont"].shape[0] + encoder_cont = x["encoder_cont"] + encoder_time_idx = x.get("encoder_time_idx", None) + past_target = x.get( + "target", + torch.zeros(batch_size, self.prediction_length, 0, device=self.device), + ) + + if encoder_time_idx is not None and encoder_time_idx.dim() == 2: + # change [batch_size, time_steps] to [batch_size, time_steps, features] + encoder_time_idx = encoder_time_idx.unsqueeze(-1) + + en_embed, n_vars = self.en_embedding(past_target.permute(0, 2, 1)) + ex_embed = self.ex_embedding(encoder_cont, encoder_time_idx) + + enc_out = self.encoder(en_embed, ex_embed) + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) + ) + + enc_out = enc_out.permute(0, 1, 3, 2) + + dec_out = self.head(enc_out) + if self.n_quantiles is not None: + dec_out = dec_out.permute(0, 2, 1, 3) + else: + dec_out = dec_out.permute(0, 2, 1) + + return dec_out + + def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forecast for multivariate with multiple targets (M) case. + Args: + x: Dictionary containing entries for encoder_cat, encoder_cont + Returns: + Dictionary with predictions + """ + + batch_size = x["encoder_cont"].shape[0] + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.prediction_length, device=self.device), + ) + encoder_time_idx = x.get("encoder_time_idx", None) + encoder_targets = x.get( + "target", + torch.zeros(batch_size, self.prediction_length, device=self.device), + ) + en_embed, n_vars = self.en_embedding(encoder_targets.permute(0, 2, 1)) + ex_embed = self.ex_embedding(encoder_cont, encoder_time_idx) + + # batch_size x sequence_length x d_model + enc_out = self.encoder(en_embed, ex_embed) + + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) + ) # batch_size x n_vars x sequence_length x d_model + + enc_out = enc_out.permute(0, 1, 3, 2) + + dec_out = self.head(enc_out) + if self.n_quantiles is not None: + dec_out = dec_out.permute(0, 2, 1, 3) + else: + dec_out = dec_out.permute(0, 2, 1) + + return dec_out + + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the model. + Args: + x: Dictionary containing model inputs + Returns: + Dictionary with model outputs + """ + if ( + self.task_name == "long_term_forecast" + or self.task_name == "short_term_forecast" + ): # noqa: E501 + if self.features == "M": + out = self._forecast_multi(x) + else: + out = self._forecast(x) + prediction = out[:, : self.prediction_length, :] + + # note: prediction.size(2) is the number of target variables i.e n_targets + target_indices = range(prediction.size(2)) + + if self.n_quantiles is not None: + prediction = [prediction[..., i, :] for i in target_indices] + else: + if len(target_indices) == 1: + prediction = prediction[..., 0] + else: + prediction = [prediction[..., i] for i in target_indices] + return {"prediction": prediction} + else: + return None diff --git a/pytorch_forecasting/models/timexer/sub_modules.py b/pytorch_forecasting/models/timexer/sub_modules.py new file mode 100644 index 000000000..c13b9fc61 --- /dev/null +++ b/pytorch_forecasting/models/timexer/sub_modules.py @@ -0,0 +1,251 @@ +""" +Implementation of `nn.Modules` for TimeXer model. +""" + +import math +from math import sqrt + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TriangularCausalMask: + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(device) + + @property + def mask(self): + return self._mask + + +class FullAttention(nn.Module): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super().__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1.0 / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + scores.masked_fill_(attn_mask.mask, -np.abs) + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): + super().__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, keys, values, attn_mask, tau=tau, delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super().__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] + return self.dropout(x) + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super().__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + return self.pe[:, : x.size(1)] + + +class FlattenHead(nn.Module): + def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None): + super().__init__() + self.n_vars = n_vars + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.n_quantiles = n_quantiles + + if self.n_quantiles is not None: + self.linear = nn.Linear(nf, target_window * n_quantiles) + else: + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + + if self.n_quantiles is not None: + batch_size, n_vars = x.shape[0], x.shape[1] + x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) + return x + + +class EnEmbedding(nn.Module): + def __init__(self, n_vars, d_model, patch_len, dropout): + super().__init__() + + self.patch_len = patch_len + + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model)) + self.position_embedding = PositionalEmbedding(d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + n_vars = x.shape[1] + glb = self.glb_token.repeat((x.shape[0], 1, 1, 1)) + + x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1])) + x = torch.cat([x, glb], dim=2) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + return self.dropout(x), n_vars + + +class Encoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super().__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + for layer in self.layers: + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x + + +class EncoderLayer(nn.Module): + def __init__( + self, + self_attention, + cross_attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu", + ): + super().__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + B, L, D = cross.shape + x = x + self.dropout( + self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + ) + x = self.norm1(x) + + x_glb_ori = x[:, -1, :].unsqueeze(1) + x_glb = torch.reshape(x_glb_ori, (B, -1, D)) + x_glb_attn = self.dropout( + self.cross_attention( + x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) + x_glb_attn = torch.reshape( + x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]) + ).unsqueeze(1) + x_glb = x_glb_ori + x_glb_attn + x_glb = self.norm2(x_glb) + + y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) + + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) diff --git a/pytorch_forecasting/models/timexer/timexer_v2_metadata.py b/pytorch_forecasting/models/timexer/timexer_v2_metadata.py new file mode 100644 index 000000000..ee3941bef --- /dev/null +++ b/pytorch_forecasting/models/timexer/timexer_v2_metadata.py @@ -0,0 +1,48 @@ +"""TimeXer metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class TimeXerMetadata(_BasePtForecaster): + """TimeXer metadata container.""" + + _tags = { + "info:name": "TimeXer", + "object_type": "ptf-v2", + "authors": ["PranavBhatP"], + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import TimeXer + + return TimeXer + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + import torch.nn as nn + + return [ + dict( + context_length=30, + prediction_length=1, + d_model=32, + n_heads=2, + e_layers=1, + d_ff=64, + patch_length=1, + task_name="long_term_forecast", + features="MS", + ), + ] diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index c5bfd775b..a3b2bba5d 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -1,10 +1,15 @@ +from datetime import datetime + import numpy as np +import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -92,6 +97,233 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) +@pytest.fixture(scope="session") +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_dataloaders_v2(data_with_covariates, **kwargs): + """Create dataloaders with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + max_encoder_length = kwargs.get("max_encoder_length", 4) + max_prediction_length = kwargs.get("max_prediction_length", 3) + + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + add_relative_time_idx = kwargs.get("add_relative_time_idx", True) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + num_workers=0, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @pytest.fixture( params=[ dict(), diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index 062db97dd..d39f6d988 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -1,10 +1,15 @@ +from datetime import datetime + import numpy as np +import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -87,6 +92,232 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_dataloaders_v2(data_with_covariates, **kwargs): + """Create dataloaders with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + add_relative_time_idx = kwargs.get("add_relative_time_idx", True) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + num_workers=0, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @pytest.fixture( params=[ dict(), diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 45d4772d5..add3ef7ba 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -245,6 +245,8 @@ def _integration( class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" + object_type_filter = "ptf-v1" + def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" from skbase.utils.doctest_run import run_doctest diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py new file mode 100644 index 000000000..857ea1f76 --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -0,0 +1,145 @@ +"""Automated tests based on the skbase test suite template.""" + +from inspect import isclass +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import torch.nn as nn + +from pytorch_forecasting.tests._conftest import make_dataloaders_v2 as make_dataloaders +from pytorch_forecasting.tests.test_all_estimators import ( + BaseFixtureGenerator, + PackageConfig, +) + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +def _integration( + estimator_cls, + data_with_covariates, + tmp_path, + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + data_with_covariates = data_with_covariates.copy() + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + + dataloaders_with_covariates = make_dataloaders( + data_with_covariates, **data_loader_default_kwargs + ) + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + test_dataloader = dataloaders_with_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + training_data_module = dataloaders_with_covariates["data_module"] + metadata = training_data_module.metadata + + assert metadata["encoder_cont"] == 14 # 14 features (8 known + 6 unknown) + assert metadata["encoder_cat"] == 0 + assert metadata["decoder_cont"] == 8 # 8 (only known features) + assert metadata["decoder_cat"] == 0 + assert metadata["static_categorical_features"] == 0 + assert ( + metadata["static_continuous_features"] == 2 + ) # 2 (agency_encoded, sku_encoded) + assert metadata["target"] == 1 + + batch_x, batch_y = next(iter(train_dataloader)) + + assert batch_x["encoder_cont"].shape[2] == metadata["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == metadata["encoder_cat"] + + assert batch_x["decoder_cont"].shape[2] == metadata["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == metadata["decoder_cat"] + + if "static_categorical_features" in batch_x: + assert ( + batch_x["static_categorical_features"].shape[2] + == metadata["static_categorical_features"] + ) + + if "static_continuous_features" in batch_x: + assert ( + batch_x["static_continuous_features"].shape[2] + == metadata["static_continuous_features"] + ) + + assert batch_y.shape[2] == metadata["target"] + + net = estimator_cls( + metadata=metadata, + loss=nn.MSELoss(), + **kwargs, + ) + + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + +class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): + """Generic tests for all objects in the mini package.""" + + object_type_filter = "ptf-v2" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + from skbase.utils.doctest_run import run_doctest + + run_doctest(object_class, name=f"class {object_class.__name__}") + + def test_integration( + self, + object_metadata, + trainer_kwargs, + tmp_path, + ): + from pytorch_forecasting.tests._data_scenarios import data_with_covariates_v2 + + data_with_covariates = data_with_covariates_v2() + object_class = object_metadata.get_model_cls() + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs)