Skip to content

Commit 37f1002

Browse files
Merge pull request #100 from open-sciencelab/fix/fix-async-storage
refactor: refactor storage methods to non-async as not necessary
2 parents f34792c + 86856d6 commit 37f1002

File tree

17 files changed

+197
-156
lines changed

17 files changed

+197
-156
lines changed

graphgen/bases/base_partitioner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ async def community2batch(
3939
edges = comm.edges
4040
nodes_data = []
4141
for node in nodes:
42-
node_data = await g.get_node(node)
42+
node_data = g.get_node(node)
4343
if node_data:
4444
nodes_data.append((node, node_data))
4545
edges_data = []
4646
for u, v in edges:
47-
edge_data = await g.get_edge(u, v)
47+
edge_data = g.get_edge(u, v)
4848
if edge_data:
4949
edges_data.append((u, v, edge_data))
5050
else:
51-
edge_data = await g.get_edge(v, u)
51+
edge_data = g.get_edge(v, u)
5252
if edge_data:
5353
edges_data.append((v, u, edge_data))
5454
batches.append((nodes_data, edges_data))

graphgen/bases/base_storage.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,103 +9,99 @@ class StorageNameSpace:
99
working_dir: str = None
1010
namespace: str = None
1111

12-
async def index_done_callback(self):
12+
def index_done_callback(self):
1313
"""commit the storage operations after indexing"""
1414

15-
async def query_done_callback(self):
15+
def query_done_callback(self):
1616
"""commit the storage operations after querying"""
1717

1818

1919
class BaseListStorage(Generic[T], StorageNameSpace):
20-
async def all_items(self) -> list[T]:
20+
def all_items(self) -> list[T]:
2121
raise NotImplementedError
2222

23-
async def get_by_index(self, index: int) -> Union[T, None]:
23+
def get_by_index(self, index: int) -> Union[T, None]:
2424
raise NotImplementedError
2525

26-
async def append(self, data: T):
26+
def append(self, data: T):
2727
raise NotImplementedError
2828

29-
async def upsert(self, data: list[T]):
29+
def upsert(self, data: list[T]):
3030
raise NotImplementedError
3131

32-
async def drop(self):
32+
def drop(self):
3333
raise NotImplementedError
3434

3535

3636
class BaseKVStorage(Generic[T], StorageNameSpace):
37-
async def all_keys(self) -> list[str]:
37+
def all_keys(self) -> list[str]:
3838
raise NotImplementedError
3939

40-
async def get_by_id(self, id: str) -> Union[T, None]:
40+
def get_by_id(self, id: str) -> Union[T, None]:
4141
raise NotImplementedError
4242

43-
async def get_by_ids(
43+
def get_by_ids(
4444
self, ids: list[str], fields: Union[set[str], None] = None
4545
) -> list[Union[T, None]]:
4646
raise NotImplementedError
4747

48-
async def get_all(self) -> dict[str, T]:
48+
def get_all(self) -> dict[str, T]:
4949
raise NotImplementedError
5050

51-
async def filter_keys(self, data: list[str]) -> set[str]:
51+
def filter_keys(self, data: list[str]) -> set[str]:
5252
"""return un-exist keys"""
5353
raise NotImplementedError
5454

55-
async def upsert(self, data: dict[str, T]):
55+
def upsert(self, data: dict[str, T]):
5656
raise NotImplementedError
5757

58-
async def drop(self):
58+
def drop(self):
5959
raise NotImplementedError
6060

6161

6262
class BaseGraphStorage(StorageNameSpace):
63-
async def has_node(self, node_id: str) -> bool:
63+
def has_node(self, node_id: str) -> bool:
6464
raise NotImplementedError
6565

66-
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
66+
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
6767
raise NotImplementedError
6868

69-
async def node_degree(self, node_id: str) -> int:
69+
def node_degree(self, node_id: str) -> int:
7070
raise NotImplementedError
7171

72-
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
72+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
7373
raise NotImplementedError
7474

75-
async def get_node(self, node_id: str) -> Union[dict, None]:
75+
def get_node(self, node_id: str) -> Union[dict, None]:
7676
raise NotImplementedError
7777

78-
async def update_node(self, node_id: str, node_data: dict[str, str]):
78+
def update_node(self, node_id: str, node_data: dict[str, str]):
7979
raise NotImplementedError
8080

81-
async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
81+
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
8282
raise NotImplementedError
8383

84-
async def get_edge(
85-
self, source_node_id: str, target_node_id: str
86-
) -> Union[dict, None]:
84+
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
8785
raise NotImplementedError
8886

89-
async def update_edge(
87+
def update_edge(
9088
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
9189
):
9290
raise NotImplementedError
9391

94-
async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
92+
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
9593
raise NotImplementedError
9694

97-
async def get_node_edges(
98-
self, source_node_id: str
99-
) -> Union[list[tuple[str, str]], None]:
95+
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
10096
raise NotImplementedError
10197

102-
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
98+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
10399
raise NotImplementedError
104100

105-
async def upsert_edge(
101+
def upsert_edge(
106102
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
107103
):
108104
raise NotImplementedError
109105

110-
async def delete_node(self, node_id: str):
106+
def delete_node(self, node_id: str):
111107
raise NotImplementedError

graphgen/graphgen.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ async def read(self, read_config: Dict):
104104
# TODO: configurable whether to use coreference resolution
105105

106106
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
107-
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
107+
_add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys()))
108108
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
109109

110110
if len(new_docs) == 0:
111111
logger.warning("All documents are already in the storage")
112112
return
113113

114-
await self.full_docs_storage.upsert(new_docs)
115-
await self.full_docs_storage.index_done_callback()
114+
self.full_docs_storage.upsert(new_docs)
115+
self.full_docs_storage.index_done_callback()
116116

117117
@op("chunk", deps=["read"])
118118
@async_to_sync_method
@@ -121,7 +121,7 @@ async def chunk(self, chunk_config: Dict):
121121
chunk documents into smaller pieces from full_docs_storage if not already present
122122
"""
123123

124-
new_docs = await self.meta_storage.get_new_data(self.full_docs_storage)
124+
new_docs = self.meta_storage.get_new_data(self.full_docs_storage)
125125
if len(new_docs) == 0:
126126
logger.warning("All documents are already in the storage")
127127
return
@@ -133,9 +133,7 @@ async def chunk(self, chunk_config: Dict):
133133
**chunk_config,
134134
)
135135

136-
_add_chunk_keys = await self.chunks_storage.filter_keys(
137-
list(inserting_chunks.keys())
138-
)
136+
_add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
139137
inserting_chunks = {
140138
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
141139
}
@@ -144,10 +142,10 @@ async def chunk(self, chunk_config: Dict):
144142
logger.warning("All chunks are already in the storage")
145143
return
146144

147-
await self.chunks_storage.upsert(inserting_chunks)
148-
await self.chunks_storage.index_done_callback()
149-
await self.meta_storage.mark_done(self.full_docs_storage)
150-
await self.meta_storage.index_done_callback()
145+
self.chunks_storage.upsert(inserting_chunks)
146+
self.chunks_storage.index_done_callback()
147+
self.meta_storage.mark_done(self.full_docs_storage)
148+
self.meta_storage.index_done_callback()
151149

152150
@op("build_kg", deps=["chunk"])
153151
@async_to_sync_method
@@ -156,7 +154,7 @@ async def build_kg(self):
156154
build knowledge graph from text chunks
157155
"""
158156
# Step 1: get new chunks according to meta and chunks storage
159-
inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage)
157+
inserting_chunks = self.meta_storage.get_new_data(self.chunks_storage)
160158
if len(inserting_chunks) == 0:
161159
logger.warning("All chunks are already in the storage")
162160
return
@@ -174,9 +172,9 @@ async def build_kg(self):
174172
return
175173

176174
# Step 3: mark meta
177-
await self.graph_storage.index_done_callback()
178-
await self.meta_storage.mark_done(self.chunks_storage)
179-
await self.meta_storage.index_done_callback()
175+
self.graph_storage.index_done_callback()
176+
self.meta_storage.mark_done(self.chunks_storage)
177+
self.meta_storage.index_done_callback()
180178

181179
return _add_entities_and_relations
182180

@@ -185,7 +183,7 @@ async def build_kg(self):
185183
async def search(self, search_config: Dict):
186184
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
187185

188-
seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
186+
seeds = self.meta_storage.get_new_data(self.full_docs_storage)
189187
if len(seeds) == 0:
190188
logger.warning("All documents are already been searched")
191189
return
@@ -194,19 +192,17 @@ async def search(self, search_config: Dict):
194192
search_config=search_config,
195193
)
196194

