diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c8252014d..f6706275f 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -519,8 +519,40 @@ def __getitem__(self, idx): "decoder_mask": decoder_mask, } if data["static"] is not None: - x["static_categorical_features"] = data["static"].unsqueeze(0) - x["static_continuous_features"] = torch.zeros((1, 0)) + raw_st_tensor = data.get("static") + static_col_names = self.data_module.time_series_metadata["cols"]["st"] + + is_categorical_mask = torch.tensor( + [ + self.data_module.time_series_metadata["col_type"].get(col_name) + == "C" + for col_name in static_col_names + ], + dtype=torch.bool, + ) + + is_continuous_mask = ~is_categorical_mask + + st_cat_values_for_item = raw_st_tensor[is_categorical_mask] + st_cont_values_for_item = raw_st_tensor[is_continuous_mask] + + if st_cat_values_for_item.shape[0] > 0: + x["static_categorical_features"] = st_cat_values_for_item.unsqueeze( + 0 + ) + else: + x["static_categorical_features"] = torch.zeros( + (1, 0), dtype=torch.float32 + ) + + if st_cont_values_for_item.shape[0] > 0: + x["static_continuous_features"] = st_cont_values_for_item.unsqueeze( + 0 + ) + else: + x["static_continuous_features"] = torch.zeros( + (1, 0), dtype=torch.float32 + ) y = data["target"][decoder_indices] if y.ndim == 1: diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 4051b852c..cad78aecd 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -405,6 +405,14 @@ def test_with_static_features(): x, y = dm.train_dataset[0] assert "static_categorical_features" in x assert "static_continuous_features" in x + assert ( + x["static_categorical_features"].shape[1] + == metadata["static_categorical_features"] + ) + assert ( + x["static_continuous_features"].shape[1] + == metadata["static_continuous_features"] + ) def test_different_train_val_test_split(sample_timeseries_data):