@@ -436,6 +436,129 @@ def test_prompt_token_ids2outputs_add_processed_video(self):
436436 self .assertEqual (len (outputs ["grid_thw" ]), 1 )
437437 self .assertEqual (len (outputs ["image_type_ids" ]), 2 )
438438
439+ def test_prompt_token_ids2outputs_add_image_token_len_mismatch (self ):
440+ test_prompt_token_ids = [101 , 1002 , 1001 , 1001 , 1001 , 1003 , 102 ]
441+ mock_img = MagicMock ()
442+ mock_img .height = 224
443+ mock_img .width = 224
444+ mock_img .convert .return_value = mock_img
445+ request = {
446+ "prompt_token_ids" : test_prompt_token_ids ,
447+ "messages" : [
448+ {"role" : "user" , "content" : [{"type" : "image_url" , "image_url" : mock_img , "uuid" : "img_uuid" }]}
449+ ],
450+ }
451+ self .data_processor .extract_mm_items .return_value = (
452+ [mock_img ],
453+ [],
454+ ["img_uuid" ],
455+ [],
456+ None ,
457+ [],
458+ [{"type" : "image" , "data" : mock_img }],
459+ )
460+ patches_h , patches_w = 8 , 8
461+ self .data_processor .image_preprocessor .get_smarted_resize .return_value = (None , (patches_h , patches_w ))
462+ mock_preprocess = {
463+ "pixel_values" : np .random .randn (1 , patches_h , patches_w , 3 ),
464+ "image_grid_thw" : np .array ([[patches_h , patches_w ]]),
465+ }
466+ self .data_processor .image_preprocessor .preprocess .return_value = mock_preprocess
467+ with self .assertRaises (ValueError ) as ctx :
468+ self .data_processor .prompt_token_ids2outputs (request )
469+ self .assertIn ("image tokens num not match the size" , str (ctx .exception ))
470+
471+ def test_prompt_token_ids2outputs_add_processed_image_token_len_mismatch (self ):
472+ test_prompt_token_ids = [101 , 1002 , 1001 , 1001 , 1003 , 102 ]
473+ spatial_conv_size = self .data_processor .spatial_conv_size
474+ num_tokens = 4
475+ mock_img_data = np .random .randn (num_tokens * (spatial_conv_size ** 2 ), 28 , 28 )
476+ mock_img_cache = (mock_img_data , {"thw" : (1 , 8 , 8 )})
477+ request = {
478+ "prompt_token_ids" : test_prompt_token_ids ,
479+ "messages" : [
480+ {"role" : "user" , "content" : [{"type" : "image_url" , "image_url" : mock_img_cache , "uuid" : "img_uuid" }]}
481+ ],
482+ }
483+ self .data_processor .extract_mm_items .return_value = (
484+ [mock_img_cache ],
485+ [],
486+ ["img_uuid" ],
487+ [],
488+ None ,
489+ [],
490+ [{"type" : "image" , "data" : mock_img_cache }],
491+ )
492+ with self .assertRaises (ValueError ) as ctx :
493+ self .data_processor .prompt_token_ids2outputs (request )
494+ self .assertIn ("image tokens num not match the size" , str (ctx .exception ))
495+
496+ def test_prompt_token_ids2outputs_add_video_token_len_mismatch (self ):
497+ test_prompt_token_ids = [101 , 1004 , 1001 , 1001 , 1005 , 102 ]
498+ mock_frame1 = MagicMock ()
499+ mock_frame1 .height = 224
500+ mock_frame1 .width = 224
501+ mock_frame1 .convert .return_value = mock_frame1
502+ mock_frame2 = MagicMock ()
503+ mock_frame2 .height = 224
504+ mock_frame2 .width = 224
505+ mock_frame2 .convert .return_value = mock_frame2
506+ frames = [mock_frame1 , mock_frame2 ]
507+ request = {
508+ "prompt_token_ids" : test_prompt_token_ids ,
509+ "messages" : [
510+ {"role" : "user" , "content" : [{"type" : "video_url" , "video_url" : frames , "uuid" : "vid_uuid" }]}
511+ ],
512+ }
513+ self .data_processor .extract_mm_items .return_value = (
514+ [],
515+ [frames ],
516+ [],
517+ ["vid_uuid" ],
518+ None ,
519+ [],
520+ [{"type" : "video" , "data" : frames }],
521+ )
522+ self .data_processor ._load_and_process_video = MagicMock (return_value = frames )
523+ patches_h , patches_w = 8 , 8
524+ self .data_processor .image_preprocessor .get_smarted_resize .return_value = (None , (patches_h , patches_w ))
525+ mock_preprocess = {
526+ "pixel_values_videos" : np .random .randn (2 , patches_h , patches_w , 3 ),
527+ "video_grid_thw" : np .array ([[patches_h , patches_w ]] * 2 ),
528+ }
529+ self .data_processor .image_preprocessor .preprocess .return_value = mock_preprocess
530+ with self .assertRaises (ValueError ) as ctx :
531+ self .data_processor .prompt_token_ids2outputs (request )
532+ self .assertIn ("video tokens num not match the size" , str (ctx .exception ))
533+
534+ def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch (self ):
535+ test_prompt_token_ids = [101 , 1004 , 1001 , 1005 , 102 ]
536+ t , h , w = 2 , 8 , 8
537+ spatial_conv_size = self .data_processor .spatial_conv_size
538+ temporal_conv_size = self .data_processor .temporal_conv_size
539+
540+ num_tokens = 4
541+ mock_frames_data = np .random .randn (num_tokens * spatial_conv_size ** 2 * temporal_conv_size , 28 , 28 )
542+ mock_frames_cache = (mock_frames_data , {"thw" : (t , h , w )})
543+ request = {
544+ "prompt_token_ids" : test_prompt_token_ids ,
545+ "messages" : [
546+ {"role" : "user" , "content" : [{"type" : "video" , "data" : mock_frames_cache , "uuid" : "vid_uuid" }]}
547+ ],
548+ }
549+ self .data_processor .extract_mm_items .return_value = (
550+ [],
551+ [mock_frames_cache ],
552+ [],
553+ ["vid_uuid" ],
554+ None ,
555+ [],
556+ [{"type" : "video" , "data" : mock_frames_cache }],
557+ )
558+ with self .assertRaises (ValueError ) as ctx :
559+ self .data_processor .prompt_token_ids2outputs (request )
560+ self .assertIn ("video tokens num not match the size" , str (ctx .exception ))
561+
439562
440563if __name__ == "__main__" :
441564 unittest .main ()
0 commit comments