From 7324ca8b636e6ed2b252348acdee1523e50f9490 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 22:25:35 +0530 Subject: [PATCH 1/2] correctly calculate the cat and cont --- pytorch_forecasting/data/data_module.py | 34 +++++++++++++++++++++++-- tests/test_data/test_data_module.py | 8 ++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c8252014d..352024792 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -519,8 +519,38 @@ def __getitem__(self, idx): "decoder_mask": decoder_mask, } if data["static"] is not None: - x["static_categorical_features"] = data["static"].unsqueeze(0) - x["static_continuous_features"] = torch.zeros((1, 0)) + st_cat_values_for_item = [] + st_cont_values_for_item = [] + raw_st_tensor = data.get("static") + + static_col_names = self.data_module.time_series_metadata["cols"]["st"] + for i, col_name in enumerate(static_col_names): + feature_value = raw_st_tensor[i] + if ( + self.data_module.time_series_metadata["col_type"].get(col_name) + == "C" + ): + st_cat_values_for_item.append(feature_value) + else: + st_cont_values_for_item.append(feature_value) + + if st_cat_values_for_item: + x["static_categorical_features"] = torch.stack( + st_cat_values_for_item + ).unsqueeze(0) + else: + x["static_categorical_features"] = torch.zeros( + (1, 0), dtype=torch.float32 + ) + + if st_cont_values_for_item: + x["static_continuous_features"] = torch.stack( + st_cont_values_for_item + ).unsqueeze(0) + else: + x["static_continuous_features"] = torch.zeros( + (1, 0), dtype=torch.float32 + ) y = data["target"][decoder_indices] if y.ndim == 1: diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 4051b852c..cad78aecd 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -405,6 +405,14 @@ def test_with_static_features(): x, y = dm.train_dataset[0] assert "static_categorical_features" in x assert "static_continuous_features" in x + assert ( + x["static_categorical_features"].shape[1] + == metadata["static_categorical_features"] + ) + assert ( + x["static_continuous_features"].shape[1] + == metadata["static_continuous_features"] + ) def test_different_train_val_test_split(sample_timeseries_data): From c07d6e4b2a64c15fe0cf2238ff2270d56ad66cc9 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 16 May 2025 13:07:41 +0530 Subject: [PATCH 2/2] better implementation --- pytorch_forecasting/data/data_module.py | 40 +++++++++++++------------ 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 352024792..f6706275f 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -519,34 +519,36 @@ def __getitem__(self, idx): "decoder_mask": decoder_mask, } if data["static"] is not None: - st_cat_values_for_item = [] - st_cont_values_for_item = [] raw_st_tensor = data.get("static") - static_col_names = self.data_module.time_series_metadata["cols"]["st"] - for i, col_name in enumerate(static_col_names): - feature_value = raw_st_tensor[i] - if ( + + is_categorical_mask = torch.tensor( + [ self.data_module.time_series_metadata["col_type"].get(col_name) == "C" - ): - st_cat_values_for_item.append(feature_value) - else: - st_cont_values_for_item.append(feature_value) - - if st_cat_values_for_item: - x["static_categorical_features"] = torch.stack( - st_cat_values_for_item - ).unsqueeze(0) + for col_name in static_col_names + ], + dtype=torch.bool, + ) + + is_continuous_mask = ~is_categorical_mask + + st_cat_values_for_item = raw_st_tensor[is_categorical_mask] + st_cont_values_for_item = raw_st_tensor[is_continuous_mask] + + if st_cat_values_for_item.shape[0] > 0: + x["static_categorical_features"] = st_cat_values_for_item.unsqueeze( + 0 + ) else: x["static_categorical_features"] = torch.zeros( (1, 0), dtype=torch.float32 ) - if st_cont_values_for_item: - x["static_continuous_features"] = torch.stack( - st_cont_values_for_item - ).unsqueeze(0) + if st_cont_values_for_item.shape[0] > 0: + x["static_continuous_features"] = st_cont_values_for_item.unsqueeze( + 0 + ) else: x["static_continuous_features"] = torch.zeros( (1, 0), dtype=torch.float32