Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
b3644a6
test suite
fkiraly Feb 22, 2025
a1d64c6
Merge branch 'main' into test-suite
fkiraly Feb 22, 2025
4b2486e
skeleton
fkiraly Feb 22, 2025
02b0ce6
skeleton
fkiraly Feb 22, 2025
41cbf66
Update test_all_estimators.py
fkiraly Feb 23, 2025
cef62d3
Update _base_object.py
fkiraly Feb 23, 2025
bc2e93b
Update _lookup.py
fkiraly Feb 23, 2025
eee1c86
Update _lookup.py
fkiraly Feb 23, 2025
164fe0d
base metadatda
fkiraly Feb 23, 2025
20e88d0
registry
fkiraly Feb 23, 2025
318c1fb
fix private name
fkiraly Feb 23, 2025
012ab3d
Update _base_object.py
fkiraly Feb 23, 2025
86365a0
test failure
fkiraly Feb 23, 2025
f6dee46
Update test_all_estimators.py
fkiraly Feb 23, 2025
9b0e4ec
Update test_all_estimators.py
fkiraly Feb 23, 2025
7de5285
Update test_all_estimators.py
fkiraly Feb 23, 2025
57dfe3a
test folders
fkiraly Feb 23, 2025
c9f12db
Update test.yml
fkiraly Feb 23, 2025
fa8144e
test integration
fkiraly Feb 23, 2025
232a510
fixes
fkiraly Feb 23, 2025
1c8d4b5
Update _conftest.py
fkiraly Feb 23, 2025
f632e32
try scenarios
fkiraly Feb 23, 2025
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
ef37f55
Merge branch 'main' into test-suite
fkiraly May 1, 2025
a669134
Update _lookup.py
fkiraly May 4, 2025
d78bf5d
Update _lookup.py
fkiraly May 4, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
77cb979
warnings and init attr handling
fkiraly May 13, 2025
28df3c3
Merge branch 'refactor-d1-d2' of https://github.com/phoeenniixx/pytor…
fkiraly May 13, 2025
f8c94e6
simplify TimeSeries.__getitem__
fkiraly May 13, 2025
c289255
Update _timeseries_v2.py
fkiraly May 13, 2025
9467f38
Update data_module.py
fkiraly May 13, 2025
c3b40ad
backwards compat of private/public attrs
fkiraly May 13, 2025
c007310
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 13, 2025
2e25052
Merge branch 'main' into refactor-model
phoeenniixx May 13, 2025
38c28dc
add tests
phoeenniixx May 14, 2025
9d80eb8
add tests
phoeenniixx May 14, 2025
a8ccfe3
add tests
phoeenniixx May 14, 2025
f900ba5
add more docstrings
phoeenniixx May 14, 2025
ed1b799
add note about the commented out tests
phoeenniixx May 14, 2025
c947910
Merge branch 'main' into refactor-model
phoeenniixx May 16, 2025
c0ceb8a
add the commented out tests
phoeenniixx May 16, 2025
3828c26
remove note
phoeenniixx May 16, 2025
6d6d18e
Merge branch 'main' into refactor-model
phoeenniixx May 18, 2025
3144865
Merge branch 'test-suite' of https://github.com/sktime/pytorch-foreca…
phoeenniixx May 20, 2025
30b541b
make the modules private
phoeenniixx May 20, 2025
3f1e11f
Merge remote-tracking branch 'origin/refactor-model' into refactor-model
phoeenniixx May 20, 2025
5cc3ff1
initial commit
phoeenniixx May 20, 2025
1bcf181
Merge branch 'refactor-model' into test-framework
phoeenniixx May 20, 2025
f18e09d
add TFTMetadata class
phoeenniixx May 20, 2025
e1e360e
add TFTMetadata class
phoeenniixx May 20, 2025
168e16a
Merge branch 'main' into test-framework
phoeenniixx May 22, 2025
92c12bf
add TFT tests
phoeenniixx May 25, 2025
1d478d5
remove refactored TFT
phoeenniixx May 27, 2025
f9992f2
Merge branch 'main' into test-framework
phoeenniixx May 28, 2025
d049019
update test_all_estimators
phoeenniixx May 28, 2025
e72486b
linting
phoeenniixx May 28, 2025
7443b0b
Merge branch 'main' into test-framework
phoeenniixx May 29, 2025
a734f26
refactor
phoeenniixx May 29, 2025
7f466b2
Add more test_params
phoeenniixx May 29, 2025
0968452
Add metadata tests
phoeenniixx May 31, 2025
525bbb9
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
48284cf
add timexer
phoeenniixx Jun 1, 2025
4267da6
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4e8f863
add object-filter to ptf-v1
phoeenniixx Jun 1, 2025
487c1ed
Merge branch 'test-framework' into timexer
phoeenniixx Jun 1, 2025
ab2bf2b
linting
phoeenniixx Jun 1, 2025
69822b8
update params
phoeenniixx Jun 1, 2025
9eff8f1
update params
phoeenniixx Jun 1, 2025
4b4ab11
Merge branch 'main' into timexer
phoeenniixx Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TemporalFusionTransformer,
)
from pytorch_forecasting.models.tide import TiDEModel
from pytorch_forecasting.models.timexer import TimeXer

