@@ -71,35 +71,47 @@ def parse_response(response: str) -> Any:
7171 def _extract_caption (node_data : dict , molecule_type : str ) -> Optional [dict ]:
7272 """
7373 Extract molecule-specific caption information from node data.
74-
74+
7575 :param node_data: Node data dictionary
7676 :param molecule_type: Type of molecule ("dna", "rna", or "protein")
7777 :return: Caption dictionary or None
7878 """
7979 molecule_type_lower = molecule_type .lower ()
80-
80+
8181 # Check if there's already a caption field (e.g., protein_caption, dna_caption, rna_caption)
8282 caption_key = f"{ molecule_type_lower } _caption"
8383 if caption_key in node_data and node_data [caption_key ]:
8484 if isinstance (node_data [caption_key ], list ) and len (node_data [caption_key ]) > 0 :
85- return node_data [caption_key ][0 ] if isinstance (node_data [caption_key ][0 ], dict ) else node_data [caption_key ]
86- elif isinstance (node_data [caption_key ], dict ):
85+ caption_val = node_data [caption_key ]
86+ return caption_val [0 ] if isinstance (caption_val [0 ], dict ) else caption_val
87+ if isinstance (node_data [caption_key ], dict ):
8788 return node_data [caption_key ]
88-
89+
8990 # Field mappings for each molecule type
9091 field_mapping = {
91- "protein" : ["protein_name" , "gene_names" , "organism" , "function" , "sequence" , "id" , "database" , "entry_name" , "uniprot_id" ],
92- "dna" : ["gene_name" , "gene_description" , "organism" , "chromosome" , "genomic_location" , "function" , "gene_type" , "id" , "database" , "sequence" ],
93- "rna" : ["rna_type" , "description" , "organism" , "related_genes" , "gene_name" , "so_term" , "id" , "database" , "rnacentral_id" , "sequence" ],
92+ "protein" : [
93+ "protein_name" , "gene_names" , "organism" , "function" ,
94+ "sequence" , "id" , "database" , "entry_name" , "uniprot_id"
95+ ],
96+ "dna" : [
97+ "gene_name" , "gene_description" , "organism" , "chromosome" ,
98+ "genomic_location" , "function" , "gene_type" , "id" ,
99+ "database" , "sequence"
100+ ],
101+ "rna" : [
102+ "rna_type" , "description" , "organism" , "related_genes" ,
103+ "gene_name" , "so_term" , "id" , "database" ,
104+ "rnacentral_id" , "sequence"
105+ ],
94106 }
95-
107+
96108 # Extract fields based on molecule type
97109 caption = {}
98110 caption_fields = field_mapping .get (molecule_type_lower , [])
99111 for field in caption_fields :
100112 if field in node_data and node_data [field ]:
101113 caption [field ] = node_data [field ]
102-
114+
103115 # Special handling for protein: check search results and existing protein field
104116 if molecule_type_lower == "protein" :
105117 # Check for search result data (from UniProt search)
@@ -121,40 +133,44 @@ def _extract_caption(node_data: dict, molecule_type: str) -> Optional[dict]:
121133 search_caption = {k : v for k , v in search_caption .items () if v }
122134 if search_caption :
123135 return search_caption
124-
136+
125137 # Merge with existing protein field if present
126138 if "protein" in node_data and node_data ["protein" ]:
127139 existing_protein = node_data ["protein" ]
128140 if isinstance (existing_protein , list ) and len (existing_protein ) > 0 :
129- existing_protein = existing_protein [0 ] if isinstance (existing_protein [0 ], dict ) else existing_protein
141+ existing_protein = (
142+ existing_protein [0 ]
143+ if isinstance (existing_protein [0 ], dict )
144+ else existing_protein
145+ )
130146 if isinstance (existing_protein , dict ):
131147 for key , value in existing_protein .items ():
132148 if key not in caption and value :
133149 caption [key ] = value
134150 # Ensure sequence from node_data takes precedence
135151 if "sequence" in node_data and node_data ["sequence" ]:
136152 caption ["sequence" ] = node_data ["sequence" ]
137-
153+
138154 # Fallback to description if no caption found
139155 if not caption and "description" in node_data :
140156 description = node_data ["description" ]
141157 if isinstance (description , str ) and len (description ) > 10 :
142158 caption ["description" ] = description
143-
159+
144160 return caption if caption else None
145161
146162 @staticmethod
147163 def _detect_molecule_type (nodes : list [tuple [str , dict ]]) -> str :
148164 """
149165 Detect molecule type from nodes.
150166 Priority: Check node type, then check metadata, then check node data fields.
151-
167+
152168 :param nodes: List of (node_id, node_data) tuples
153169 :return: Detected molecule type ("dna", "rna", "protein", or "unknown")
154170 """
155171 if not nodes :
156172 return "unknown"
157-
173+
158174 # Type indicators for each molecule type
159175 type_indicators = {
160176 "protein" : {
@@ -173,28 +189,28 @@ def _detect_molecule_type(nodes: list[tuple[str, dict]]) -> str:
173189 "description_keywords" : ["rna" , "transcript" ],
174190 },
175191 }
176-
192+
177193 for _ , node_data in nodes :
178194 # Priority 1: Check explicit type fields (most reliable)
179195 for field in ["type" , "molecule_type" ]:
180196 value = node_data .get (field , "" ).lower ()
181197 if value in ("dna" , "rna" , "protein" ):
182198 return value
183-
199+
184200 # Priority 2: Check source_id prefix
185201 source_id = node_data .get ("source_id" , "" ).lower ()
186202 for mol_type , indicators in type_indicators .items ():
187203 if source_id .startswith (indicators ["source_prefix" ]):
188204 return mol_type
189-
205+
190206 # Priority 3: Check type-specific fields
191207 for mol_type , indicators in type_indicators .items ():
192208 if any (key in node_data for key in indicators ["fields" ]):
193209 # Special check for DNA: need chromosome or genomic_location
194210 if mol_type == "dna" and not any (key in node_data for key in ["chromosome" , "genomic_location" ]):
195211 continue
196212 return mol_type
197-
213+
198214 # Priority 4: Check description keywords
199215 description = node_data .get ("description" , "" ).lower ()
200216 for mol_type , indicators in type_indicators .items ():
@@ -204,7 +220,7 @@ def _detect_molecule_type(nodes: list[tuple[str, dict]]) -> str:
204220 if mol_type == "protein" and "gene" in description :
205221 continue
206222 return mol_type
207-
223+
208224 return "unknown"
209225
210226 async def generate (
@@ -216,51 +232,57 @@ async def generate(
216232 """
217233 Generate QAs based on a given batch.
218234 Automatically extracts and attaches molecule-specific caption information.
219-
235+
220236 :param batch
221237 :return: QA pairs with attached molecule captions
222238 """
223239 result = {}
224240 prompt = self .build_prompt (batch )
225241 response = await self .llm_client .generate_answer (prompt )
226242 qa_pairs = self .parse_response (response ) # generate one or more QA pairs
227-
243+
228244 nodes , _ = batch
229-
245+
230246 # Detect molecule type from nodes
231247 molecule_type = self ._detect_molecule_type (nodes )
232-
248+
233249 # Extract captions for all molecule types from nodes
234250 captions = {"dna" : None , "rna" : None , "protein" : None }
235251 caption_attached = False
236-
252+
237253 for node in nodes :
238- node_id , node_data = node
239-
254+ _ , node_data = node
255+
240256 # Check for pre-extracted captions (from partition_service)
241257 for mol_type in ["dna" , "rna" , "protein" ]:
242258 caption_key = f"{ mol_type } _caption"
243259 if caption_key in node_data and node_data [caption_key ]:
244260 captions [mol_type ] = node_data [caption_key ]
245261 caption_attached = True
246-
262+
247263 # If no pre-extracted captions, extract from node_data using the detected molecule_type
248264 if not caption_attached :
249265 caption = self ._extract_caption (node_data , molecule_type )
250266 if caption :
251267 captions [molecule_type ] = caption
252268 caption_attached = True
253269 break # Only need to extract once per batch
254-
270+
255271 # Attach all captions to QA pairs
256272 for qa in qa_pairs .values ():
257273 qa ["dna" ] = captions ["dna" ] if captions ["dna" ] else ""
258274 qa ["rna" ] = captions ["rna" ] if captions ["rna" ] else ""
259275 qa ["protein" ] = captions ["protein" ] if captions ["protein" ] else ""
260-
276+
261277 if not caption_attached :
262- logger .warning (f"No caption extracted for molecule_type={ molecule_type } . Node data sample: { dict (list (nodes [0 ][1 ].items ())[:5 ]) if nodes else 'No nodes' } " )
263-
278+ node_sample = (
279+ dict (list (nodes [0 ][1 ].items ())[:5 ]) if nodes else 'No nodes'
280+ )
281+ logger .warning (
282+ "No caption extracted for molecule_type=%s. Node data sample: %s" ,
283+ molecule_type , node_sample
284+ )
285+
264286 result .update (qa_pairs )
265287 return result
266288
@@ -284,7 +306,7 @@ def format_generation_results(
284306 for item in results
285307 for k , v in item .items ()
286308 ]
287-
309+
288310 # Format based on output format
289311 if output_data_format == "Alpaca" :
290312 return [
@@ -298,7 +320,7 @@ def format_generation_results(
298320 }
299321 for qa in qa_items
300322 ]
301- elif output_data_format == "Sharegpt" :
323+ if output_data_format == "Sharegpt" :
302324 return [
303325 {
304326 "conversations" : [
@@ -318,7 +340,7 @@ def format_generation_results(
318340 }
319341 for qa in qa_items
320342 ]
321- elif output_data_format == "ChatML" :
343+ if output_data_format == "ChatML" :
322344 return [
323345 {
324346 "messages" : [
0 commit comments