Skip to content

Commit b85d44c

Browse files
Add files via upload
1 parent a1037c7 commit b85d44c

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

cal_column_similarity.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,39 @@
99
parser.add_argument("-p","--path", help="path to the folder containing the test data")
1010
parser.add_argument("-m", "--model", help="path to the model")
1111
parser.add_argument("-t", "--threshold", help="threshold for inference")
12+
parser.add_argument("-s", "--strategy", help="one-to-one or one-to-many", default="one-to-many")
1213
args = parser.parse_args()
1314

14-
def create_similarity_matrix(pth,pred):
15+
def create_similarity_matrix(pth,preds,pred_labels_list,strategy="one-to-many"):
1516
"""
1617
Create a similarity matrix from the prediction
1718
"""
19+
preds = np.array(preds)
20+
preds = np.mean(preds,axis=0)
21+
pred_labels_list = np.array(pred_labels_list)
22+
pred_labels = np.mean(pred_labels_list,axis=0)
23+
pred_labels = np.where(pred_labels>0.5,1,0)
24+
# read column names
1825
df1 = pd.read_csv(pth+"/Table1.csv")
1926
df2 = pd.read_csv(pth+"/Table2.csv")
2027
df1_cols = df1.columns
2128
df2_cols = df2.columns
22-
# create similarity matrix for pred values
23-
sim_matrix = np.zeros((len(df1_cols),len(df2_cols)))
24-
for i in range(len(df1_cols)):
25-
for j in range(len(df2_cols)):
26-
sim_matrix[i,j] = pred[i*len(df2_cols)+j]
27-
# create dataframe
28-
df = pd.DataFrame(sim_matrix,index=df1_cols,columns=df2_cols)
29-
return df
29+
# create similarity matrix for pred values
30+
preds_matrix = np.array(preds).reshape(len(df1_cols),len(df2_cols))
31+
if strategy == "one-to-many":
32+
pred_labels_matrix = np.array(pred_labels).reshape(len(df1_cols),len(df2_cols))
33+
elif strategy == "one-to-one":
34+
pred_labels_matrix = np.zeros((len(df1_cols),len(df2_cols)))
35+
for i in range(len(df1_cols)):
36+
for j in range(len(df2_cols)):
37+
if pred_labels[i*len(df2_cols)+j] == 1:
38+
max_row = max(preds_matrix[i,:])
39+
max_col = max(preds_matrix[:,j])
40+
if preds_matrix[i,j] == max_row and preds_matrix[i,j] == max_col:
41+
pred_labels_matrix[i,j] = 1
42+
df_pred = pd.DataFrame(preds_matrix,columns=df2_cols,index=df1_cols)
43+
df_pred_labels = pd.DataFrame(pred_labels_matrix,columns=df2_cols,index=df1_cols)
44+
return df_pred,df_pred_labels
3045

3146
if __name__ == '__main__':
3247
pth = args.path
@@ -48,13 +63,7 @@ def create_similarity_matrix(pth,pred):
4863
preds.append(pred)
4964
pred_labels_list.append(pred_labels)
5065
del bst
51-
preds = np.array(preds)
52-
preds = np.mean(preds,axis=0)
53-
pred_labels_list = np.array(pred_labels_list)
54-
pred_labels = np.mean(pred_labels_list,axis=0)
55-
pred_labels = np.where(pred_labels>0.5,1,0)
5666

57-
df_pred = create_similarity_matrix(pth,pred)
58-
df_pred_labels = create_similarity_matrix(pth,pred_labels)
59-
df_pred.to_csv(pth+"/similarity_matrix_value.csv")
60-
df_pred_labels.to_csv(pth+"/similarity_matrix_label.csv")
67+
df_pred,df_pred_labels = create_similarity_matrix(pth,preds,pred_labels_list,strategy=args.strategy)
68+
df_pred.to_csv(pth+"/similarity_matrix_value.csv",index=True)
69+
df_pred_labels.to_csv(pth+"/similarity_matrix_label.csv",index=True)

0 commit comments

Comments
 (0)