@@ -519,34 +519,36 @@ def __getitem__(self, idx):
519519 "decoder_mask" : decoder_mask ,
520520 }
521521 if data ["static" ] is not None :
522- st_cat_values_for_item = []
523- st_cont_values_for_item = []
524522 raw_st_tensor = data .get ("static" )
525-
526523 static_col_names = self .data_module .time_series_metadata ["cols" ]["st" ]
527- for i , col_name in enumerate ( static_col_names ):
528- feature_value = raw_st_tensor [ i ]
529- if (
524+
525+ is_categorical_mask = torch . tensor (
526+ [
530527 self .data_module .time_series_metadata ["col_type" ].get (col_name )
531528 == "C"
532- ):
533- st_cat_values_for_item .append (feature_value )
534- else :
535- st_cont_values_for_item .append (feature_value )
536-
537- if st_cat_values_for_item :
538- x ["static_categorical_features" ] = torch .stack (
539- st_cat_values_for_item
540- ).unsqueeze (0 )
529+ for col_name in static_col_names
530+ ],
531+ dtype = torch .bool ,
532+ )
533+
534+ is_continuous_mask = ~ is_categorical_mask
535+
536+ st_cat_values_for_item = raw_st_tensor [is_categorical_mask ]
537+ st_cont_values_for_item = raw_st_tensor [is_continuous_mask ]
538+
539+ if st_cat_values_for_item .shape [0 ] > 0 :
540+ x ["static_categorical_features" ] = st_cat_values_for_item .unsqueeze (
541+ 0
542+ )
541543 else :
542544 x ["static_categorical_features" ] = torch .zeros (
543545 (1 , 0 ), dtype = torch .float32
544546 )
545547
546- if st_cont_values_for_item :
547- x ["static_continuous_features" ] = torch . stack (
548- st_cont_values_for_item
549- ). unsqueeze ( 0 )
548+ if st_cont_values_for_item . shape [ 0 ] > 0 :
549+ x ["static_continuous_features" ] = st_cont_values_for_item . unsqueeze (
550+ 0
551+ )
550552 else :
551553 x ["static_continuous_features" ] = torch .zeros (
552554 (1 , 0 ), dtype = torch .float32
0 commit comments