Skip to content

Commit 192cd41

Browse files
Add files via upload
1 parent c3dade6 commit 192cd41

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

cal_column_similarity.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from relation_features import make_data_from
2+
from utils import make_csv_from_json
23
from train import test
34
import numpy as np
45
import pandas as pd
@@ -16,6 +17,7 @@ def create_similarity_matrix(pth,preds,pred_labels_list,strategy="one-to-many"):
1617
"""
1718
Create a similarity matrix from the prediction
1819
"""
20+
predicted_pairs = []
1921
preds = np.array(preds)
2022
preds = np.mean(preds,axis=0)
2123
pred_labels_list = np.array(pred_labels_list)
@@ -28,6 +30,7 @@ def create_similarity_matrix(pth,preds,pred_labels_list,strategy="one-to-many"):
2830
df2_cols = df2.columns
2931
# create similarity matrix for pred values
3032
preds_matrix = np.array(preds).reshape(len(df1_cols),len(df2_cols))
33+
# create similarity matrix for pred labels
3134
if strategy == "one-to-many":
3235
pred_labels_matrix = np.array(pred_labels).reshape(len(df1_cols),len(df2_cols))
3336
elif strategy == "one-to-one":
@@ -41,11 +44,19 @@ def create_similarity_matrix(pth,preds,pred_labels_list,strategy="one-to-many"):
4144
pred_labels_matrix[i,j] = 1
4245
df_pred = pd.DataFrame(preds_matrix,columns=df2_cols,index=df1_cols)
4346
df_pred_labels = pd.DataFrame(pred_labels_matrix,columns=df2_cols,index=df1_cols)
44-
return df_pred,df_pred_labels
47+
for i in range(len(df_pred_labels)):
48+
for j in range(len(df_pred_labels.iloc[i])):
49+
if df_pred_labels.iloc[i,j] == 1:
50+
predicted_pairs.append((df_pred.index[i],df_pred.columns[j],df_pred.iloc[i,j]))
51+
return df_pred,df_pred_labels,predicted_pairs
4552

4653
if __name__ == '__main__':
4754
pth = args.path
4855
model_pth = args.model
56+
# transform jsonl or json file to csv
57+
for file in os.listdir(args.path):
58+
if file.endswith('.json') or file.endswith('.jsonl'):
59+
make_csv_from_json(pth+"/"+file)
4960

5061
features,_ = make_data_from(pth,"test")
5162
preds = []
@@ -64,6 +75,9 @@ def create_similarity_matrix(pth,preds,pred_labels_list,strategy="one-to-many"):
6475
pred_labels_list.append(pred_labels)
6576
del bst
6677

67-
df_pred,df_pred_labels = create_similarity_matrix(pth,preds,pred_labels_list,strategy=args.strategy)
78+
df_pred,df_pred_labels,predicted_pairs = create_similarity_matrix(pth,preds,pred_labels_list,strategy=args.strategy)
6879
df_pred.to_csv(pth+"/similarity_matrix_value.csv",index=True)
69-
df_pred_labels.to_csv(pth+"/similarity_matrix_label.csv",index=True)
80+
df_pred_labels.to_csv(pth+"/similarity_matrix_label.csv",index=True)
81+
82+
for pair_tuple in predicted_pairs:
83+
print(pair_tuple)

utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pandas as pd
2+
import json
3+
from collections import defaultdict
4+
import re
5+
6+
def find_all_keys_values(json_data,parent_key):
7+
"""
8+
Find all keys that don't have list or dictionary values and their values. Key should be saved with its parent key like "parent-key.key".
9+
"""
10+
key_values = defaultdict(list)
11+
for key, value in json_data.items():
12+
if isinstance(value, dict):
13+
child_key_values = find_all_keys_values(value,key)
14+
for child_key, child_value in child_key_values.items():
15+
key_values[child_key].extend(child_value)
16+
elif isinstance(value, list):
17+
for item in value:
18+
if isinstance(item, dict):
19+
child_key_values = find_all_keys_values(item,key)
20+
for child_key, child_value in child_key_values.items():
21+
key_values[child_key].extend(child_value)
22+
else:
23+
key_values[parent_key+"."+key].append(item)
24+
else:
25+
key_values[parent_key+"."+key].append(value)
26+
return key_values
27+
28+
def make_csv_from_json(file_path):
29+
"""
30+
Make csv file from json file.
31+
"""
32+
with open(file_path, 'r', encoding='utf-8') as f:
33+
data = json.load(f)
34+
35+
# find key_values
36+
if isinstance(data, dict):
37+
key_values = find_all_keys_values(data,"")
38+
elif isinstance(data, list):
39+
key_values = find_all_keys_values({"data":data},"")
40+
else:
41+
raise ValueError('Your input JsonData is not a dictionary or list')
42+
43+
key_values = {k:v for k,v in key_values.items() if len(v)>1}
44+
45+
df = pd.DataFrame({k:pd.Series(v) for k,v in key_values.items()})
46+
# save to csv
47+
save_pth = re.sub(r'\.jsonl?','.csv',file_path)
48+
df.to_csv(save_pth, index=False, encoding='utf-8')

0 commit comments

Comments
 (0)