Skip to content

Commit 7c0d5c2

Browse files
committed
add unit test
1 parent 913bcb0 commit 7c0d5c2

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed

tests/input/test_ernie_vl_processor.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import unittest
22
from unittest.mock import MagicMock, patch
33

4+
import numpy as np
5+
6+
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
47
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
8+
from fastdeploy.input.ernie4_5_vl_processor.image_preprocessor.image_preprocessor_adaptive import (
9+
AdaptiveImageProcessor,
10+
)
11+
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
12+
from fastdeploy.input.utils import IDS_TYPE_FLAG
513

614

715
class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
@@ -133,5 +141,301 @@ def test_process_request_dict_with_options(self):
133141
self.assertEqual(request_dict["enable_thinking"], False)
134142

135143

144+
class TestDataProcessorTargetMethods(unittest.TestCase):
145+
def setUp(self):
146+
self.mock_tokenizer = MagicMock(spec=Ernie4_5Tokenizer)
147+
self.mock_tokenizer.ignored_index = -100
148+
self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids
149+
self.mock_tokenizer.chat_template = "mock_template"
150+
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>"
151+
152+
def mock_load_tokenizer(dp_instance):
153+
dp_instance.tokenizer = self.mock_tokenizer
154+
155+
with patch.object(DataProcessor, "_load_tokenizer", side_effect=mock_load_tokenizer, autospec=True):
156+
with patch.object(AdaptiveImageProcessor, "from_pretrained") as mock_image_preprocessor:
157+
mock_image_preprocessor.return_value = MagicMock()
158+
self.data_processor = DataProcessor(
159+
tokenizer_name="mock_tokenizer",
160+
image_preprocessor_name="mock_image_preprocessor",
161+
enable_processor_cache=False,
162+
)
163+
self.data_processor.image_patch_id = 1001
164+
self.data_processor.image_start_id = 1002
165+
self.data_processor.image_end_id = 1003
166+
self.data_processor.video_start_id = 1004
167+
self.data_processor.video_end_id = 1005
168+
self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "}
169+
self.data_processor.enable_processor_cache = False
170+
self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], []))
171+
172+
def _mock_convert_tokens_to_ids(self, token):
173+
token_id_map = {
174+
"<|begin_of_sentence|>": 101,
175+
"<|end_of_sentence|>": 102,
176+
"</s>": 103,
177+
"<|IMAGE_PLACEHOLDER|>": 1001,
178+
"<|IMAGE_START|>": 1002,
179+
"<|IMAGE_END|>": 1003,
180+
"<|VIDEO_START|>": 1004,
181+
"<|VIDEO_END|>": 1005,
182+
}
183+
return token_id_map.get(token, 999)
184+
185+
def test_prompt_token_ids2outputs_only_prompt_token_ids(self):
186+
test_prompt_token_ids = [101, 999, 998, 997, 102]
187+
request = {
188+
"prompt_token_ids": test_prompt_token_ids,
189+
}
190+
191+
outputs = self.data_processor.prompt_token_ids2outputs(request)
192+
193+
prompt_len = len(test_prompt_token_ids)
194+
195+
self.assertEqual(
196+
outputs["input_ids"],
197+
[test_prompt_token_ids],
198+
f"input_ids 不匹配:实际{outputs['input_ids']},预期[{test_prompt_token_ids}]",
199+
)
200+
201+
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
202+
203+
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
204+
self.assertEqual(outputs["position_ids"], expected_position_ids)
205+
206+
self.assertEqual(outputs["cur_position"], prompt_len)
207+
208+
self.assertEqual(len(outputs["images"]), 0)
209+
self.assertEqual(len(outputs["grid_thw"]), 0)
210+
self.assertEqual(len(outputs["mm_positions"]), 0)
211+
self.assertEqual(len(outputs["mm_hashes"]), 0)
212+
self.assertEqual(outputs["video_cnt"], 0)
213+
self.assertEqual(outputs["num_input_image_tokens"], 0)
214+
self.assertEqual(outputs["num_input_video_tokens"], 0)
215+
216+
def test_prompt_token_ids2outputs_with_messages_no_mm(self):
217+
test_prompt_token_ids = [101, 999, 998, 997, 102]
218+
request = {
219+
"prompt_token_ids": test_prompt_token_ids,
220+
"messages": [{"role": "user", "content": "Hello World"}],
221+
}
222+
223+
self.data_processor.extract_mm_items.return_value = ([], [], [], [], None, [], [])
224+
225+
outputs = self.data_processor.prompt_token_ids2outputs(request)
226+
227+
prompt_len = len(test_prompt_token_ids)
228+
229+
self.assertEqual(outputs["input_ids"], test_prompt_token_ids)
230+
231+
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
232+
233+
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
234+
self.assertEqual(outputs["position_ids"], expected_position_ids)
235+
236+
self.assertEqual(outputs["cur_position"], prompt_len)
237+
238+
self.assertEqual(len(outputs["images"]), 0)
239+
self.assertEqual(outputs["video_cnt"], 0)
240+
self.assertEqual(outputs["num_input_image_tokens"], 0)
241+
242+
def test_prompt_token_ids2outputs_add_image(self):
243+
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
244+
mock_img = MagicMock()
245+
mock_img.height = 224
246+
mock_img.width = 224
247+
mock_img.convert.return_value = mock_img
248+
request = {
249+
"prompt_token_ids": test_prompt_token_ids,
250+
"messages": [
251+
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
252+
],
253+
}
254+
self.data_processor.extract_mm_items.return_value = (
255+
[mock_img],
256+
[],
257+
["img_uuid"],
258+
[],
259+
None,
260+
[],
261+
[{"type": "image", "data": mock_img}],
262+
)
263+
mock_resize = (None, (2, 4))
264+
self.data_processor.image_preprocessor.get_smarted_resize.return_value = mock_resize
265+
mock_preprocess = {"pixel_values": np.random.randn(1, 16, 16, 3), "image_grid_thw": np.array([[2, 4]])}
266+
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
267+
# self.data_processor._compute_3d_positions = MagicMock(return_value=[[i]*3 for i in range(4)])
268+
outputs = self.data_processor.prompt_token_ids2outputs(request)
269+
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
270+
self.assertEqual(
271+
outputs["token_type_ids"],
272+
[
273+
IDS_TYPE_FLAG["text"],
274+
IDS_TYPE_FLAG["text"],
275+
IDS_TYPE_FLAG["image"],
276+
IDS_TYPE_FLAG["image"],
277+
IDS_TYPE_FLAG["text"],
278+
IDS_TYPE_FLAG["text"],
279+
],
280+
)
281+
self.assertEqual(len(outputs["position_ids"]), 6)
282+
self.assertEqual(outputs["cur_position"], 6)
283+
self.assertEqual(len(outputs["images"]), 1)
284+
self.assertIsNotNone(outputs["images"][0])
285+
self.assertEqual(outputs["num_input_image_tokens"], 2)
286+
self.assertEqual(len(outputs["mm_positions"]), 1)
287+
self.assertEqual(len(outputs["mm_hashes"]), 1)
288+
self.assertEqual(len(outputs["grid_thw"]), 1)
289+
self.assertEqual(len(outputs["image_type_ids"]), 1)
290+
291+
def test_prompt_token_ids2outputs_add_processed_image(self):
292+
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
293+
mock_img_data = np.random.randn(8, 28, 28)
294+
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
295+
request = {
296+
"prompt_token_ids": test_prompt_token_ids,
297+
"messages": [
298+
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
299+
],
300+
}
301+
self.data_processor.extract_mm_items.return_value = (
302+
[mock_img_cache],
303+
[],
304+
["img_uuid"],
305+
[],
306+
None,
307+
[],
308+
[{"type": "image", "data": mock_img_cache}],
309+
)
310+
outputs = self.data_processor.prompt_token_ids2outputs(request)
311+
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
312+
self.assertEqual(
313+
outputs["token_type_ids"],
314+
[
315+
IDS_TYPE_FLAG["text"],
316+
IDS_TYPE_FLAG["text"],
317+
IDS_TYPE_FLAG["image"],
318+
IDS_TYPE_FLAG["image"],
319+
IDS_TYPE_FLAG["text"],
320+
IDS_TYPE_FLAG["text"],
321+
],
322+
)
323+
self.assertEqual(len(outputs["position_ids"]), 20)
324+
self.assertEqual(outputs["cur_position"], 8)
325+
self.assertEqual(len(outputs["images"]), 1)
326+
self.assertIsNotNone(outputs["images"][0])
327+
self.assertEqual(len(outputs["mm_positions"]), 1)
328+
self.assertEqual(outputs["mm_hashes"][0], "img_uuid")
329+
self.assertEqual(len(outputs["grid_thw"]), 1)
330+
self.assertEqual(len(outputs["image_type_ids"]), 1)
331+
332+
def test_prompt_token_ids2outputs_add_video(self):
333+
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
334+
mock_frame1 = MagicMock()
335+
mock_frame1.height = 224
336+
mock_frame1.width = 224
337+
mock_frame1.convert.return_value = mock_frame1
338+
mock_frame2 = MagicMock()
339+
mock_frame2.height = 224
340+
mock_frame2.width = 224
341+
mock_frame2.convert.return_value = mock_frame2
342+
frames = [mock_frame1, mock_frame2]
343+
request = {
344+
"prompt_token_ids": test_prompt_token_ids,
345+
"messages": [
346+
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
347+
],
348+
}
349+
self.data_processor.extract_mm_items.return_value = (
350+
[],
351+
[frames],
352+
[],
353+
["vid_uuid"],
354+
None,
355+
[],
356+
[{"type": "video", "data": frames}],
357+
)
358+
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
359+
patches_h, patches_w = 4, 4
360+
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
361+
mock_preprocess = {
362+
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
363+
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
364+
}
365+
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
366+
outputs = self.data_processor.prompt_token_ids2outputs(request)
367+
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
368+
self.assertEqual(
369+
outputs["token_type_ids"],
370+
[
371+
IDS_TYPE_FLAG["text"],
372+
IDS_TYPE_FLAG["text"],
373+
IDS_TYPE_FLAG["video"],
374+
IDS_TYPE_FLAG["video"],
375+
IDS_TYPE_FLAG["video"],
376+
IDS_TYPE_FLAG["video"],
377+
IDS_TYPE_FLAG["text"],
378+
IDS_TYPE_FLAG["text"],
379+
],
380+
)
381+
self.assertEqual(len(outputs["position_ids"]), 8)
382+
self.assertEqual(outputs["cur_position"], 6)
383+
self.assertEqual(len(outputs["images"]), 1)
384+
self.assertIsNotNone(outputs["images"][0])
385+
self.assertEqual(len(outputs["mm_positions"]), 1)
386+
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
387+
self.assertEqual(len(outputs["grid_thw"]), 1)
388+
self.assertEqual(len(outputs["image_type_ids"]), 2)
389+
self.assertEqual(outputs["num_input_video_tokens"], 4)
390+
391+
def test_prompt_token_ids2outputs_add_processed_video(self):
392+
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
393+
t, h, w = 2, 4, 4
394+
spatial_conv_size = self.data_processor.spatial_conv_size
395+
temporal_conv_size = self.data_processor.temporal_conv_size
396+
token_per_frame = (h // spatial_conv_size) * (w // spatial_conv_size)
397+
num_tokens = (t // temporal_conv_size) * token_per_frame
398+
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
399+
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
400+
request = {
401+
"prompt_token_ids": test_prompt_token_ids,
402+
"messages": [
403+
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
404+
],
405+
}
406+
self.data_processor.extract_mm_items.return_value = (
407+
[],
408+
[mock_frames_cache],
409+
[],
410+
["vid_uuid"],
411+
None,
412+
[],
413+
[{"type": "video", "data": mock_frames_cache}],
414+
)
415+
outputs = self.data_processor.prompt_token_ids2outputs(request)
416+
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
417+
self.assertEqual(
418+
outputs["token_type_ids"],
419+
[
420+
IDS_TYPE_FLAG["text"],
421+
IDS_TYPE_FLAG["text"],
422+
IDS_TYPE_FLAG["video"],
423+
IDS_TYPE_FLAG["video"],
424+
IDS_TYPE_FLAG["video"],
425+
IDS_TYPE_FLAG["video"],
426+
IDS_TYPE_FLAG["text"],
427+
IDS_TYPE_FLAG["text"],
428+
],
429+
)
430+
self.assertEqual(len(outputs["position_ids"]), 8)
431+
self.assertEqual(outputs["cur_position"], 6)
432+
self.assertEqual(len(outputs["images"]), 1)
433+
self.assertIsNotNone(outputs["images"][0])
434+
self.assertEqual(len(outputs["mm_positions"]), 1)
435+
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
436+
self.assertEqual(len(outputs["grid_thw"]), 1)
437+
self.assertEqual(len(outputs["image_type_ids"]), 2)
438+
439+
136440
if __name__ == "__main__":
137441
unittest.main()

0 commit comments

Comments
 (0)