__all__ = [
"NBeats",
Expand All @@ -37,4 +38,5 @@
"MultiEmbedding",
"DecoderMLP",
"TiDEModel",
"TimeXer",
]
1 change: 1 addition & 0 deletions pytorch_forecasting/models/deepar/_deepar_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class DeepARMetadata(_BasePtForecaster):
_tags = {
"info:name": "DeepAR",
"info:compute": 3,
"object_type": "ptf-v1",
"authors": ["jdb78"],
"capability:exogenous": True,
"capability:multivariate": True,
Expand Down
1 change: 1 addition & 0 deletions pytorch_forecasting/models/nbeats/_nbeats_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class NBeatsMetadata(_BasePtForecaster):
_tags = {
"info:name": "NBeats",
"info:compute": 1,
"object_type": "ptf-v1",
"authors": ["jdb78"],
"capability:exogenous": False,
"capability:multivariate": False,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""TFT metadata container."""

from pytorch_forecasting.models.base._base_object import _BasePtForecaster


class TFTMetadata(_BasePtForecaster):
"""TFT metadata container."""

_tags = {
"info:name": "TFT",
"object_type": "ptf-v2",
"authors": ["phoeenniixx"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": False,
}

@classmethod
def get_model_cls(cls):
"""Get model class."""
from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT

return TFT

@classmethod
def get_test_train_params(cls):
"""Return testing parameter settings for the trainer.

Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
return [
{},
dict(
hidden_size=25,
attention_head_size=5,
),
dict(
data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3)
),
dict(
hidden_size=24,
attention_head_size=8,
data_loader_kwargs=dict(
max_encoder_length=5,
max_prediction_length=3,
add_relative_time_idx=False,
),
),
dict(
hidden_size=12,
data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10),
),
dict(attention_head_size=2),
]
1 change: 1 addition & 0 deletions pytorch_forecasting/models/tide/_tide_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TiDEModelMetadata(_BasePtForecaster):
_tags = {
"info:name": "TiDEModel",
"info:compute": 3,
"object_type": "ptf-v1",
"authors": ["Sohaib-Ahmed21"],
"capability:exogenous": True,
"capability:multivariate": True,
Expand Down
29 changes: 29 additions & 0 deletions pytorch_forecasting/models/timexer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
TimeXer model for forecasting time series.
"""

from pytorch_forecasting.models.timexer._timexer import TimeXer
from pytorch_forecasting.models.timexer.sub_modules import (
AttentionLayer,
DataEmbedding_inverted,
Encoder,
EncoderLayer,
EnEmbedding,
FlattenHead,
FullAttention,
PositionalEmbedding,
TriangularCausalMask,
)

__all__ = [
"TimeXer",
"TriangularCausalMask",
"FullAttention",
"AttentionLayer",
"DataEmbedding_inverted",
"PositionalEmbedding",
"FlattenHead",
"EnEmbedding",
"Encoder",
"EncoderLayer",
]
267 changes: 267 additions & 0 deletions pytorch_forecasting/models/timexer/_timexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
"""
Time Series Transformer with eXogenous variables (TimeXer)
---------------------------------------------------------
"""

#######################################################
# Note: This is an example version to demonstrate the
# working of the TimeXer model with the exisiting v2
# designs. The pending work includes building the D2
# layer and base tslib model.
######################################################

from typing import Callable, Optional, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models.base._base_model_v2 import BaseModel
from pytorch_forecasting.models.timexer.sub_modules import (
AttentionLayer,
DataEmbedding_inverted,
Encoder,
EncoderLayer,
EnEmbedding,
FlattenHead,
FullAttention,
)


class TimeXer(BaseModel):
def __init__(
self,
context_length: int,
prediction_length: int,
loss: nn.Module,
logging_metrics: Optional[list[nn.Module]] = None,
optimizer: Optional[Union[Optimizer, str]] = "adam",
optimizer_params: Optional[dict] = None,
lr_scheduler: Optional[str] = None,
lr_scheduler_params: Optional[dict] = None,
task_name: str = "long_term_forecast",
features: str = "MS",
enc_in: int = None,
d_model: int = 512,
n_heads: int = 8,
e_layers: int = 2,
d_ff: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable] = "torch.nn.functional.relu",
patch_length: int = 24,
use_norm: bool = False,
factor: int = 5,
embed_type: str = "fixed",
freq: str = "h",
metadata: Optional[dict] = None,
target_positions: torch.LongTensor = None,
):
"""An implementation of the TimeXer model.
TimeXer empowers the canonical transformer with the ability to reconcile
endogenous and exogenous information without any architectural modifications
and achieves consistent state-of-the-art performance across twelve real-world
forecasting benchmarks.
TimeXer employs patch-level and variate-level representations respectively for
endogenous and exogenous variables, with an endogenous global token as a bridge
in-between. With this design, TimeXer can jointly capture intra-endogenous
temporal dependencies and exogenous-to-endogenous correlations.
TimeXer model for time series forecasting with exogenous variables.
"""
super().__init__(
loss=loss,
logging_metrics=logging_metrics,
optimizer=optimizer,
optimizer_params=optimizer_params or {},
lr_scheduler=lr_scheduler,
lr_scheduler_params=lr_scheduler_params or {},
)

self.context_length = context_length
self.prediction_length = prediction_length
self.task_name = task_name
self.features = features
self.d_model = d_model
self.n_heads = n_heads
self.e_layers = e_layers
self.d_ff = d_ff
self.activation = activation
self.patch_length = patch_length
self.use_norm = use_norm
self.factor = factor
self.embed_type = embed_type
self.freq = freq
self.metadata = metadata
self.n_target_vars = self.metadata["target"]
self.target_positions = target_positions
self.enc_in = self.metadata["encoder_cont"]
self.patch_num = self.context_length // self.patch_length
self.dropout = dropout

self.n_quantiles = None

if isinstance(loss, QuantileLoss):
self.n_quantiles = len(loss.quantiles)

self.en_embedding = EnEmbedding(
self.n_target_vars,
self.d_model,
self.patch_length,
self.dropout,
)

self.ex_embedding = DataEmbedding_inverted(
self.context_length,
self.d_model,
self.embed_type,
self.freq,
self.dropout,
)

self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(
False,
self.factor,
attention_dropout=self.dropout,
output_attention=False,
),
self.d_model,
self.n_heads,
),
AttentionLayer(
FullAttention(
False,
self.factor,
attention_dropout=self.dropout,
output_attention=False,
),
self.d_model,
self.n_heads,
),
self.d_model,
self.d_ff,
dropout=self.dropout,
activation=self.activation,
)
for l in range(self.e_layers)
],
norm_layer=torch.nn.LayerNorm(self.d_model),
)
self.head_nf = self.d_model * (self.patch_num + 1)
self.head = FlattenHead(
self.enc_in,
self.head_nf,
self.prediction_length,
head_dropout=self.dropout,
n_quantiles=self.n_quantiles,
)

def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forecast for univariate or multivariate with single target (MS) case.
Args:
x: Dictionary containing entries for encoder_cat, encoder_cont
"""
batch_size = x["encoder_cont"].shape[0]
encoder_cont = x["encoder_cont"]
encoder_time_idx = x.get("encoder_time_idx", None)
past_target = x.get(
"target",
torch.zeros(batch_size, self.prediction_length, 0, device=self.device),
)

if encoder_time_idx is not None and encoder_time_idx.dim() == 2:
# change [batch_size, time_steps] to [batch_size, time_steps, features]
encoder_time_idx = encoder_time_idx.unsqueeze(-1)

en_embed, n_vars = self.en_embedding(past_target.permute(0, 2, 1))
ex_embed = self.ex_embedding(encoder_cont, encoder_time_idx)

enc_out = self.encoder(en_embed, ex_embed)
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
)

