Skip to content

Commit bfbe6b2

Browse files
authored
[BUG] EXPERIMENTAL PR: Solve the bug in data_module (#1834)
This PR solves the bug in `data_module` where the `static_categorical_features` and `static_continuous_features` were not correctly calculated in `__getitem__` of nested class
1 parent 5685c59 commit bfbe6b2

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

pytorch_forecasting/data/data_module.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,40 @@ def __getitem__(self, idx):
519519
"decoder_mask": decoder_mask,
520520
}
521521
if data["static"] is not None:
522-
x["static_categorical_features"] = data["static"].unsqueeze(0)
523-
x["static_continuous_features"] = torch.zeros((1, 0))
522+
raw_st_tensor = data.get("static")
523+
static_col_names = self.data_module.time_series_metadata["cols"]["st"]
524+
525+
is_categorical_mask = torch.tensor(
526+
[
527+
self.data_module.time_series_metadata["col_type"].get(col_name)
528+
== "C"
529+
for col_name in static_col_names
530+
],
531+
dtype=torch.bool,
532+
)
533+
534+
is_continuous_mask = ~is_categorical_mask
535+
536+
st_cat_values_for_item = raw_st_tensor[is_categorical_mask]
537+
st_cont_values_for_item = raw_st_tensor[is_continuous_mask]
538+
539+
if st_cat_values_for_item.shape[0] > 0:
540+
x["static_categorical_features"] = st_cat_values_for_item.unsqueeze(
541+
0
542+
)
543+
else:
544+
x["static_categorical_features"] = torch.zeros(
545+
(1, 0), dtype=torch.float32
546+
)
547+
548+
if st_cont_values_for_item.shape[0] > 0:
549+
x["static_continuous_features"] = st_cont_values_for_item.unsqueeze(
550+
0
551+
)
552+
else:
553+
x["static_continuous_features"] = torch.zeros(
554+
(1, 0), dtype=torch.float32
555+
)
524556

525557
y = data["target"][decoder_indices]
526558
if y.ndim == 1:

tests/test_data/test_data_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,14 @@ def test_with_static_features():
405405
x, y = dm.train_dataset[0]
406406
assert "static_categorical_features" in x
407407
assert "static_continuous_features" in x
408+
assert (
409+
x["static_categorical_features"].shape[1]
410+
== metadata["static_categorical_features"]
411+
)
412+
assert (
413+
x["static_continuous_features"].shape[1]
414+
== metadata["static_continuous_features"]
415+
)
408416

409417

410418
def test_different_train_val_test_split(sample_timeseries_data):

0 commit comments

Comments
 (0)