Skip to content

Commit c4155f2

Browse files
authored
Merge branch 'sktime:main' into newDataset
2 parents fda5f7e + 1a2d83c commit c4155f2

File tree

18 files changed

+2892
-2754
lines changed

18 files changed

+2892
-2754
lines changed

docs/source/getting-started.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ Otherwise, you can proceed with
2323
2424
2525
Alternatively, to install the package via ``conda``:
26-
2726
.. code-block:: bash
2827
2928
conda install pytorch-forecasting pytorch>=1.7 -c pytorch -c conda-forge
3029
31-
PyTorch Forecasting is now installed from the conda-forge channel while PyTorch is install from the pytorch channel.
30+
PyTorch Forecasting is now installed from the conda-forge channel while PyTorch is installed from the pytorch channel.
3231

3332
To use the MQF2 loss (multivariate quantile loss), also install
3433

@@ -54,7 +53,7 @@ The general setup for training and testing a model is
5453
Similarly, a test dataset or later a dataset for inference can be created. You can store the dataset parameters
5554
directly if you do not wish to load the entire training dataset at inference time.
5655

57-
#. Instantiate a model using the its ``.from_dataset()`` method.
56+
#. Instantiate a model using the ``.from_dataset()`` method.
5857
#. Create a ``lightning.Trainer()`` object.
5958
#. Find the optimal learning rate with its ``.tuner.lr_find()`` method.
6059
#. Train the model with early stopping on the training dataset and use the tensorboard logs
@@ -65,7 +64,7 @@ The general setup for training and testing a model is
6564
#. Load the model from the model checkpoint and apply it to new data.
6665

6766

68-
The :ref:`Tutorials <tutorials>` section provides detailled guidance and examples on how to use models and implement new ones.
67+
The :ref:`Tutorials <tutorials>` section provides detailed guidance and examples on how to use models and implement new ones.
6968

7069

7170
Example

docs/source/tutorials/ar.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@
268268
],
269269
"source": [
270270
"pl.seed_everything(42)\n",
271-
"trainer = pl.Trainer(accelerator=\"auto\", gradient_clip_val=0.01)\n",
271+
"trainer = pl.Trainer(accelerator=\"auto\", gradient_clip_val=0.1)\n",
272272
"net = NBeats.from_dataset(training, learning_rate=3e-2, weight_decay=1e-2, widths=[32, 512], backcast_loss_ratio=0.1)"
273273
]
274274
},
@@ -448,7 +448,7 @@
448448
" max_epochs=3,\n",
449449
" accelerator=\"auto\",\n",
450450
" enable_model_summary=True,\n",
451-
" gradient_clip_val=0.01,\n",
451+
" gradient_clip_val=0.1,\n",
452452
" callbacks=[early_stop_callback],\n",
453453
" limit_train_batches=150,\n",
454454
")\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ dev = [
106106

107107
# docs - dependencies for building the documentation
108108
docs = [
109-
"sphinx>3.2",
109+
"sphinx>3.2,<7.2.6",
110110
"pydata-sphinx-theme",
111111
"nbsphinx",
112112
"pandoc",

pytorch_forecasting/metrics/_mqf2_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import List, Optional, Tuple
44

5-
from cpflows.flows import DeepConvexFlow, SequentialFlow
65
import torch
76
from torch.distributions import (
87
AffineTransform,
@@ -12,6 +11,11 @@
1211
)
1312
import torch.nn.functional as F
1413

14+
from pytorch_forecasting.utils._dependencies import _safe_import
15+
16+
DeepConvexFlow = _safe_import("cpflows.flows.DeepConvexFlow")
17+
SequentialFlow = _safe_import("cpflows.flows.SequentialFlow")
18+
1519

1620
class DeepConvexNet(DeepConvexFlow):
1721
r"""

pytorch_forecasting/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Models for timeseries forecasting.
33
"""
44

5-
from pytorch_forecasting.models.base_model import (
5+
from pytorch_forecasting.models.base import (
66
AutoRegressiveBaseModel,
77
AutoRegressiveBaseModelWithCovariates,
88
BaseModel,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Base classes for pytorch-foercasting models."""
2+
3+
from pytorch_forecasting.models.base._base_model import (
4+
AutoRegressiveBaseModel,
5+
AutoRegressiveBaseModelWithCovariates,
6+
BaseModel,
7+
BaseModelWithCovariates,
8+
Prediction,
9+
)
10+
11+
__all__ = [
12+
"AutoRegressiveBaseModel",
13+
"AutoRegressiveBaseModelWithCovariates",
14+
"BaseModel",
15+
"BaseModelWithCovariates",
16+
"Prediction",
17+
]

0 commit comments

Comments
 (0)