@@ -136,7 +136,9 @@ def __init__(
136136 self .video_end = self .VID_END
137137 self .image_patch_id = self .tokenizer .convert_tokens_to_ids ("<|IMAGE_PLACEHOLDER|>" )
138138 self .image_start_id = self .tokenizer .convert_tokens_to_ids (self .image_start )
139+ self .image_end_id = self .tokenizer .convert_tokens_to_ids (self .image_end )
139140 self .video_start_id = self .tokenizer .convert_tokens_to_ids (self .video_start )
141+ self .video_end_id = self .tokenizer .convert_tokens_to_ids (self .video_end )
140142 self .sep_token_id = self .tokenizer .convert_tokens_to_ids (self .sep_token )
141143 self .eos_token_id = self .tokenizer .convert_tokens_to_ids (self .eos_token )
142144
@@ -243,14 +245,7 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N
243245
244246 return outputs
245247
246- def request2ids (
247- self , request : Dict [str , Any ], tgts : List [str ] = None
248- ) -> Dict [str , Union [np .ndarray , List [np .ndarray ], None ]]:
249- """
250- Convert chat messages into model inputs.
251- Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
252- """
253-
248+ def extract_mm_items (self , request : Dict [str , Any ]):
254249 messages = parse_chat_messages (request .get ("messages" ))
255250 mm_items = []
256251 for msg in messages :
@@ -273,6 +268,7 @@ def request2ids(
273268 if len (missing_hashes ) > 0 and not self .enable_processor_cache :
274269 raise ValueError ("Missing items cannot be retrieved without processor cache." )
275270
271+ dealer = None
276272 if self .enable_processor_cache :
277273 context = zmq .Context ()
278274 dealer = context .socket (zmq .DEALER )
@@ -295,6 +291,16 @@ def request2ids(
295291 video_uuid .append (item ["uuid" ])
296292 else :
297293 raise ValueError (f"Unsupported multimodal type: { item .get ('type' )} " )
294+ return images , videos , image_uuid , video_uuid , dealer , missing_idx , mm_items
295+
296+ def request2ids (
297+ self , request : Dict [str , Any ], tgts : List [str ] = None
298+ ) -> Dict [str , Union [np .ndarray , List [np .ndarray ], None ]]:
299+ """
300+ Convert chat messages into model inputs.
301+ Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
302+ """
303+ images , videos , image_uuid , video_uuid , dealer , missing_idx , mm_items = self .extract_mm_items (request )
298304
299305 if self .tokenizer .chat_template is None :
300306 raise ValueError ("This model does not support chat template." )
@@ -329,6 +335,115 @@ def request2ids(
329335
330336 return outputs
331337
338+ def prompt_token_ids2outputs (
339+ self , request : Dict [str , Any ], tgts : List [str ] = None
340+ ) -> Dict [str , Union [np .ndarray , List [np .ndarray ], None ]]:
341+ outputs = {
342+ "input_ids" : [],
343+ "token_type_ids" : [],
344+ "position_ids" : [],
345+ "images" : [],
346+ "grid_thw" : [],
347+ "image_type_ids" : [],
348+ "labels" : [],
349+ "cur_position" : 0 ,
350+ "video_cnt" : 0 ,
351+ "num_input_image_tokens" : 0 ,
352+ "num_input_video_tokens" : 0 ,
353+ "mm_positions" : [],
354+ "mm_hashes" : [],
355+ }
356+ prompt_token_ids = request .get ("prompt_token_ids" , [])
357+ prompt_token_ids_len = len (prompt_token_ids )
358+ if not request .get ("messages" ):
359+ outputs ["input_ids" ].extend (prompt_token_ids )
360+ outputs ["token_type_ids" ].extend ([IDS_TYPE_FLAG ["text" ]] * prompt_token_ids_len )
361+ for i in range (prompt_token_ids_len ):
362+ outputs ["position_ids" ].append ([i ] * 3 )
363+ outputs ["cur_position" ] += prompt_token_ids_len
364+ return outputs
365+ images , videos , image_uuid , video_uuid , dealer , missing_idx , mm_items = self .extract_mm_items (request )
366+ st , image_idx , video_idx = 0 , 0 , 0
367+ while st < prompt_token_ids_len :
368+ cur_token_id = prompt_token_ids [st ]
369+ if cur_token_id == self .image_start_id :
370+ if image_idx >= len (images ):
371+ raise ValueError ("prompt token ids has more image placeholder than in messages" )
372+ # append image_start_id
373+ outputs ["input_ids" ].extend ([cur_token_id ])
374+ outputs ["token_type_ids" ].extend ([IDS_TYPE_FLAG ["text" ]])
375+ outputs ["position_ids" ].append ([outputs ["cur_position" ]] * 3 )
376+ outputs ["cur_position" ] += 1
377+ st += 1
378+ # process placeholder token ids
379+ cur_idx = st
380+ while cur_idx < prompt_token_ids_len and prompt_token_ids [cur_idx ] != self .image_end_id :
381+ cur_idx += 1
382+ if cur_idx >= prompt_token_ids_len :
383+ raise ValueError ("image token ids not complete" )
384+ image = images [image_idx ]
385+ uuid = image_uuid [image_idx ] if image_uuid else None
386+ token_len = cur_idx - st
387+ if not isinstance (image , tuple ):
388+ self ._add_image (image , outputs , uuid , token_len )
389+ else :
390+ self ._add_processed_image (image , outputs , uuid , token_len )
391+ image_idx += 1
392+ st = cur_idx
393+ elif cur_token_id == self .video_start_id :
394+ if video_idx >= len (videos ):
395+ raise ValueError ("prompt token ids has more video placeholder than in messages" )
396+ # append video_start_id
397+ outputs ["input_ids" ].extend ([cur_token_id ])
398+ outputs ["token_type_ids" ].extend ([IDS_TYPE_FLAG ["text" ]])
399+ outputs ["position_ids" ].append ([outputs ["cur_position" ]] * 3 )
400+ outputs ["cur_position" ] += 1
401+ st += 1
402+ # process placeholder token ids
403+ cur_idx = st
404+ while cur_idx < prompt_token_ids_len and prompt_token_ids [cur_idx ] != self .video_end_id :
405+ cur_idx += 1
406+ if cur_idx >= prompt_token_ids_len :
407+ raise ValueError ("video token ids not complete" )
408+ video = videos [video_idx ]
409+ uuid = video_uuid [video_idx ] if video_uuid else None
410+ token_len = cur_idx - st
411+ if not isinstance (video , tuple ):
412+ if isinstance (video , dict ):
413+ frames = self ._load_and_process_video (video ["video" ], video )
414+ else :
415+ frames = self ._load_and_process_video (video , {})
416+ self ._add_video (frames , outputs , uuid , token_len )
417+ else :
418+ self ._add_processed_video (video , outputs , uuid , token_len )
419+ video_idx += 1
420+ st = cur_idx
421+ else :
422+ outputs ["input_ids" ].extend ([cur_token_id ])
423+ outputs ["token_type_ids" ].extend ([IDS_TYPE_FLAG ["text" ]])
424+ outputs ["position_ids" ].append ([outputs ["cur_position" ]] * 3 )
425+ outputs ["cur_position" ] += 1
426+ st += 1
427+ if image_idx != len (images ):
428+ raise ValueError ("number of images does not match" )
429+ if video_idx != len (videos ):
430+ raise ValueError ("number of videos does not match" )
431+
432+ if self .enable_processor_cache :
433+ missing_idx = set (missing_idx )
434+ hashes_to_cache , items_to_cache = [], []
435+ for idx in range (len (mm_items )):
436+ if idx in missing_idx :
437+ continue
438+ meta = {}
439+ t , h , w = outputs ["grid_thw" ][idx ][0 ]
440+ meta ["thw" ] = (t , h , w )
441+ hashes_to_cache .append (outputs ["mm_hashes" ][idx ])
442+ items_to_cache .append ((outputs ["images" ][idx ], meta ))
443+ self .update_processor_cache (dealer , hashes_to_cache , items_to_cache )
444+
445+ return outputs
446+
332447 def _add_special_token (self , token : Union [str , int ], outputs : Dict ) -> None :
333448 token_id = token if isinstance (token , int ) else self .tokenizer .convert_tokens_to_ids (token )
334449 outputs ["input_ids" ].append (token_id )
@@ -348,14 +463,16 @@ def _add_text(self, tokens, outputs: Dict) -> None:
348463 outputs ["position_ids" ].append ([start + i ] * 3 )
349464 outputs ["cur_position" ] += len (tokens )
350465
351- def _add_image (self , img , outputs : Dict , uuid : Optional [str ]) -> None :
466+ def _add_image (self , img , outputs : Dict , uuid : Optional [str ], token_len = None ) -> None :
352467 patches_h , patches_w = self .image_preprocessor .get_smarted_resize (
353468 img .height ,
354469 img .width ,
355470 min_pixels = self .image_min_pixels ,
356471 max_pixels = self .image_max_pixels ,
357472 )[1 ]
358473 num_tokens = (patches_h * patches_w ) // (self .spatial_conv_size ** 2 )
474+ if token_len and token_len != num_tokens :
475+ raise ValueError ("image tokens num not match the size" )
359476
360477 outputs ["mm_positions" ].append (ImagePosition (len (outputs ["input_ids" ]), num_tokens ))
361478 outputs ["input_ids" ].extend ([self .image_patch_id ] * num_tokens )
@@ -383,9 +500,13 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
383500 outputs ["grid_thw" ].append (ret ["image_grid_thw" ])
384501 outputs ["image_type_ids" ].append (0 )
385502
386- def _add_processed_image (self , img_cache : Tuple [np .ndarray , dict ], outputs : Dict , uuid : str ) -> None :
503+ def _add_processed_image (
504+ self , img_cache : Tuple [np .ndarray , dict ], outputs : Dict , uuid : str , token_len = None
505+ ) -> None :
387506 img , meta = img_cache
388507 num_tokens = img .shape [0 ] // (self .spatial_conv_size ** 2 )
508+ if token_len and num_tokens != token_len :
509+ raise ValueError ("image tokens num not match the size" )
389510
390511 outputs ["mm_positions" ].append (ImagePosition (len (outputs ["input_ids" ]), num_tokens ))
391512 outputs ["input_ids" ].extend ([self .image_patch_id ] * num_tokens )
@@ -401,7 +522,7 @@ def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict
401522 outputs ["grid_thw" ].append (np .array ([[1 , h , w ]]))
402523 outputs ["image_type_ids" ].append (0 )
403524
404- def _add_video (self , frames , outputs : Dict , uuid : Optional [str ]) -> None :
525+ def _add_video (self , frames , outputs : Dict , uuid : Optional [str ], token_len = None ) -> None :
405526 patches_h , patches_w = self .image_preprocessor .get_smarted_resize (
406527 frames [0 ].height ,
407528 frames [0 ].width ,
@@ -410,6 +531,8 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
410531 )[1 ]
411532 num_frames = len (frames )
412533 num_tokens = (num_frames * patches_h * patches_w ) // (self .spatial_conv_size ** 2 * self .temporal_conv_size )
534+ if token_len and num_tokens != token_len :
535+ raise ValueError ("video tokens num not match the size" )
413536
414537 pixel_stack = np .stack ([np .array (f .convert ("RGB" )) for f in frames ], axis = 0 )
415538 ret = self .image_preprocessor .preprocess (
@@ -438,9 +561,13 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
438561 outputs ["position_ids" ].extend (pos_ids )
439562 outputs ["cur_position" ] = np .max (pos_ids ) + 1
440563
441- def _add_processed_video (self , frames_cache : Tuple [np .ndarray , dict ], outputs : Dict , uuid : str ) -> None :
564+ def _add_processed_video (
565+ self , frames_cache : Tuple [np .ndarray , dict ], outputs : Dict , uuid : str , token_len = None
566+ ) -> None :
442567 frames , meta = frames_cache
443568 num_tokens = frames .shape [0 ] // (self .spatial_conv_size ** 2 * self .temporal_conv_size )
569+ if token_len and num_tokens != token_len :
570+ raise ValueError ("video tokens num not match the size" )
444571
445572 t , h , w = meta ["thw" ]
446573 outputs ["images" ].append (frames )
0 commit comments