Skip to content

Commit 66b1e66

Browse files
Merge pull request #98 from CHERRY-ui8/feat/add-progress-bar-and-refactor-concurrent
refactor: refactor concurrent and add progress bar
2 parents 2acab4c + 276b963 commit 66b1e66

File tree

3 files changed

+97
-93
lines changed

3 files changed

+97
-93
lines changed

graphgen/graphgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
237237
self.graph_storage,
238238
self.rephrase_storage,
239239
re_judge,
240+
progress_bar=self.progress_bar,
240241
)
241242

242243
await self.rephrase_storage.index_done_callback()

graphgen/operators/partition/pre_tokenize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
from typing import List, Tuple
33

4+
import gradio as gr
5+
46
from graphgen.bases import BaseGraphStorage, BaseTokenizer
57
from graphgen.utils import run_concurrent
68

@@ -10,9 +12,11 @@ async def pre_tokenize(
1012
tokenizer: BaseTokenizer,
1113
edges: List[Tuple],
1214
nodes: List[Tuple],
15+
progress_bar: gr.Progress = None,
16+
max_concurrent: int = 1000,
1317
) -> Tuple[List, List]:
1418
"""为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
15-
sem = asyncio.Semaphore(1000)
19+
sem = asyncio.Semaphore(max_concurrent)
1620

1721
async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
1822
async with sem:
@@ -35,11 +39,15 @@ async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
3539
lambda e: _patch_and_write(e, is_node=False),
3640
edges,
3741
desc="Pre-tokenizing edges",
42+
unit="edge",
43+
progress_bar=progress_bar,
3844
),
3945
run_concurrent(
4046
lambda n: _patch_and_write(n, is_node=True),
4147
nodes,
4248
desc="Pre-tokenizing nodes",
49+
unit="node",
50+
progress_bar=progress_bar,
4351
),
4452
)
4553

Lines changed: 87 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
import asyncio
21
import math
32

4-
from tqdm.asyncio import tqdm as tqdm_async
3+
import gradio as gr
54

65
from graphgen.bases import BaseLLMWrapper
76
from graphgen.models import JsonKVStorage, NetworkXStorage
87
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
9-
from graphgen.utils import logger, yes_no_loss_entropy
8+
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
109

1110

1211
async def judge_statement( # pylint: disable=too-many-statements
1312
trainee_llm_client: BaseLLMWrapper,
1413
graph_storage: NetworkXStorage,
1514
rephrase_storage: JsonKVStorage,
1615
re_judge: bool = False,
17-
max_concurrent: int = 1000,
16+
progress_bar: gr.Progress = None,
1817
) -> NetworkXStorage:
1918
"""
2019
Get all edges and nodes and judge them
@@ -23,128 +22,124 @@ async def judge_statement( # pylint: disable=too-many-statements
2322
:param graph_storage: graph storage instance
2423
:param rephrase_storage: rephrase storage instance
2524
:param re_judge: re-judge the relations
26-
:param max_concurrent: max concurrent
25+
:param progress_bar
2726
:return:
2827
"""
2928

30-
semaphore = asyncio.Semaphore(max_concurrent)
31-
3229
async def _judge_single_relation(
3330
edge: tuple,
3431
):
35-
async with semaphore:
36-
source_id = edge[0]
37-
target_id = edge[1]
38-
edge_data = edge[2]
39-
40-
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
41-
logger.debug(
42-
"Edge %s -> %s already judged, loss: %s, skip",
43-
source_id,
44-
target_id,
45-
edge_data["loss"],
46-
)
47-
return source_id, target_id, edge_data
32+
source_id = edge[0]
33+
target_id = edge[1]
34+
edge_data = edge[2]
35+
36+
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
37+
logger.debug(
38+
"Edge %s -> %s already judged, loss: %s, skip",
39+
source_id,
40+
target_id,
41+
edge_data["loss"],
42+
)
43+
return source_id, target_id, edge_data
4844

49-
description = edge_data["description"]
45+
description = edge_data["description"]
5046

51-
try:
52-
descriptions = await rephrase_storage.get_by_id(description)
53-
assert descriptions is not None
47+
try:
48+
descriptions = await rephrase_storage.get_by_id(description)
49+
assert descriptions is not None
5450

55-
judgements = []
56-
gts = [gt for _, gt in descriptions]
57-
for description, gt in descriptions:
58-
judgement = await trainee_llm_client.generate_topk_per_token(
59-
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
60-
statement=description
61-
)
51+
judgements = []
52+
gts = [gt for _, gt in descriptions]
53+
for description, gt in descriptions:
54+
judgement = await trainee_llm_client.generate_topk_per_token(
55+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
56+
statement=description
6257
)
63-
judgements.append(judgement[0].top_candidates)
58+
)
59+
judgements.append(judgement[0].top_candidates)
6460

65-
loss = yes_no_loss_entropy(judgements, gts)
61+
loss = yes_no_loss_entropy(judgements, gts)
6662

67-
logger.debug(
68-
"Edge %s -> %s description: %s loss: %s",
69-
source_id,
70-
target_id,
71-
description,
72-
loss,
73-
)
63+
logger.debug(
64+
"Edge %s -> %s description: %s loss: %s",
65+
source_id,
66+
target_id,
67+
description,
68+
loss,
69+
)
7470

75-
edge_data["loss"] = loss
76-
except Exception as e: # pylint: disable=broad-except
77-
logger.error(
78-
"Error in judging relation %s -> %s: %s", source_id, target_id, e
79-
)
80-
logger.info("Use default loss 0.1")
81-
edge_data["loss"] = -math.log(0.1)
71+
edge_data["loss"] = loss
72+
except Exception as e: # pylint: disable=broad-except
73+
logger.error(
74+
"Error in judging relation %s -> %s: %s", source_id, target_id, e
75+
)
76+
logger.info("Use default loss 0.1")
77+
edge_data["loss"] = -math.log(0.1)
8278

83-
await graph_storage.update_edge(source_id, target_id, edge_data)
84-
return source_id, target_id, edge_data
79+
await graph_storage.update_edge(source_id, target_id, edge_data)
80+
return source_id, target_id, edge_data
8581

8682
edges = await graph_storage.get_all_edges()
8783

88-
results = []
89-
for result in tqdm_async(
90-
asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
91-
total=len(edges),
84+
await run_concurrent(
85+
_judge_single_relation,
86+
edges,
9287
desc="Judging relations",
93-
):
94-
results.append(await result)
88+
unit="relation",
89+
progress_bar=progress_bar,
90+
)
9591

9692
async def _judge_single_entity(
9793
node: tuple,
9894
):
99-
async with semaphore:
100-
node_id = node[0]
101-
node_data = node[1]
95+
node_id = node[0]
96+
node_data = node[1]
10297

103-
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
104-
logger.debug(
105-
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
106-
)
107-
return node_id, node_data
98+
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
99+
logger.debug(
100+
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
101+
)
102+
return node_id, node_data
108103

109-
description = node_data["description"]
104+
description = node_data["description"]
110105

111-
try:
112-
descriptions = await rephrase_storage.get_by_id(description)
113-
assert descriptions is not None
106+
try:
107+
descriptions = await rephrase_storage.get_by_id(description)
108+
assert descriptions is not None
114109

115-
judgements = []
116-
gts = [gt for _, gt in descriptions]
117-
for description, gt in descriptions:
118-
judgement = await trainee_llm_client.generate_topk_per_token(
119-
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
120-
statement=description
121-
)
110+
judgements = []
111+
gts = [gt for _, gt in descriptions]
112+
for description, gt in descriptions:
113+
judgement = await trainee_llm_client.generate_topk_per_token(
114+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
115+
statement=description
122116
)
123-
judgements.append(judgement[0].top_candidates)
117+
)
118+
judgements.append(judgement[0].top_candidates)
124119

125-
loss = yes_no_loss_entropy(judgements, gts)
120+
loss = yes_no_loss_entropy(judgements, gts)
126121

127-
logger.debug(
128-
"Node %s description: %s loss: %s", node_id, description, loss
129-
)
122+
logger.debug(
123+
"Node %s description: %s loss: %s", node_id, description, loss
124+
)
130125

131-
node_data["loss"] = loss
132-
except Exception as e: # pylint: disable=broad-except
133-
logger.error("Error in judging entity %s: %s", node_id, e)
134-
logger.error("Use default loss 0.1")
135-
node_data["loss"] = -math.log(0.1)
126+
node_data["loss"] = loss
127+
except Exception as e: # pylint: disable=broad-except
128+
logger.error("Error in judging entity %s: %s", node_id, e)
129+
logger.error("Use default loss 0.1")
130+
node_data["loss"] = -math.log(0.1)
136131

137-
await graph_storage.update_node(node_id, node_data)
138-
return node_id, node_data
132+
await graph_storage.update_node(node_id, node_data)
133+
return node_id, node_data
139134

140135
nodes = await graph_storage.get_all_nodes()
141136

142-
results = []
143-
for result in tqdm_async(
144-
asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
145-
total=len(nodes),
137+
await run_concurrent(
138+
_judge_single_entity,
139+
nodes,
146140
desc="Judging entities",
147-
):
148-
results.append(await result)
141+
unit="entity",
142+
progress_bar=progress_bar,
143+
)
149144

150145
return graph_storage

0 commit comments

Comments
 (0)