Skip to content

Commit ade5a01

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/protein-qa
2 parents 8a23f73 + f8a3b9c commit ade5a01

File tree

8 files changed

+305
-51
lines changed

8 files changed

+305
-51
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ disable=raw-checker-failed,
452452
R0917, # Too many positional arguments (6/5) (too-many-positional-arguments)
453453
C0103,
454454
E0401,
455+
W0718, # Catching too general exception Exception (broad-except)
455456

456457
# Enable the message, report, category or checker with the given id(s). You can
457458
# either give multiple identifier separated by comma (,) or put this option

graphgen/configs/search_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pipeline:
22
- name: read
33
params:
4-
input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
4+
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55

66
- name: search
77
params:

graphgen/graphgen.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,16 @@ def __init__(
6767
self.graph_storage: NetworkXStorage = NetworkXStorage(
6868
self.working_dir, namespace="graph"
6969
)
70-
self.search_storage: JsonKVStorage = JsonKVStorage(
71-
self.working_dir, namespace="search"
72-
)
7370
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7471
self.working_dir, namespace="rephrase"
7572
)
7673
self.partition_storage: JsonListStorage = JsonListStorage(
7774
self.working_dir, namespace="partition"
7875
)
76+
self.search_storage: JsonKVStorage = JsonKVStorage(
77+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
78+
namespace="search",
79+
)
7980
self.qa_storage: JsonListStorage = JsonListStorage(
8081
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
8182
namespace="qa",
@@ -94,23 +95,24 @@ async def read(self, read_config: Dict):
9495
"""
9596
read files from input sources
9697
"""
97-
data = read_files(**read_config, cache_dir=self.working_dir)
98-
if len(data) == 0:
99-
logger.warning("No data to process")
100-
return
98+
doc_stream = read_files(**read_config, cache_dir=self.working_dir)
10199

102-
assert isinstance(data, list) and isinstance(data[0], dict)
100+
batch = {}
101+
for doc in doc_stream:
102+
doc_id = compute_mm_hash(doc, prefix="doc-")
103103

104-
# TODO: configurable whether to use coreference resolution
104+
batch[doc_id] = doc
105+
if batch:
106+
self.full_docs_storage.upsert(batch)
107+
self.full_docs_storage.index_done_callback()
105108

106-
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
107-
_add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys()))
108-
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
109+
# TODO: configurable whether to use coreference resolution
109110

111+
_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
112+
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
110113
if len(new_docs) == 0:
111114
logger.warning("All documents are already in the storage")
112115
return
113-
114116
self.full_docs_storage.upsert(new_docs)
115117
self.full_docs_storage.index_done_callback()
116118

graphgen/models/searcher/db/uniprot_searcher.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,16 @@ def _get_pool():
2727
return ThreadPoolExecutor(max_workers=10)
2828

2929

30+
# ensure only one BLAST searcher at a time
31+
_blast_lock = asyncio.Lock()
32+
33+
3034
class UniProtSearch(BaseSearcher):
3135
"""
3236
UniProt Search client to searcher with UniProt.
3337
1) Get the protein by accession number.
3438
2) Search with keywords or protein names (fuzzy searcher).
35-
3) Search with FASTA sequence (BLAST searcher).
39+
3) Search with FASTA sequence (BLAST searcher). Note that NCBIWWW does not support async.
3640
"""
3741

3842
def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"):
@@ -230,22 +234,21 @@ async def search(
230234
if query.startswith(">") or re.fullmatch(
231235
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
232236
):
233-
coro = loop.run_in_executor(
234-
_get_pool(), self.get_by_fasta, query, threshold
235-
)
237+
async with _blast_lock:
238+
result = await loop.run_in_executor(
239+
_get_pool(), self.get_by_fasta, query, threshold
240+
)
236241

237242
# check if accession number
238243
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
239-
coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query)
244+
result = await loop.run_in_executor(
245+
_get_pool(), self.get_by_accession, query
246+
)
240247

241248
else:
242249
# otherwise treat as keyword
243-
coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query)
250+
result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
244251

