1- import asyncio
21import math
32
4- from tqdm . asyncio import tqdm as tqdm_async
3+ import gradio as gr
54
65from graphgen .bases import BaseLLMWrapper
76from graphgen .models import JsonKVStorage , NetworkXStorage
87from graphgen .templates import STATEMENT_JUDGEMENT_PROMPT
9- from graphgen .utils import logger , yes_no_loss_entropy
8+ from graphgen .utils import logger , run_concurrent , yes_no_loss_entropy
109
1110
1211async def judge_statement ( # pylint: disable=too-many-statements
1312 trainee_llm_client : BaseLLMWrapper ,
1413 graph_storage : NetworkXStorage ,
1514 rephrase_storage : JsonKVStorage ,
1615 re_judge : bool = False ,
17- max_concurrent : int = 1000 ,
16+ progress_bar : gr . Progress = None ,
1817) -> NetworkXStorage :
1918 """
2019 Get all edges and nodes and judge them
@@ -23,128 +22,124 @@ async def judge_statement( # pylint: disable=too-many-statements
2322 :param graph_storage: graph storage instance
2423 :param rephrase_storage: rephrase storage instance
2524 :param re_judge: re-judge the relations
26- :param max_concurrent: max concurrent
25+ :param progress_bar
2726 :return:
2827 """
2928
30- semaphore = asyncio .Semaphore (max_concurrent )
31-
3229 async def _judge_single_relation (
3330 edge : tuple ,
3431 ):
35- async with semaphore :
36- source_id = edge [0 ]
37- target_id = edge [1 ]
38- edge_data = edge [2 ]
39-
40- if (not re_judge ) and "loss" in edge_data and edge_data ["loss" ] is not None :
41- logger .debug (
42- "Edge %s -> %s already judged, loss: %s, skip" ,
43- source_id ,
44- target_id ,
45- edge_data ["loss" ],
46- )
47- return source_id , target_id , edge_data
32+ source_id = edge [0 ]
33+ target_id = edge [1 ]
34+ edge_data = edge [2 ]
35+
36+ if (not re_judge ) and "loss" in edge_data and edge_data ["loss" ] is not None :
37+ logger .debug (
38+ "Edge %s -> %s already judged, loss: %s, skip" ,
39+ source_id ,
40+ target_id ,
41+ edge_data ["loss" ],
42+ )
43+ return source_id , target_id , edge_data
4844
49- description = edge_data ["description" ]
45+ description = edge_data ["description" ]
5046
51- try :
52- descriptions = await rephrase_storage .get_by_id (description )
53- assert descriptions is not None
47+ try :
48+ descriptions = await rephrase_storage .get_by_id (description )
49+ assert descriptions is not None
5450
55- judgements = []
56- gts = [gt for _ , gt in descriptions ]
57- for description , gt in descriptions :
58- judgement = await trainee_llm_client .generate_topk_per_token (
59- STATEMENT_JUDGEMENT_PROMPT ["TEMPLATE" ].format (
60- statement = description
61- )
51+ judgements = []
52+ gts = [gt for _ , gt in descriptions ]
53+ for description , gt in descriptions :
54+ judgement = await trainee_llm_client .generate_topk_per_token (
55+ STATEMENT_JUDGEMENT_PROMPT ["TEMPLATE" ].format (
56+ statement = description
6257 )
63- judgements .append (judgement [0 ].top_candidates )
58+ )
59+ judgements .append (judgement [0 ].top_candidates )
6460
65- loss = yes_no_loss_entropy (judgements , gts )
61+ loss = yes_no_loss_entropy (judgements , gts )
6662
67- logger .debug (
68- "Edge %s -> %s description: %s loss: %s" ,
69- source_id ,
70- target_id ,
71- description ,
72- loss ,
73- )
63+ logger .debug (
64+ "Edge %s -> %s description: %s loss: %s" ,
65+ source_id ,
66+ target_id ,
67+ description ,
68+ loss ,
69+ )
7470
75- edge_data ["loss" ] = loss
76- except Exception as e : # pylint: disable=broad-except
77- logger .error (
78- "Error in judging relation %s -> %s: %s" , source_id , target_id , e
79- )
80- logger .info ("Use default loss 0.1" )
81- edge_data ["loss" ] = - math .log (0.1 )
71+ edge_data ["loss" ] = loss
72+ except Exception as e : # pylint: disable=broad-except
73+ logger .error (
74+ "Error in judging relation %s -> %s: %s" , source_id , target_id , e
75+ )
76+ logger .info ("Use default loss 0.1" )
77+ edge_data ["loss" ] = - math .log (0.1 )
8278
83- await graph_storage .update_edge (source_id , target_id , edge_data )
84- return source_id , target_id , edge_data
79+ await graph_storage .update_edge (source_id , target_id , edge_data )
80+ return source_id , target_id , edge_data
8581
8682 edges = await graph_storage .get_all_edges ()
8783
88- results = []
89- for result in tqdm_async (
90- asyncio .as_completed ([_judge_single_relation (edge ) for edge in edges ]),
91- total = len (edges ),
84+ await run_concurrent (
85+ _judge_single_relation ,
86+ edges ,
9287 desc = "Judging relations" ,
93- ):
94- results .append (await result )
88+ unit = "relation" ,
89+ progress_bar = progress_bar ,
90+ )
9591
9692 async def _judge_single_entity (
9793 node : tuple ,
9894 ):
99- async with semaphore :
100- node_id = node [0 ]
101- node_data = node [1 ]
95+ node_id = node [0 ]
96+ node_data = node [1 ]
10297
103- if (not re_judge ) and "loss" in node_data and node_data ["loss" ] is not None :
104- logger .debug (
105- "Node %s already judged, loss: %s, skip" , node_id , node_data ["loss" ]
106- )
107- return node_id , node_data
98+ if (not re_judge ) and "loss" in node_data and node_data ["loss" ] is not None :
99+ logger .debug (
100+ "Node %s already judged, loss: %s, skip" , node_id , node_data ["loss" ]
101+ )
102+ return node_id , node_data
108103
109- description = node_data ["description" ]
104+ description = node_data ["description" ]
110105
111- try :
112- descriptions = await rephrase_storage .get_by_id (description )
113- assert descriptions is not None
106+ try :
107+ descriptions = await rephrase_storage .get_by_id (description )
108+ assert descriptions is not None
114109
115- judgements = []
116- gts = [gt for _ , gt in descriptions ]
117- for description , gt in descriptions :
118- judgement = await trainee_llm_client .generate_topk_per_token (
119- STATEMENT_JUDGEMENT_PROMPT ["TEMPLATE" ].format (
120- statement = description
121- )
110+ judgements = []
111+ gts = [gt for _ , gt in descriptions ]
112+ for description , gt in descriptions :
113+ judgement = await trainee_llm_client .generate_topk_per_token (
114+ STATEMENT_JUDGEMENT_PROMPT ["TEMPLATE" ].format (
115+ statement = description
122116 )
123- judgements .append (judgement [0 ].top_candidates )
117+ )
118+ judgements .append (judgement [0 ].top_candidates )
124119
125- loss = yes_no_loss_entropy (judgements , gts )
120+ loss = yes_no_loss_entropy (judgements , gts )
126121
127- logger .debug (
128- "Node %s description: %s loss: %s" , node_id , description , loss
129- )
122+ logger .debug (
123+ "Node %s description: %s loss: %s" , node_id , description , loss
124+ )
130125
131- node_data ["loss" ] = loss
132- except Exception as e : # pylint: disable=broad-except
133- logger .error ("Error in judging entity %s: %s" , node_id , e )
134- logger .error ("Use default loss 0.1" )
135- node_data ["loss" ] = - math .log (0.1 )
126+ node_data ["loss" ] = loss
127+ except Exception as e : # pylint: disable=broad-except
128+ logger .error ("Error in judging entity %s: %s" , node_id , e )
129+ logger .error ("Use default loss 0.1" )
130+ node_data ["loss" ] = - math .log (0.1 )
136131
137- await graph_storage .update_node (node_id , node_data )
138- return node_id , node_data
132+ await graph_storage .update_node (node_id , node_data )
133+ return node_id , node_data
139134
140135 nodes = await graph_storage .get_all_nodes ()
141136
142- results = []
143- for result in tqdm_async (
144- asyncio .as_completed ([_judge_single_entity (node ) for node in nodes ]),
145- total = len (nodes ),
137+ await run_concurrent (
138+ _judge_single_entity ,
139+ nodes ,
146140 desc = "Judging entities" ,
147- ):
148- results .append (await result )
141+ unit = "entity" ,
142+ progress_bar = progress_bar ,
143+ )
149144
150145 return graph_storage
0 commit comments