Skip to content

Commit d2c4135

Browse files
committed
fix: allow multi input sources and ensure sequence extraction in omics qa
1 parent 741ebd3 commit d2c4135

File tree

6 files changed

+192
-52
lines changed

6 files changed

+192
-52
lines changed

examples/generate/generate_omics_qa/omics_qa_config.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ nodes:
1010
dependencies: []
1111
params:
1212
input_path:
13-
# For DNA: examples/input_examples/search_dna_demo.jsonl
14-
# For RNA: examples/input_examples/search_rna_demo.jsonl
15-
# For Protein: examples/input_examples/search_protein_demo.jsonl
16-
- examples/input_examples/search_protein_demo.jsonl # Change this to dna/rna/protein demo file as needed
13+
# three input files to generate DNA, RNA, and Protein data together
14+
- examples/input_examples/search_dna_demo.jsonl
15+
- examples/input_examples/search_rna_demo.jsonl
16+
- examples/input_examples/search_protein_demo.jsonl
1717

1818
- id: search_data
1919
op_name: search
@@ -24,25 +24,26 @@ nodes:
2424
replicas: 1
2525
batch_size: 10
2626
params:
27-
data_sources: [uniprot] # Change to [ncbi] for DNA or [rnacentral] for RNA
27+
data_sources: [ncbi, rnacentral, uniprot] # Multi-omics: use all three data sources
2828
# DNA search parameters
2929
ncbi_params:
3030
email: your_email@example.com # Required for NCBI
3131
tool: GraphGen
3232
use_local_blast: true
33-
local_blast_db: refseq_release/refseq_release
33+
local_blast_db: databases/refseq_232_old/refseq_232
3434
blast_num_threads: 2
3535
max_concurrent: 5
3636
# RNA search parameters
3737
rnacentral_params:
3838
use_local_blast: true
39-
local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD
39+
local_blast_db: databases/rnacentral_merged_20251213/rnacentral_merged_20251213
4040
blast_num_threads: 2
4141
max_concurrent: 5
4242
# Protein search parameters
4343
uniprot_params:
4444
use_local_blast: true
45-
local_blast_db: ${RELEASE}/uniprot_sprot
45+
# local_blast_db: ${RELEASE}/uniprot_sprot
46+
local_blast_db: databases/2025_04/uniprot_sprot
4647
blast_num_threads: 2
4748
max_concurrent: 5
4849

@@ -76,7 +77,7 @@ nodes:
7677
params:
7778
method: anchor_bfs # partition method
7879
method_params:
79-
anchor_type: protein # node type (dna, rna, or protein)
80+
anchor_type: [dna, rna, protein] # Multi-omics: support multiple anchor types (list or single string)
8081
max_units_per_community: 10 # max nodes and edges per community
8182

8283
- id: generate