245-
result = await coro
246252
if result:
247253
result["_search_query"] = query
248254
return result
249-
250-
251-
# TODO: use local UniProt database for large-scale searchs
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import os
2+
import time
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from pathlib import Path
5+
from typing import Any, Dict, List, Set, Union
6+
7+
from diskcache import Cache
8+
9+
from graphgen.utils import logger
10+
11+
12+
class ParallelFileScanner:
13+
def __init__(
14+
self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4
15+
):
16+
self.cache = Cache(cache_dir)
17+
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None
18+
self.rescan = rescan
19+
self.max_workers = max_workers
20+
21+
def scan(
22+
self, paths: Union[str, List[str]], recursive: bool = True
23+
) -> Dict[str, Any]:
24+
if isinstance(paths, str):
25+
paths = [paths]
26+
27+
results = {}
28+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
29+
future_to_path = {}
30+
for p in paths:
31+
if os.path.exists(p):
32+
future = executor.submit(
33+
self._scan_files, Path(p).resolve(), recursive, set()
34+
)
35+
future_to_path[future] = p
36+
else:
37+
logger.warning("[READ] Path does not exist: %s", p)
38+
39+
for future in as_completed(future_to_path):
40+
path = future_to_path[future]
41+
try:
42+
results[path] = future.result()
43+
except Exception as e:
44+
logger.error("[READ] Error scanning path %s: %s", path, e)
45+
results[path] = {
46+
"error": str(e),
47+
"files": [],
48+
"dirs": [],
49+
"stats": {},
50+
}
51+
return results
52+
53+
def _scan_files(
54+
self, path: Path, recursive: bool, visited: Set[str]
55+
) -> Dict[str, Any]:
56+
path_str = str(path)
57+
58+
# Avoid cycles due to symlinks
59+
if path_str in visited:
60+
logger.warning("[READ] Skipping already visited path: %s", path_str)
61+
return self._empty_result(path_str)
62+
63+
# cache check
64+
cache_key = f"scan::{path_str}::recursive::{recursive}"
65+
cached = self.cache.get(cache_key)
66+
if cached and not self.rescan:
67+
logger.info("[READ] Using cached scan result for path: %s", path_str)
68+
return cached["data"]
69+
70+
logger.info("[READ] Scanning path: %s", path_str)
71+
files, dirs = [], []
72+
stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0}
73+
74+
try:
75+
path_stat = path.stat()
76+
if path.is_file():
77+
return self._scan_single_file(path, path_str, path_stat)
78+
if path.is_dir():
79+
with os.scandir(path_str) as entries:
80+
for entry in entries:
81+
try:
82+
entry_stat = entry.stat(follow_symlinks=False)
83+
84+
if entry.is_dir():
85+
dirs.append(
86+
{
87+
"path": entry.path,
88+
"name": entry.name,
89+
"mtime": entry_stat.st_mtime,
90+
}
91+
)
92+
stats["dir_count"] += 1
93+
else:
94+
# allowed suffix filter
95+
if not self._is_allowed_file(Path(entry.path)):
96+
continue
97+
files.append(
98+
{
99+
"path": entry.path,
100+
"name": entry.name,
101+
"size": entry_stat.st_size,
102+
"mtime": entry_stat.st_mtime,
103+
}
104+
)
105+
stats["total_size"] += entry_stat.st_size
106+
stats["file_count"] += 1
107+
108+
except OSError:
109+
stats["errors"] += 1
110+
111+
except (PermissionError, FileNotFoundError, OSError) as e:
112+
logger.error("[READ] Failed to scan path %s: %s", path_str, e)
113+
return {"error": str(e), "files": [], "dirs": [], "stats": stats}
114+
115+
if recursive:
116+
sub_visited = visited | {path_str}
117+
sub_results = self._scan_subdirs(dirs, sub_visited)
118+
119+
for sub_data in sub_results.values():
120+
files.extend(sub_data.get("files", []))
121+
stats["total_size"] += sub_data["stats"].get("total_size", 0)
122+
stats["file_count"] += sub_data["stats"].get("file_count", 0)
123+
124+
result = {"path": path_str, "files": files, "dirs": dirs, "stats": stats}
125+
self._cache_result(cache_key, result, path)
126+
return result
127+
128+
def _scan_single_file(
129+
self, path: Path, path_str: str, stat: os.stat_result
130+
) -> Dict[str, Any]:
131+
"""Scan a single file and return its metadata"""
132+
if not self._is_allowed_file(path):
133+
return self._empty_result(path_str)
134+
135+
return {
136+
"path": path_str,
137+
"files": [
138+
{
139+
"path": path_str,
140+
"name": path.name,
141+
"size": stat.st_size,
142+
"mtime": stat.st_mtime,
143+
}
144+
],
145+
"dirs": [],
146+
"stats": {
147+
"total_size": stat.st_size,
148+
"file_count": 1,
149+
"dir_count": 0,
150+
"errors": 0,
151+
},
152+
}
153+
154+
def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]:
155+
"""
156+
Parallel scan subdirectories
157+
:param dir_list
158+
:param visited
159+
:return:
160+
"""
161+
results = {}
162+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
163+
futures = {
164+
executor.submit(self._scan_files, Path(d["path"]), True, visited): d[
165+
"path"
166+
]
167+
for d in dir_list
168+
}
169+
170+
for future in as_completed(futures):
171+
path = futures[future]
172+
try:
173+
results[path] = future.result()
174+
except Exception as e:
175+
logger.error("[READ] Error scanning subdirectory %s: %s", path, e)
176+
results[path] = {
177+
"error": str(e),
178+
"files": [],
179+
"dirs": [],
180+
"stats": {},
181+
}
182+
183+
return results
184+
185+
def _cache_result(self, key: str, result: Dict, path: Path):
186+
"""Cache the scan result"""
187+
try:
188+
self.cache.set(
189+
key,
190+
{
191+
"data": result,
192+
"dir_mtime": path.stat().st_mtime,
193+
"cached_at": time.time(),
194+
},
195+
)
196+
logger.info("[READ] Cached scan result for path: %s", path)
197+
except OSError as e:
198+
logger.error("[READ] Failed to cache scan result for path %s: %s", path, e)
199+
200+
def _is_allowed_file(self, path: Path) -> bool:
201+
"""Check if the file has an allowed suffix"""
202+
if self.allowed_suffix is None:
203+
return True
204+
suffix = path.suffix.lower().lstrip(".")
205+
return suffix in self.allowed_suffix
206+
207+
def invalidate(self, path: str):
208+
"""Invalidate cache for a specific path"""
209+
path = Path(path).resolve()
210+
keys = [k for k in self.cache if k.startswith(f"scan::{path}")]
211+
for k in keys:
212+
self.cache.delete(k)
213+
logger.info("[READ] Invalidated cache for path: %s", path)
214+
215+
def close(self):
216+
self.cache.close()
217+
218+
def __enter__(self):
219+
return self
220+
221+
def __exit__(self, *args):
222+
self.close()
223+
224+
@staticmethod
225+
def _empty_result(path: str) -> Dict[str, Any]:
226+
return {
227+
"path": path,
228+
"files": [],
229+
"dirs": [],
230+
"stats": {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0},
231+
}

0 commit comments

Comments
 (0)