Skip to content

Commit c07d6e4

Browse files
committed
better implementation
1 parent 7324ca8 commit c07d6e4

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

pytorch_forecasting/data/data_module.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -519,34 +519,36 @@ def __getitem__(self, idx):
519519
"decoder_mask": decoder_mask,
520520
}
521521
if data["static"] is not None:
522-
st_cat_values_for_item = []
523-
st_cont_values_for_item = []
524522
raw_st_tensor = data.get("static")
525-
526523
static_col_names = self.data_module.time_series_metadata["cols"]["st"]
527-
for i, col_name in enumerate(static_col_names):
528-
feature_value = raw_st_tensor[i]
529-
if (
524+
525+
is_categorical_mask = torch.tensor(
526+
[
530527
self.data_module.time_series_metadata["col_type"].get(col_name)
531528
== "C"
532-
):
533-
st_cat_values_for_item.append(feature_value)
534-
else:
535-
st_cont_values_for_item.append(feature_value)
536-
537-
if st_cat_values_for_item:
538-
x["static_categorical_features"] = torch.stack(
539-
st_cat_values_for_item
540-
).unsqueeze(0)
529+
for col_name in static_col_names
530+
],
531+
dtype=torch.bool,
532+
)
533+
534+
is_continuous_mask = ~is_categorical_mask
535+
536+
st_cat_values_for_item = raw_st_tensor[is_categorical_mask]
537+
st_cont_values_for_item = raw_st_tensor[is_continuous_mask]
538+
539+
if st_cat_values_for_item.shape[0] > 0:
540+
x["static_categorical_features"] = st_cat_values_for_item.unsqueeze(
541+
0
542+
)
541543
else:
542544
x["static_categorical_features"] = torch.zeros(
543545
(1, 0), dtype=torch.float32
544546
)
545547

546-
if st_cont_values_for_item:
547-
x["static_continuous_features"] = torch.stack(
548-
st_cont_values_for_item
549-
).unsqueeze(0)
548+
if st_cont_values_for_item.shape[0] > 0:
549+
x["static_continuous_features"] = st_cont_values_for_item.unsqueeze(
550+
0
551+
)
550552
else:
551553
x["static_continuous_features"] = torch.zeros(
552554
(1, 0), dtype=torch.float32

0 commit comments

Comments
 (0)