Skip to content

Commit 7324ca8

Browse files
committed
correctly calculate the cat and cont
1 parent 5685c59 commit 7324ca8

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

pytorch_forecasting/data/data_module.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,38 @@ def __getitem__(self, idx):
519519
"decoder_mask": decoder_mask,
520520
}
521521
if data["static"] is not None:
522-
x["static_categorical_features"] = data["static"].unsqueeze(0)
523-
x["static_continuous_features"] = torch.zeros((1, 0))
522+
st_cat_values_for_item = []
523+
st_cont_values_for_item = []
524+
raw_st_tensor = data.get("static")
525+
526+
static_col_names = self.data_module.time_series_metadata["cols"]["st"]
527+
for i, col_name in enumerate(static_col_names):
528+
feature_value = raw_st_tensor[i]
529+
if (
530+
self.data_module.time_series_metadata["col_type"].get(col_name)
531+
== "C"
532+
):
533+
st_cat_values_for_item.append(feature_value)
534+
else:
535+
st_cont_values_for_item.append(feature_value)
536+
537+
if st_cat_values_for_item:
538+
x["static_categorical_features"] = torch.stack(
539+
st_cat_values_for_item
540+
).unsqueeze(0)
541+
else:
542+
x["static_categorical_features"] = torch.zeros(
543+
(1, 0), dtype=torch.float32
544+
)
545+
546+
if st_cont_values_for_item:
547+
x["static_continuous_features"] = torch.stack(
548+
st_cont_values_for_item
549+
).unsqueeze(0)
550+
else:
551+
x["static_continuous_features"] = torch.zeros(
552+
(1, 0), dtype=torch.float32
553+
)
524554

525555
y = data["target"][decoder_indices]
526556
if y.ndim == 1:

tests/test_data/test_data_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,14 @@ def test_with_static_features():
405405
x, y = dm.train_dataset[0]
406406
assert "static_categorical_features" in x
407407
assert "static_continuous_features" in x
408+
assert (
409+
x["static_categorical_features"].shape[1]
410+
== metadata["static_categorical_features"]
411+
)
412+
assert (
413+
x["static_continuous_features"].shape[1]
414+
== metadata["static_continuous_features"]
415+
)
408416

409417

410418
def test_different_train_val_test_split(sample_timeseries_data):

0 commit comments

Comments
 (0)