197-
_add_search_keys = await self.search_storage.filter_keys(
198-
list(search_results.keys())
199-
)
195+
_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
200196
search_results = {
201197
k: v for k, v in search_results.items() if k in _add_search_keys
202198
}
203199
if len(search_results) == 0:
204200
logger.warning("All search results are already in the storage")
205201
return
206-
await self.search_storage.upsert(search_results)
207-
await self.search_storage.index_done_callback()
208-
await self.meta_storage.mark_done(self.full_docs_storage)
209-
await self.meta_storage.index_done_callback()
202+
self.search_storage.upsert(search_results)
203+
self.search_storage.index_done_callback()
204+
self.meta_storage.mark_done(self.full_docs_storage)
205+
self.meta_storage.index_done_callback()
210206

211207
@op("quiz_and_judge", deps=["build_kg"])
212208
@async_to_sync_method
@@ -240,8 +236,8 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
240236
progress_bar=self.progress_bar,
241237
)
242238

243-
await self.rephrase_storage.index_done_callback()
244-
await _update_relations.index_done_callback()
239+
self.rephrase_storage.index_done_callback()
240+
_update_relations.index_done_callback()
245241

246242
logger.info("Shutting down trainee LLM client.")
247243
self.trainee_llm_client.shutdown()
@@ -258,7 +254,7 @@ async def partition(self, partition_config: Dict):
258254
self.tokenizer_instance,
259255
partition_config,
260256
)
261-
await self.partition_storage.upsert(batches)
257+
self.partition_storage.upsert(batches)
262258
return batches
263259

