@@ -86,42 +86,60 @@ def _extract_caption(node_data: dict, molecule_type: str) -> Optional[dict]:
8686 elif isinstance (node_data [caption_key ], dict ):
8787 return node_data [caption_key ]
8888
89- # Extract from metadata or node data based on molecule type
89+ # Field mappings for each molecule type
90+ field_mapping = {
91+ "protein" : ["protein_name" , "gene_names" , "organism" , "function" , "sequence" , "id" , "database" , "entry_name" , "uniprot_id" ],
92+ "dna" : ["gene_name" , "gene_description" , "organism" , "chromosome" , "genomic_location" , "function" , "gene_type" , "id" , "database" , "sequence" ],
93+ "rna" : ["rna_type" , "description" , "organism" , "related_genes" , "gene_name" , "so_term" , "id" , "database" , "rnacentral_id" , "sequence" ],
94+ }
95+
96+ # Extract fields based on molecule type
9097 caption = {}
98+ caption_fields = field_mapping .get (molecule_type_lower , [])
99+ for field in caption_fields :
100+ if field in node_data and node_data [field ]:
101+ caption [field ] = node_data [field ]
91102
103+ # Special handling for protein: check search results and existing protein field
92104 if molecule_type_lower == "protein" :
93- # Extract protein-specific fields
94- if "protein" in node_data and node_data ["protein" ]:
95- if isinstance (node_data ["protein" ], list ) and len (node_data ["protein" ]) > 0 :
96- return node_data ["protein" ][0 ] if isinstance (node_data ["protein" ][0 ], dict ) else node_data ["protein" ]
97- elif isinstance (node_data ["protein" ], dict ):
98- return node_data ["protein" ]
105+ # Check for search result data (from UniProt search)
106+ if "_search_results" in node_data :
107+ search_results = node_data ["_search_results" ]
108+ if isinstance (search_results , list ) and len (search_results ) > 0 :
109+ first_result = search_results [0 ]
110+ if isinstance (first_result , dict ):
111+ search_caption = {
112+ "id" : first_result .get ("id" , "" ),
113+ "protein_name" : first_result .get ("protein_name" , "" ),
114+ "gene_names" : first_result .get ("gene_names" , []),
115+ "organism" : first_result .get ("organism" , "" ),
116+ "function" : first_result .get ("function" , []),
117+ "sequence" : node_data .get ("sequence" ) or first_result .get ("sequence" , "" ),
118+ "database" : "UniProt"
119+ }
120+ # Remove empty fields and return if any data exists
121+ search_caption = {k : v for k , v in search_caption .items () if v }
122+ if search_caption :
123+ return search_caption
99124
100- # Fallback: extract from node data fields
101- caption_fields = ["protein_name" , "gene_names" , "organism" , "function" , "sequence" , "id" , "database" ]
102- for field in caption_fields :
103- if field in node_data :
104- caption [field ] = node_data [field ]
105-
106- elif molecule_type_lower == "dna" :
107- # Extract DNA-specific fields
108- caption_fields = [
109- "gene_name" , "gene_description" , "organism" , "chromosome" ,
110- "genomic_location" , "function" , "gene_type" , "id" , "database"
111- ]
112- for field in caption_fields :
113- if field in node_data :
114- caption [field ] = node_data [field ]
125+ # Merge with existing protein field if present
126+ if "protein" in node_data and node_data ["protein" ]:
127+ existing_protein = node_data ["protein" ]
128+ if isinstance (existing_protein , list ) and len (existing_protein ) > 0 :
129+ existing_protein = existing_protein [0 ] if isinstance (existing_protein [0 ], dict ) else existing_protein
130+ if isinstance (existing_protein , dict ):
131+ for key , value in existing_protein .items ():
132+ if key not in caption and value :
133+ caption [key ] = value
134+ # Ensure sequence from node_data takes precedence
135+ if "sequence" in node_data and node_data ["sequence" ]:
136+ caption ["sequence" ] = node_data ["sequence" ]
115137
116- elif molecule_type_lower == "rna" :
117- # Extract RNA-specific fields
118- caption_fields = [
119- "rna_type" , "description" , "organism" , "related_genes" ,
120- "gene_name" , "so_term" , "id" , "database" , "rnacentral_id"
121- ]
122- for field in caption_fields :
123- if field in node_data :
124- caption [field ] = node_data [field ]
138+ # Fallback to description if no caption found
139+ if not caption and "description" in node_data :
140+ description = node_data ["description" ]
141+ if isinstance (description , str ) and len (description ) > 10 :
142+ caption ["description" ] = description
125143
126144 return caption if caption else None
127145
@@ -134,24 +152,58 @@ def _detect_molecule_type(nodes: list[tuple[str, dict]]) -> str:
134152 :param nodes: List of (node_id, node_data) tuples
135153 :return: Detected molecule type ("dna", "rna", "protein", or "unknown")
136154 """
155+ if not nodes :
156+ return "unknown"
157+
158+ # Type indicators for each molecule type
159+ type_indicators = {
160+ "protein" : {
161+ "fields" : ["protein_name" , "uniprot_id" , "entry_name" , "protein_caption" ],
162+ "source_prefix" : "protein-" ,
163+ "description_keywords" : ["protein" ],
164+ },
165+ "dna" : {
166+ "fields" : ["gene_name" , "chromosome" , "genomic_location" ],
167+ "source_prefix" : "dna-" ,
168+ "description_keywords" : ["gene" , "dna" , "chromosome" ],
169+ },
170+ "rna" : {
171+ "fields" : ["rna_type" , "rnacentral_id" ],
172+ "source_prefix" : "rna-" ,
173+ "description_keywords" : ["rna" , "transcript" ],
174+ },
175+ }
176+
137177 for _ , node_data in nodes :
138- # Check node type field
139- node_type = node_data .get ("type" , "" ).lower ()
140- if node_type in ("dna" , "rna" , "protein" ):
141- return node_type
178+ # Priority 1: Check explicit type fields (most reliable)
179+ for field in ["type" , "molecule_type" ]:
180+ value = node_data .get (field , "" ).lower ()
181+ if value in ("dna" , "rna" , "protein" ):
182+ return value
183+
184+ # Priority 2: Check source_id prefix
185+ source_id = node_data .get ("source_id" , "" ).lower ()
186+ for mol_type , indicators in type_indicators .items ():
187+ if source_id .startswith (indicators ["source_prefix" ]):
188+ return mol_type
142189
143- # Check molecule_type in metadata or node data
144- molecule_type = node_data .get ("molecule_type" , "" ).lower ()
145- if molecule_type in ("dna" , "rna" , "protein" ):
146- return molecule_type
190+ # Priority 3: Check type-specific fields
191+ for mol_type , indicators in type_indicators .items ():
192+ if any (key in node_data for key in indicators ["fields" ]):
193+ # Special check for DNA: need chromosome or genomic_location
194+ if mol_type == "dna" and not any (key in node_data for key in ["chromosome" , "genomic_location" ]):
195+ continue
196+ return mol_type
147197
148- # Check for type-specific fields
149- if "protein" in node_data or "protein_name" in node_data or "protein_caption" in node_data :
150- return "protein"
151- if "gene_name" in node_data and "chromosome" in node_data :
152- return "dna"
153- if "rna_type" in node_data or "rnacentral_id" in node_data :
154- return "rna"
198+ # Priority 4: Check description keywords
199+ description = node_data .get ("description" , "" ).lower ()
200+ for mol_type , indicators in type_indicators .items ():
201+ keywords = indicators ["description_keywords" ]
202+ if any (kw in description for kw in keywords ):
203+ # Special check: "protein" in description but not "gene"
204+ if mol_type == "protein" and "gene" in description :
205+ continue
206+ return mol_type
155207
156208 return "unknown"
157209
@@ -182,7 +234,7 @@ async def generate(
182234 # Only attach caption once per batch (from the first relevant node)
183235 caption_attached = False
184236 for node in nodes :
185- node_data = node [ 1 ]
237+ node_id , node_data = node
186238 caption = self ._extract_caption (node_data , molecule_type )
187239
188240 if caption and not caption_attached :
@@ -193,6 +245,14 @@ async def generate(
193245 caption_attached = True
194246 break # Only need to attach once per batch
195247
248+ if not caption_attached :
249+ logger .warning (f"No caption extracted for molecule_type={ molecule_type } . Node data sample: { dict (list (nodes [0 ][1 ].items ())[:5 ]) if nodes else 'No nodes' } " )
250+ # Still attach empty captions to maintain format consistency
251+ for qa in qa_pairs .values ():
252+ qa .setdefault ("dna" , "" )
253+ qa .setdefault ("rna" , "" )
254+ qa .setdefault ("protein" , "" )
255+
196256 result .update (qa_pairs )
197257 return result
198258
@@ -204,61 +264,71 @@ def format_generation_results(
204264 Format generation results with molecule-specific caption fields.
205265 Supports dna, rna, and protein fields in output.
206266 """
267+ # Extract QA pairs and molecule captions
268+ qa_items = [
269+ {
270+ "question" : v ["question" ],
271+ "answer" : v ["answer" ],
272+ "dna" : v .get ("dna" , "" ),
273+ "rna" : v .get ("rna" , "" ),
274+ "protein" : v .get ("protein" , "" ),
275+ }
276+ for item in results
277+ for k , v in item .items ()
278+ ]
279+
280+ # Format based on output format
207281 if output_data_format == "Alpaca" :
208- results = [
282+ return [
209283 {
210- "instruction" : v ["question" ],
284+ "instruction" : qa ["question" ],
211285 "input" : "" ,
212- "output" : v ["answer" ],
213- "dna" : v . get ( "dna" , "" ) ,
214- "rna" : v . get ( "rna" , "" ) ,
215- "protein" : v . get ( "protein" , "" ) ,
286+ "output" : qa ["answer" ],
287+ "dna" : qa [ "dna" ] ,
288+ "rna" : qa [ "rna" ] ,
289+ "protein" : qa [ "protein" ] ,
216290 }
217- for item in results
218- for k , v in item .items ()
291+ for qa in qa_items
219292 ]
220293 elif output_data_format == "Sharegpt" :
221- results = [
294+ return [
222295 {
223296 "conversations" : [
224297 {
225298 "from" : "human" ,
226299 "value" : [
227300 {
228- "text" : v ["question" ],
229- "dna" : v . get ( "dna" , "" ) ,
230- "rna" : v . get ( "rna" , "" ) ,
231- "protein" : v . get ( "protein" , "" ) ,
301+ "text" : qa ["question" ],
302+ "dna" : qa [ "dna" ] ,
303+ "rna" : qa [ "rna" ] ,
304+ "protein" : qa [ "protein" ] ,
232305 }
233306 ],
234307 },
235- {"from" : "gpt" , "value" : v ["answer" ]},
308+ {"from" : "gpt" , "value" : qa ["answer" ]},
236309 ]
237310 }
238- for item in results
239- for k , v in item .items ()
311+ for qa in qa_items
240312 ]
241313 elif output_data_format == "ChatML" :
242- results = [
314+ return [
243315 {
244316 "messages" : [
245317 {
246318 "role" : "user" ,
247319 "content" : [
248320 {
249- "text" : v ["question" ],
250- "dna" : v . get ( "dna" , "" ) ,
251- "rna" : v . get ( "rna" , "" ) ,
252- "protein" : v . get ( "protein" , "" ) ,
321+ "text" : qa ["question" ],
322+ "dna" : qa [ "dna" ] ,
323+ "rna" : qa [ "rna" ] ,
324+ "protein" : qa [ "protein" ] ,
253325 }
254326 ],
255327 },
256- {"role" : "assistant" , "content" : v ["answer" ]},
328+ {"role" : "assistant" , "content" : qa ["answer" ]},
257329 ]
258330 }
259- for item in results
260- for k , v in item .items ()
331+ for qa in qa_items
261332 ]
262333 else :
263334 raise ValueError (f"Unknown output data format: { output_data_format } " )
264- return results
0 commit comments