graphgen/models/generator/omics_qa_generator.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -230,28 +230,36 @@ async def generate(
230230
# Detect molecule type from nodes
231231
molecule_type = self._detect_molecule_type(nodes)
232232

233-
# Extract caption for each node and attach to QA pairs
234-
# Only attach caption once per batch (from the first relevant node)
233+
# Extract captions for all molecule types from nodes
234+
captions = {"dna": None, "rna": None, "protein": None}
235235
caption_attached = False
236+
236237
for node in nodes:
237238
node_id, node_data = node
238-
caption = self._extract_caption(node_data, molecule_type)
239239

240-
if caption and not caption_attached:
241-
# Attach caption to all QA pairs
242-
for qa in qa_pairs.values():
243-
# Use molecule_type as the key (dna, rna, or protein)
244-
qa[molecule_type] = caption
245-
caption_attached = True
246-
break # Only need to attach once per batch
240+
# Check for pre-extracted captions (from partition_service)
241+
for mol_type in ["dna", "rna", "protein"]:
242+
caption_key = f"{mol_type}_caption"
243+
if caption_key in node_data and node_data[caption_key]:
244+
captions[mol_type] = node_data[caption_key]
245+
caption_attached = True
246+
247+
# If no pre-extracted captions, extract from node_data using the detected molecule_type
248+
if not caption_attached:
249+
caption = self._extract_caption(node_data, molecule_type)
250+
if caption:
251+
captions[molecule_type] = caption
252+
caption_attached = True
253+
break # Only need to extract once per batch
254+
255+
# Attach all captions to QA pairs
256+
for qa in qa_pairs.values():
257+
qa["dna"] = captions["dna"] if captions["dna"] else ""
258+
qa["rna"] = captions["rna"] if captions["rna"] else ""
259+
qa["protein"] = captions["protein"] if captions["protein"] else ""
247260

248261
if not caption_attached:
249262
logger.warning(f"No caption extracted for molecule_type={molecule_type}. Node data sample: {dict(list(nodes[0][1].items())[:5]) if nodes else 'No nodes'}")
250-
# Still attach empty captions to maintain format consistency
251-
for qa in qa_pairs.values():
252-
qa.setdefault("dna", "")
253-
qa.setdefault("rna", "")
254-
qa.setdefault("protein", "")
255263

256264
result.update(qa_pairs)
257265
return result

graphgen/models/kg_builder/omics_kg_builder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,17 @@ async def merge_nodes(
179179
set([dp["source_id"] for dp in node_data] + source_ids)
180180
)
181181

182-
node_data = {
182+
node_data_dict = {
183183
"entity_type": entity_type,
184184
"description": description,
185185
"source_id": source_id,
186186
}
187-
kg_instance.upsert_node(entity_name, node_data=node_data)
187+
188+
# Preserve sequence from existing node if present (e.g., added by partition_service)
189+
if node is not None and "sequence" in node and node["sequence"]:
190+
node_data_dict["sequence"] = node["sequence"]
191+
192+
kg_instance.upsert_node(entity_name, node_data=node_data_dict)
188193

189194
async def merge_edges(
190195
self,
@@ -194,6 +199,12 @@ async def merge_edges(
194199
"""Merge extracted edges into the knowledge graph."""
195200
(src_id, tgt_id), edge_data = edges_data
196201

202+
# Skip self-loops (edges where source and target are the same)
203+
# This can happen when LLM extracts invalid relationships
204+
if src_id == tgt_id:
205+
logger.debug("Skipping self-loop edge: (%s, %s)", src_id, tgt_id)
206+
return
207+
197208
source_ids = []
198209
descriptions = []
199210

graphgen/models/partitioner/anchor_bfs_partitioner.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import deque
3-
from typing import Any, Iterable, List, Literal, Set, Tuple
3+
from typing import Any, Iterable, List, Literal, Set, Tuple, Union
44

55
from graphgen.bases import BaseGraphStorage
66
from graphgen.bases.datatypes import Community
@@ -18,16 +18,26 @@ class AnchorBFSPartitioner(BFSPartitioner):
1818
2. Expand the community using BFS until the max unit size is reached.(A unit is a node or an edge.)
1919
3. Non-anchor units can only be "pulled" into a community and never become seeds themselves.
2020
For example, for VQA tasks, we may want to use image nodes as anchors and expand to nearby text nodes and edges.
21+
22+
Supports multiple anchor types for multi-omics data: anchor_type can be a single string or a list of strings.
23+
When a list is provided, nodes matching any of the types in the list can serve as anchors.
2124
"""
2225

2326
def __init__(
2427
self,
2528
*,
26-
anchor_type: Literal["image"] = "image",
29+
anchor_type: Union[
30+
Literal["image", "dna", "rna", "protein"],
31+
List[Literal["dna", "rna", "protein"]],
32+
] = "image",
2733
anchor_ids: Set[str] | None = None,
2834
) -> None:
2935
super().__init__()
30-
self.anchor_type = anchor_type
36+
# Normalize anchor_type to always be a list for internal processing
37+
if isinstance(anchor_type, str):
38+
self.anchor_types = [anchor_type]
39+
else:
40+
self.anchor_types = list(anchor_type)
3141
self.anchor_ids = anchor_ids
3242

3343
def partition(
@@ -68,10 +78,53 @@ def _pick_anchor_ids(
6878
return self.anchor_ids
6979

7080
anchor_ids: Set[str] = set()
81+
anchor_types_lower = [at.lower() for at in self.anchor_types]
82+
7183
for node_id, meta in nodes:
84+
# Check if node matches any of the anchor types
85+
matched = False
86+
87+
# Check 1: entity_type (for image, etc.)
7288
node_type = str(meta.get("entity_type", "")).lower()
73-
if self.anchor_type.lower() in node_type:
89+
for anchor_type_lower in anchor_types_lower:
90+
if anchor_type_lower in node_type:
91+
anchor_ids.add(node_id)
92+
matched = True
93+
break
94+
95+
if matched:
96+
continue
97+
98+
# Check 2: molecule_type (for omics data: dna, rna, protein)
99+
molecule_type = str(meta.get("molecule_type", "")).lower()
100+
if molecule_type in anchor_types_lower:
74101
anchor_ids.add(node_id)
102+
continue
103+
104+
# Check 3: source_id prefix (for omics data: dna-, rna-, protein-)
105+
source_id = str(meta.get("source_id", "")).lower()
106+
for anchor_type_lower in anchor_types_lower:
107+
if source_id.startswith(f"{anchor_type_lower}-"):
108+
anchor_ids.add(node_id)
109+
matched = True
110+
break
111+
112+
if matched:
113+
continue
114+
115+
# Check 4: Check if source_id contains multiple IDs separated by <SEP>
116+
if "<sep>" in source_id:
117+
source_ids = source_id.split("<sep>")
118+
for sid in source_ids:
119+
sid = sid.strip()
120+
for anchor_type_lower in anchor_types_lower:
121+
if sid.startswith(f"{anchor_type_lower}-"):
122+
anchor_ids.add(node_id)
123+
matched = True
124+
break
125+
if matched:
126+
break
127+
75128
return anchor_ids
76129

77130
@staticmethod
@@ -113,7 +166,21 @@ def _grow_community(
113166
if it in used_e:
114167
continue
115168
used_e.add(it)
116-
u, v = it
169+
# Convert frozenset to tuple for edge representation
170+
# Note: Self-loops should be filtered during graph construction,
171+
# but we handle edge cases defensively
172+
try:
173+
u, v = tuple(it)
174+
except ValueError:
175+
# Handle edge case: frozenset with unexpected number of elements
176+
# This should not happen if graph construction is correct
177+
edge_nodes = list(it)
178+
if len(edge_nodes) == 1:
179+
# Self-loop edge (should have been filtered during graph construction)
180+
u, v = edge_nodes[0], edge_nodes[0]
181+
else:
182+
# Invalid edge, skip it
183+
continue
117184
comm_e.append((u, v))
118185
cnt += 1
119186
for n in it:

graphgen/operators/partition/partition_service.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ def partition(self) -> Iterable[pd.DataFrame]:
6464
partitioner = LeidenPartitioner()
6565
elif method == "anchor_bfs":
6666
logger.info("Partitioning knowledge graph using Anchor BFS method.")
67+
anchor_type = method_params.get("anchor_type")
68+
if isinstance(anchor_type, list):
69+
logger.info("Using multiple anchor types: %s", anchor_type)
70+
else:
71+
logger.info("Using single anchor type: %s", anchor_type)
6772
partitioner = AnchorBFSPartitioner(
68-
anchor_type=method_params.get("anchor_type"),
73+
anchor_type=anchor_type,
6974
anchor_ids=set(method_params.get("anchor_ids", []))
7075
if method_params.get("anchor_ids")
7176
else None,
@@ -187,41 +192,86 @@ def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
187192
logger.warning("No chunks found for node %s (type: %s) with source_ids: %s", node_id, molecule_type, source_ids)
188193
continue
189194

190-
first_chunk = omics_chunks[0]
191-
def get_chunk_value(field: str):
195+
def get_chunk_value(chunk: dict, field: str):
192196
# First check root level of chunk
193-
if field in first_chunk:
194-
return first_chunk[field]
197+
if field in chunk:
198+
return chunk[field]
195199
# Then check metadata if it exists and is a dict
196-
chunk_metadata = first_chunk.get("metadata")
200+
chunk_metadata = chunk.get("metadata")
197201
if isinstance(chunk_metadata, dict) and field in chunk_metadata:
198202
return chunk_metadata[field]
199203
return None
200204

201-
# Attach sequence if not already present
205+
# Group chunks by molecule type to preserve all types of sequences
206+
chunks_by_type = {"dna": [], "rna": [], "protein": []}
207+
for chunk in omics_chunks:
208+
chunk_id = chunk.get("_chunk_id", "").lower()
209+
if chunk_id.startswith("dna-"):
210+
chunks_by_type["dna"].append(chunk)
211+
elif chunk_id.startswith("rna-"):
212+
chunks_by_type["rna"].append(chunk)
213+
elif chunk_id.startswith("protein-"):
214+
chunks_by_type["protein"].append(chunk)
215+
216+
# Field mappings for each molecule type
217+
field_mapping = {
218+
"protein": ["protein_name", "gene_names", "organism", "function", "sequence", "id", "database", "entry_name", "uniprot_id"],
219+
"dna": ["gene_name", "gene_description", "organism", "chromosome", "genomic_location", "function", "gene_type", "sequence", "id", "database"],
220+
"rna": ["rna_type", "description", "organism", "related_genes", "gene_name", "so_term", "sequence", "id", "database", "rnacentral_id"],
221+
}
222+
223+
# Extract and store captions for each molecule type
224+
for mol_type in ["dna", "rna", "protein"]:
225+
type_chunks = chunks_by_type[mol_type]
226+
if not type_chunks:
227+
continue
228+
229+
# Use the first chunk of this type
230+
type_chunk = type_chunks[0]
231+
caption = {}
232+
233+
# Extract all relevant fields for this molecule type
234+
for field in field_mapping.get(mol_type, []):
235+
value = get_chunk_value(type_chunk, field)
236+
if value:
237+
caption[field] = value
238+
239+
# Store caption if it has any data
240+
if caption:
241+
caption_key = f"{mol_type}_caption"
242+
node_data[caption_key] = caption
243+
logger.debug("Stored %s caption for node %s with %d fields", mol_type, node_id, len(caption))
244+
245+
# For backward compatibility, also attach sequence and other fields from the primary molecule type
246+
# Use the detected molecule_type or default to the first available type
247+
primary_chunk = None
248+
if chunks_by_type.get(molecule_type):
249+
primary_chunk = chunks_by_type[molecule_type][0]
250+
elif chunks_by_type["dna"]:
251+
primary_chunk = chunks_by_type["dna"][0]
252+
elif chunks_by_type["rna"]:
253+
primary_chunk = chunks_by_type["rna"][0]
254+
elif chunks_by_type["protein"]:
255+
primary_chunk = chunks_by_type["protein"][0]
256+
else:
257+
primary_chunk = omics_chunks[0]
258+
259+
# Attach sequence if not already present (for backward compatibility)
202260
if "sequence" not in node_data:
203-
sequence = get_chunk_value("sequence")
261+
sequence = get_chunk_value(primary_chunk, "sequence")
204262
if sequence:
205263
node_data["sequence"] = sequence
206-
else:
207-
logger.warning("No sequence found in chunk for node %s. Chunk keys: %s", node_id, list(first_chunk.keys())[:15])
208264

209265
# Attach molecule_type if not present
210266
if "molecule_type" not in node_data:
211-
chunk_molecule_type = get_chunk_value("molecule_type")
267+
chunk_molecule_type = get_chunk_value(primary_chunk, "molecule_type")
212268
if chunk_molecule_type:
213269
node_data["molecule_type"] = chunk_molecule_type
214270

215-
# Attach molecule-specific fields
216-
field_mapping = {
217-
"protein": ["protein_name", "gene_names", "organism", "function", "id", "database", "entry_name", "uniprot_id"],
218-
"dna": ["gene_name", "gene_description", "organism", "chromosome", "genomic_location", "function", "gene_type", "id", "database"],
219-
"rna": ["rna_type", "description", "organism", "related_genes", "gene_name", "so_term", "id", "database", "rnacentral_id"],
220-
}
221-
271+
# Attach molecule-specific fields from primary chunk (for backward compatibility)
222272
for field in field_mapping.get(molecule_type, []):
223273
if field not in node_data:
224-
value = get_chunk_value(field)
274+
value = get_chunk_value(primary_chunk, field)
225275
if value:
226276
node_data[field] = value
227277

graphgen/operators/search/search_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,11 @@ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
242242
if doc_type not in ["text", "dna", "rna", "protein"]:
243243
doc_type = "text"
244244

245-
# Generate document ID from result ID or search query
246-
doc_id = result.get("id") or result.get("_search_query") or f"search-{len(result_rows)}"
245+
# Convert to string to handle Ray Data ListElement and other types
246+
raw_doc_id = result.get("id") or result.get("_search_query") or f"search-{len(result_rows)}"
247+
doc_id = str(raw_doc_id)
248+
249+
# Ensure doc_id starts with "doc-" prefix
247250
if not doc_id.startswith("doc-"):
248251
doc_id = f"doc-{doc_id}"
249252

0 commit comments

Comments
 (0)