99parser .add_argument ("-p" ,"--path" , help = "path to the folder containing the test data" )
1010parser .add_argument ("-m" , "--model" , help = "path to the model" )
1111parser .add_argument ("-t" , "--threshold" , help = "threshold for inference" )
12+ parser .add_argument ("-s" , "--strategy" , help = "one-to-one or one-to-many" , default = "one-to-many" )
1213args = parser .parse_args ()
1314
14- def create_similarity_matrix (pth ,pred ):
15+ def create_similarity_matrix (pth ,preds , pred_labels_list , strategy = "one-to-many" ):
1516 """
1617 Create a similarity matrix from the prediction
1718 """
19+ preds = np .array (preds )
20+ preds = np .mean (preds ,axis = 0 )
21+ pred_labels_list = np .array (pred_labels_list )
22+ pred_labels = np .mean (pred_labels_list ,axis = 0 )
23+ pred_labels = np .where (pred_labels > 0.5 ,1 ,0 )
24+ # read column names
1825 df1 = pd .read_csv (pth + "/Table1.csv" )
1926 df2 = pd .read_csv (pth + "/Table2.csv" )
2027 df1_cols = df1 .columns
2128 df2_cols = df2 .columns
22- # create similarity matrix for pred values
23- sim_matrix = np .zeros ((len (df1_cols ),len (df2_cols )))
24- for i in range (len (df1_cols )):
25- for j in range (len (df2_cols )):
26- sim_matrix [i ,j ] = pred [i * len (df2_cols )+ j ]
27- # create dataframe
28- df = pd .DataFrame (sim_matrix ,index = df1_cols ,columns = df2_cols )
29- return df
29+ # create similarity matrix for pred values
30+ preds_matrix = np .array (preds ).reshape (len (df1_cols ),len (df2_cols ))
31+ if strategy == "one-to-many" :
32+ pred_labels_matrix = np .array (pred_labels ).reshape (len (df1_cols ),len (df2_cols ))
33+ elif strategy == "one-to-one" :
34+ pred_labels_matrix = np .zeros ((len (df1_cols ),len (df2_cols )))
35+ for i in range (len (df1_cols )):
36+ for j in range (len (df2_cols )):
37+ if pred_labels [i * len (df2_cols )+ j ] == 1 :
38+ max_row = max (preds_matrix [i ,:])
39+ max_col = max (preds_matrix [:,j ])
40+ if preds_matrix [i ,j ] == max_row and preds_matrix [i ,j ] == max_col :
41+ pred_labels_matrix [i ,j ] = 1
42+ df_pred = pd .DataFrame (preds_matrix ,columns = df2_cols ,index = df1_cols )
43+ df_pred_labels = pd .DataFrame (pred_labels_matrix ,columns = df2_cols ,index = df1_cols )
44+ return df_pred ,df_pred_labels
3045
3146if __name__ == '__main__' :
3247 pth = args .path
@@ -48,13 +63,7 @@ def create_similarity_matrix(pth,pred):
4863 preds .append (pred )
4964 pred_labels_list .append (pred_labels )
5065 del bst
51- preds = np .array (preds )
52- preds = np .mean (preds ,axis = 0 )
53- pred_labels_list = np .array (pred_labels_list )
54- pred_labels = np .mean (pred_labels_list ,axis = 0 )
55- pred_labels = np .where (pred_labels > 0.5 ,1 ,0 )
5666
57- df_pred = create_similarity_matrix (pth ,pred )
58- df_pred_labels = create_similarity_matrix (pth ,pred_labels )
59- df_pred .to_csv (pth + "/similarity_matrix_value.csv" )
60- df_pred_labels .to_csv (pth + "/similarity_matrix_label.csv" )
67+ df_pred ,df_pred_labels = create_similarity_matrix (pth ,preds ,pred_labels_list ,strategy = args .strategy )
68+ df_pred .to_csv (pth + "/similarity_matrix_value.csv" ,index = True )
69+ df_pred_labels .to_csv (pth + "/similarity_matrix_label.csv" ,index = True )
0 commit comments