enc_out = enc_out.permute(0, 1, 3, 2)

dec_out = self.head(enc_out)
if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)

return dec_out

def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forecast for multivariate with multiple targets (M) case.
Args:
x: Dictionary containing entries for encoder_cat, encoder_cont
Returns:
Dictionary with predictions
"""

batch_size = x["encoder_cont"].shape[0]
encoder_cont = x.get(
"encoder_cont",
torch.zeros(batch_size, self.prediction_length, device=self.device),
)
encoder_time_idx = x.get("encoder_time_idx", None)
encoder_targets = x.get(
"target",
torch.zeros(batch_size, self.prediction_length, device=self.device),
)
en_embed, n_vars = self.en_embedding(encoder_targets.permute(0, 2, 1))
ex_embed = self.ex_embedding(encoder_cont, encoder_time_idx)

# batch_size x sequence_length x d_model
enc_out = self.encoder(en_embed, ex_embed)

enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
) # batch_size x n_vars x sequence_length x d_model

enc_out = enc_out.permute(0, 1, 3, 2)

dec_out = self.head(enc_out)
if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)

return dec_out

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forward pass of the model.
Args:
x: Dictionary containing model inputs
Returns:
Dictionary with model outputs
"""
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
): # noqa: E501
if self.features == "M":
out = self._forecast_multi(x)
else:
out = self._forecast(x)
prediction = out[:, : self.prediction_length, :]

# note: prediction.size(2) is the number of target variables i.e n_targets
target_indices = range(prediction.size(2))

if self.n_quantiles is not None:
prediction = [prediction[..., i, :] for i in target_indices]
else:
if len(target_indices) == 1:
prediction = prediction[..., 0]
else:
prediction = [prediction[..., i] for i in target_indices]
return {"prediction": prediction}
else:
return None
Loading
Loading