Skip to content

Commit 21e2576

Browse files
committed
feat: enhance OmicsQAGenerator and PartitionService for data extraction
1 parent 9281181 commit 21e2576

File tree

3 files changed

+213
-75
lines changed

3 files changed

+213
-75
lines changed

graphgen/models/generator/omics_qa_generator.py

Lines changed: 142 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -86,42 +86,60 @@ def _extract_caption(node_data: dict, molecule_type: str) -> Optional[dict]:
8686
elif isinstance(node_data[caption_key], dict):
8787
return node_data[caption_key]
8888

89-
# Extract from metadata or node data based on molecule type
89+
# Field mappings for each molecule type
90+
field_mapping = {
91+
"protein": ["protein_name", "gene_names", "organism", "function", "sequence", "id", "database", "entry_name", "uniprot_id"],
92+
"dna": ["gene_name", "gene_description", "organism", "chromosome", "genomic_location", "function", "gene_type", "id", "database", "sequence"],
93+
"rna": ["rna_type", "description", "organism", "related_genes", "gene_name", "so_term", "id", "database", "rnacentral_id", "sequence"],
94+
}
95+
96+
# Extract fields based on molecule type
9097
caption = {}
98+
caption_fields = field_mapping.get(molecule_type_lower, [])
99+
for field in caption_fields:
100+
if field in node_data and node_data[field]:
101+
caption[field] = node_data[field]
91102

103+
# Special handling for protein: check search results and existing protein field
92104
if molecule_type_lower == "protein":
93-
# Extract protein-specific fields
94-
if "protein" in node_data and node_data["protein"]:
95-
if isinstance(node_data["protein"], list) and len(node_data["protein"]) > 0:
96-
return node_data["protein"][0] if isinstance(node_data["protein"][0], dict) else node_data["protein"]
97-
elif isinstance(node_data["protein"], dict):
98-
return node_data["protein"]
105+
# Check for search result data (from UniProt search)
106+
if "_search_results" in node_data:
107+
search_results = node_data["_search_results"]
108+
if isinstance(search_results, list) and len(search_results) > 0:
109+
first_result = search_results[0]
110+
if isinstance(first_result, dict):
111+
search_caption = {
112+
"id": first_result.get("id", ""),
113+
"protein_name": first_result.get("protein_name", ""),
114+
"gene_names": first_result.get("gene_names", []),
115+
"organism": first_result.get("organism", ""),
116+
"function": first_result.get("function", []),
117+
"sequence": node_data.get("sequence") or first_result.get("sequence", ""),
118+
"database": "UniProt"
119+
}
120+
# Remove empty fields and return if any data exists
121+
search_caption = {k: v for k, v in search_caption.items() if v}
122+
if search_caption:
123+
return search_caption
99124

100-
# Fallback: extract from node data fields
101-
caption_fields = ["protein_name", "gene_names", "organism", "function", "sequence", "id", "database"]
102-
for field in caption_fields:
103-
if field in node_data:
104-
caption[field] = node_data[field]
105-
106-
elif molecule_type_lower == "dna":
107-
# Extract DNA-specific fields
108-
caption_fields = [
109-
"gene_name", "gene_description", "organism", "chromosome",
110-
"genomic_location", "function", "gene_type", "id", "database"
111-
]
112-
for field in caption_fields:
113-
if field in node_data:
114-
caption[field] = node_data[field]
125+
# Merge with existing protein field if present
126+
if "protein" in node_data and node_data["protein"]:
127+
existing_protein = node_data["protein"]
128+
if isinstance(existing_protein, list) and len(existing_protein) > 0:
129+
existing_protein = existing_protein[0] if isinstance(existing_protein[0], dict) else existing_protein
130+
if isinstance(existing_protein, dict):
131+
for key, value in existing_protein.items():
132+
if key not in caption and value:
133+
caption[key] = value
134+
# Ensure sequence from node_data takes precedence
135+
if "sequence" in node_data and node_data["sequence"]:
136+
caption["sequence"] = node_data["sequence"]
115137

116-
elif molecule_type_lower == "rna":
117-
# Extract RNA-specific fields
118-
caption_fields = [
119-
"rna_type", "description", "organism", "related_genes",
120-
"gene_name", "so_term", "id", "database", "rnacentral_id"
121-
]
122-
for field in caption_fields:
123-
if field in node_data:
124-
caption[field] = node_data[field]
138+
# Fallback to description if no caption found
139+
if not caption and "description" in node_data:
140+
description = node_data["description"]
141+
if isinstance(description, str) and len(description) > 10:
142+
caption["description"] = description
125143

126144
return caption if caption else None
127145

