Skip to content

Commit e653596

Browse files
committed
add unit test
1 parent 7c0d5c2 commit e653596

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

tests/input/test_ernie_vl_processor.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,129 @@ def test_prompt_token_ids2outputs_add_processed_video(self):
436436
self.assertEqual(len(outputs["grid_thw"]), 1)
437437
self.assertEqual(len(outputs["image_type_ids"]), 2)
438438

439+
def test_prompt_token_ids2outputs_add_image_token_len_mismatch(self):
440+
test_prompt_token_ids = [101, 1002, 1001, 1001, 1001, 1003, 102]
441+
mock_img = MagicMock()
442+
mock_img.height = 224
443+
mock_img.width = 224
444+
mock_img.convert.return_value = mock_img
445+
request = {
446+
"prompt_token_ids": test_prompt_token_ids,
447+
"messages": [
448+
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
449+
],
450+
}
451+
self.data_processor.extract_mm_items.return_value = (
452+
[mock_img],
453+
[],
454+
["img_uuid"],
455+
[],
456+
None,
457+
[],
458+
[{"type": "image", "data": mock_img}],
459+
)
460+
patches_h, patches_w = 8, 8
461+
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
462+
mock_preprocess = {
463+
"pixel_values": np.random.randn(1, patches_h, patches_w, 3),
464+
"image_grid_thw": np.array([[patches_h, patches_w]]),
465+
}
466+
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
467+
with self.assertRaises(ValueError) as ctx:
468+
self.data_processor.prompt_token_ids2outputs(request)
469+
self.assertIn("image tokens num not match the size", str(ctx.exception))
470+
471+
def test_prompt_token_ids2outputs_add_processed_image_token_len_mismatch(self):
472+
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
473+
spatial_conv_size = self.data_processor.spatial_conv_size
474+
num_tokens = 4
475+
mock_img_data = np.random.randn(num_tokens * (spatial_conv_size**2), 28, 28)
476+
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
477+
request = {
478+
"prompt_token_ids": test_prompt_token_ids,
479+
"messages": [
480+
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
481+
],
482+
}
483+
self.data_processor.extract_mm_items.return_value = (
484+
[mock_img_cache],
485+
[],
486+
["img_uuid"],
487+
[],
488+
None,
489+
[],
490+
[{"type": "image", "data": mock_img_cache}],
491+
)
492+
with self.assertRaises(ValueError) as ctx:
493+
self.data_processor.prompt_token_ids2outputs(request)
494+
self.assertIn("image tokens num not match the size", str(ctx.exception))
495+
496+
def test_prompt_token_ids2outputs_add_video_token_len_mismatch(self):
497+
test_prompt_token_ids = [101, 1004, 1001, 1001, 1005, 102]
498+
mock_frame1 = MagicMock()
499+
mock_frame1.height = 224
500+
mock_frame1.width = 224
501+
mock_frame1.convert.return_value = mock_frame1
502+
mock_frame2 = MagicMock()
503+
mock_frame2.height = 224
504+
mock_frame2.width = 224
505+
mock_frame2.convert.return_value = mock_frame2
506+
frames = [mock_frame1, mock_frame2]
507+
request = {
508+
"prompt_token_ids": test_prompt_token_ids,
509+
"messages": [
510+
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
511+
],
512+
}
513+
self.data_processor.extract_mm_items.return_value = (
514+
[],
515+
[frames],
516+
[],
517+
["vid_uuid"],
518+
None,
519+
[],
520+
[{"type": "video", "data": frames}],
521+
)
522+
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
523+
patches_h, patches_w = 8, 8
524+
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
525+
mock_preprocess = {
526+
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
527+
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
528+
}
529+
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
530+
with self.assertRaises(ValueError) as ctx:
531+
self.data_processor.prompt_token_ids2outputs(request)
532+
self.assertIn("video tokens num not match the size", str(ctx.exception))
533+
534+
def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch(self):
535+
test_prompt_token_ids = [101, 1004, 1001, 1005, 102]
536+
t, h, w = 2, 8, 8
537+
spatial_conv_size = self.data_processor.spatial_conv_size
538+
temporal_conv_size = self.data_processor.temporal_conv_size
539+
540+
num_tokens = 4
541+
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
542+
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
543+
request = {
544+
"prompt_token_ids": test_prompt_token_ids,
545+
"messages": [
546+
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
547+
],
548+
}
549+
self.data_processor.extract_mm_items.return_value = (
550+
[],
551+
[mock_frames_cache],
552+
[],
553+
["vid_uuid"],
554+
None,
555+
[],
556+
[{"type": "video", "data": mock_frames_cache}],
557+
)
558+
with self.assertRaises(ValueError) as ctx:
559+
self.data_processor.prompt_token_ids2outputs(request)
560+
self.assertIn("video tokens num not match the size", str(ctx.exception))
561+
439562

440563
if __name__ == "__main__":
441564
unittest.main()

0 commit comments

Comments
 (0)