264260
@op("extract", deps=["chunk"])
@@ -276,10 +272,10 @@ async def extract(self, extract_config: Dict):
276272
logger.warning("No information extracted")
277273
return
278274

279-
await self.extract_storage.upsert(results)
280-
await self.extract_storage.index_done_callback()
281-
await self.meta_storage.mark_done(self.chunks_storage)
282-
await self.meta_storage.index_done_callback()
275+
self.extract_storage.upsert(results)
276+
self.extract_storage.index_done_callback()
277+
self.meta_storage.mark_done(self.chunks_storage)
278+
self.meta_storage.index_done_callback()
283279

284280
@op("generate", deps=["partition"])
285281
@async_to_sync_method
@@ -303,17 +299,17 @@ async def generate(self, generate_config: Dict):
303299
return
304300

305301
# Step 3: store the generated QA pairs
306-
await self.qa_storage.upsert(results)
307-
await self.qa_storage.index_done_callback()
302+
self.qa_storage.upsert(results)
303+
self.qa_storage.index_done_callback()
308304

309305
@async_to_sync_method
310306
async def clear(self):
311-
await self.full_docs_storage.drop()
312-
await self.chunks_storage.drop()
313-
await self.search_storage.drop()
314-
await self.graph_storage.clear()
315-
await self.rephrase_storage.drop()
316-
await self.qa_storage.drop()
307+
self.full_docs_storage.drop()
308+
self.chunks_storage.drop()
309+
self.search_storage.drop()
310+
self.graph_storage.clear()
311+
self.rephrase_storage.drop()
312+
self.qa_storage.drop()
317313

318314
logger.info("All caches are cleared")
319315

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def merge_nodes(
105105
source_ids = []
106106
descriptions = []
107107

108-
node = await kg_instance.get_node(entity_name)
108+
node = kg_instance.get_node(entity_name)
109109
if node is not None:
110110
entity_types.append(node["entity_type"])
111111
source_ids.extend(
@@ -134,7 +134,7 @@ async def merge_nodes(
134134
"description": description,
135135
"source_id": source_id,
136136
}
137-
await kg_instance.upsert_node(entity_name, node_data=node_data)
137+
kg_instance.upsert_node(entity_name, node_data=node_data)
138138

139139
async def merge_edges(
140140
self,
@@ -146,7 +146,7 @@ async def merge_edges(
146146
source_ids = []
147147
descriptions = []
148148

149-
edge = await kg_instance.get_edge(src_id, tgt_id)
149+
edge = kg_instance.get_edge(src_id, tgt_id)
150150
if edge is not None:
151151
source_ids.extend(
152152
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
@@ -161,8 +161,8 @@ async def merge_edges(
161161
)
162162

163163
for insert_id in [src_id, tgt_id]:
164-
if not await kg_instance.has_node(insert_id):
165-
await kg_instance.upsert_node(
164+
if not kg_instance.has_node(insert_id):
165+
kg_instance.upsert_node(
166166
insert_id,
167167
node_data={
168168
"source_id": source_id,
@@ -175,7 +175,7 @@ async def merge_edges(
175175
f"({src_id}, {tgt_id})", description
176176
)
177177

178-
await kg_instance.upsert_edge(
178+
kg_instance.upsert_edge(
179179
src_id,
180180
tgt_id,
181181
edge_data={"source_id": source_id, "description": description},

graphgen/models/partitioner/anchor_bfs_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ async def partition(
3636
max_units_per_community: int = 1,
3737
**kwargs: Any,
3838
) -> List[Community]:
39-
nodes = await g.get_all_nodes() # List[tuple[id, meta]]
40-
edges = await g.get_all_edges() # List[tuple[u, v, meta]]
39+
nodes = g.get_all_nodes() # List[tuple[id, meta]]
40+
edges = g.get_all_edges() # List[tuple[u, v, meta]]
4141

4242
adj, _ = self._build_adjacency_list(nodes, edges)
4343

graphgen/models/partitioner/bfs_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ async def partition(
2323
max_units_per_community: int = 1,
2424
**kwargs: Any,
2525
) -> List[Community]:
26-
nodes = await g.get_all_nodes()
27-
edges = await g.get_all_edges()
26+
nodes = g.get_all_nodes()
27+
edges = g.get_all_edges()
2828

2929
adj, _ = self._build_adjacency_list(nodes, edges)
3030

0 commit comments

Comments
 (0)