@@ -134,24 +152,58 @@ def _detect_molecule_type(nodes: list[tuple[str, dict]]) -> str:
134152
:param nodes: List of (node_id, node_data) tuples
135153
:return: Detected molecule type ("dna", "rna", "protein", or "unknown")
136154
"""
155+
if not nodes:
156+
return "unknown"
157+
158+
# Type indicators for each molecule type
159+
type_indicators = {
160+
"protein": {
161+
"fields": ["protein_name", "uniprot_id", "entry_name", "protein_caption"],
162+
"source_prefix": "protein-",
163+
"description_keywords": ["protein"],
164+
},
165+
"dna": {
166+
"fields": ["gene_name", "chromosome", "genomic_location"],
167+
"source_prefix": "dna-",
168+
"description_keywords": ["gene", "dna", "chromosome"],
169+
},
170+
"rna": {
171+
"fields": ["rna_type", "rnacentral_id"],
172+
"source_prefix": "rna-",
173+
"description_keywords": ["rna", "transcript"],
174+
},
175+
}
176+
137177
for _, node_data in nodes:
138-
# Check node type field
139-
node_type = node_data.get("type", "").lower()
140-
if node_type in ("dna", "rna", "protein"):
141-
return node_type
178+
# Priority 1: Check explicit type fields (most reliable)
179+
for field in ["type", "molecule_type"]:
180+
value = node_data.get(field, "").lower()
181+
if value in ("dna", "rna", "protein"):
182+
return value
183+
184+
# Priority 2: Check source_id prefix
185+
source_id = node_data.get("source_id", "").lower()
186+
for mol_type, indicators in type_indicators.items():
187+
if source_id.startswith(indicators["source_prefix"]):
188+
return mol_type
142189

143-
# Check molecule_type in metadata or node data
144-
molecule_type = node_data.get("molecule_type", "").lower()
145-
if molecule_type in ("dna", "rna", "protein"):
146-
return molecule_type
190+
# Priority 3: Check type-specific fields
191+
for mol_type, indicators in type_indicators.items():
192+
if any(key in node_data for key in indicators["fields"]):
193+
# Special check for DNA: need chromosome or genomic_location
194+
if mol_type == "dna" and not any(key in node_data for key in ["chromosome", "genomic_location"]):
195+
continue
196+
return mol_type
147197

148-
# Check for type-specific fields
149-
if "protein" in node_data or "protein_name" in node_data or "protein_caption" in node_data:
150-
return "protein"
151-
if "gene_name" in node_data and "chromosome" in node_data:
152-
return "dna"
153-
if "rna_type" in node_data or "rnacentral_id" in node_data:
154-
return "rna"
198+
# Priority 4: Check description keywords
199+
description = node_data.get("description", "").lower()
200+
for mol_type, indicators in type_indicators.items():
201+
keywords = indicators["description_keywords"]
202+
if any(kw in description for kw in keywords):
203+
# Special check: "protein" in description but not "gene"
204+
if mol_type == "protein" and "gene" in description:
205+
continue
206+
return mol_type
155207

156208
return "unknown"
157209

@@ -182,7 +234,7 @@ async def generate(
182234
# Only attach caption once per batch (from the first relevant node)
183235
caption_attached = False
184236
for node in nodes:
185-
node_data = node[1]
237+
node_id, node_data = node
186238
caption = self._extract_caption(node_data, molecule_type)
187239

188240
if caption and not caption_attached:
@@ -193,6 +245,14 @@ async def generate(
193245
caption_attached = True
194246
break # Only need to attach once per batch
195247

248+
if not caption_attached:
249+
logger.warning(f"No caption extracted for molecule_type={molecule_type}. Node data sample: {dict(list(nodes[0][1].items())[:5]) if nodes else 'No nodes'}")
250+
# Still attach empty captions to maintain format consistency
251+
for qa in qa_pairs.values():
252+
qa.setdefault("dna", "")
253+
qa.setdefault("rna", "")
254+
qa.setdefault("protein", "")
255+
196256
result.update(qa_pairs)
197257
return result
198258

@@ -204,61 +264,71 @@ def format_generation_results(
204264
Format generation results with molecule-specific caption fields.
205265
Supports dna, rna, and protein fields in output.
206266
"""
267+
# Extract QA pairs and molecule captions
268+
qa_items = [
269+
{
270+
"question": v["question"],
271+
"answer": v["answer"],
272+
"dna": v.get("dna", ""),
273+
"rna": v.get("rna", ""),
274+
"protein": v.get("protein", ""),
275+
}
276+
for item in results
277+
for k, v in item.items()
278+
]
279+
280+
# Format based on output format
207281
if output_data_format == "Alpaca":
208-
results = [
282+
return [
209283
{
210-
"instruction": v["question"],
284+
"instruction": qa["question"],
211285
"input": "",
212-
"output": v["answer"],
213-
"dna": v.get("dna", ""),
214-
"rna": v.get("rna", ""),
215-
"protein": v.get("protein", ""),
286+
"output": qa["answer"],
287+
"dna": qa["dna"],
288+
"rna": qa["rna"],
289+
"protein": qa["protein"],
216290
}
217-
for item in results
218-
for k, v in item.items()
291+
for qa in qa_items
219292
]
220293
elif output_data_format == "Sharegpt":
221-
results = [
294+
return [
222295
{
223296
"conversations": [
224297
{
225298
"from": "human",
226299
"value": [
227300
{
228-
"text": v["question"],
229-
"dna": v.get("dna", ""),
230-
"rna": v.get("rna", ""),
231-
"protein": v.get("protein", ""),
301+
"text": qa["question"],
302+
"dna": qa["dna"],
303+
"rna": qa["rna"],
304+
"protein": qa["protein"],
232305
}
233306
],
234307
},
235-
{"from": "gpt", "value": v["answer"]},
308+
{"from": "gpt", "value": qa["answer"]},
236309
]
237310
}
238-
for item in results
239-
for k, v in item.items()
311+
for qa in qa_items
240312
]
241313
elif output_data_format == "ChatML":
242-
results = [
314+
return [
243315
{
244316
"messages": [
245317
{
246318
"role": "user",
247319
"content": [
248320
{
249-
"text": v["question"],
250-
"dna": v.get("dna", ""),
251-
"rna": v.get("rna", ""),
252-
"protein": v.get("protein", ""),
321+
"text": qa["question"],
322+
"dna": qa["dna"],
323+
"rna": qa["rna"],
324+
"protein": qa["protein"],
253325
}
254326
],
255327
},
256-
{"role": "assistant", "content": v["answer"]},
328+
{"role": "assistant", "content": qa["answer"]},
257329
]
258330
}
259-
for item in results
260-
for k, v in item.items()
331+
for qa in qa_items
261332
]
262333
else:
263334
raise ValueError(f"Unknown output data format: {output_data_format}")
264-
return results

graphgen/operators/partition/partition_service.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
131131

132132
for node_id, node_data in nodes_data:
133133
entity_type = (node_data.get("entity_type") or "").lower()
134+
134135
if not entity_type:
135136
continue
136137

@@ -139,6 +140,9 @@ def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
139140
for sid in node_data.get("source_id", "").split("<SEP>")
140141
if sid.strip()
141142
]
143+
144+
if not source_ids:
145+
continue
142146

143147
# Handle images
144148
if "image" in entity_type:
@@ -153,5 +157,72 @@ def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
153157
# We'll use the first image chunk found for this node.
154158
node_data["image_data"] = json.loads(image_chunks[0]["content"])
155159
logger.debug("Attached image data to node %s", node_id)
160+
161+
# Handle omics data (protein/dna/rna)
162+
molecule_type = None
163+
if entity_type in ("protein", "dna", "rna"):
164+
molecule_type = entity_type
165+
else:
166+
# Infer from source_id prefix
167+
for sid in source_ids:
168+
sid_lower = sid.lower()
169+
if sid_lower.startswith("protein-"):
170+
molecule_type = "protein"
171+
break
172+
elif sid_lower.startswith("dna-"):
173+
molecule_type = "dna"
174+
break
175+
elif sid_lower.startswith("rna-"):
176+
molecule_type = "rna"
177+
break
178+
179+
if molecule_type:
180+
omics_chunks = [
181+
data
182+
for sid in source_ids
183+
if (data := self.chunk_storage.get_by_id(sid))
184+
]
185+
186+
if not omics_chunks:
187+
logger.warning("No chunks found for node %s (type: %s) with source_ids: %s", node_id, molecule_type, source_ids)
188+
continue
189+
190+
first_chunk = omics_chunks[0]
191+
def get_chunk_value(field: str):
192+
# First check root level of chunk
193+
if field in first_chunk:
194+
return first_chunk[field]
195+
# Then check metadata if it exists and is a dict
196+
chunk_metadata = first_chunk.get("metadata")
197+
if isinstance(chunk_metadata, dict) and field in chunk_metadata:
198+
return chunk_metadata[field]
199+
return None
200+
201+
# Attach sequence if not already present
202+
if "sequence" not in node_data:
203+
sequence = get_chunk_value("sequence")
204+
if sequence:
205+
node_data["sequence"] = sequence
206+
else:
207+
logger.warning("No sequence found in chunk for node %s. Chunk keys: %s", node_id, list(first_chunk.keys())[:15])
208+
209+
# Attach molecule_type if not present
210+
if "molecule_type" not in node_data:
211+
chunk_molecule_type = get_chunk_value("molecule_type")
212+
if chunk_molecule_type:
213+
node_data["molecule_type"] = chunk_molecule_type
214+
215+
# Attach molecule-specific fields
216+
field_mapping = {
217+
"protein": ["protein_name", "gene_names", "organism", "function", "id", "database", "entry_name", "uniprot_id"],
218+
"dna": ["gene_name", "gene_description", "organism", "chromosome", "genomic_location", "function", "gene_type", "id", "database"],
219+
"rna": ["rna_type", "description", "organism", "related_genes", "gene_name", "so_term", "id", "database", "rnacentral_id"],
220+
}
221+
222+
for field in field_mapping.get(molecule_type, []):
223+
if field not in node_data:
224+
value = get_chunk_value(field)
225+
if value:
226+
node_data[field] = value
156227

157228
return nodes_data, edges_data

scripts/generate/generate_protein_qa.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)