|
1 | 1 | import unittest |
2 | 2 | from unittest.mock import MagicMock, patch |
3 | 3 |
|
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer |
4 | 7 | from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor |
| 8 | +from fastdeploy.input.ernie4_5_vl_processor.image_preprocessor.image_preprocessor_adaptive import ( |
| 9 | + AdaptiveImageProcessor, |
| 10 | +) |
| 11 | +from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor |
| 12 | +from fastdeploy.input.utils import IDS_TYPE_FLAG |
5 | 13 |
|
6 | 14 |
|
7 | 15 | class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase): |
@@ -133,5 +141,301 @@ def test_process_request_dict_with_options(self): |
133 | 141 | self.assertEqual(request_dict["enable_thinking"], False) |
134 | 142 |
|
135 | 143 |
|
| 144 | +class TestDataProcessorTargetMethods(unittest.TestCase): |
| 145 | + def setUp(self): |
| 146 | + self.mock_tokenizer = MagicMock(spec=Ernie4_5Tokenizer) |
| 147 | + self.mock_tokenizer.ignored_index = -100 |
| 148 | + self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids |
| 149 | + self.mock_tokenizer.chat_template = "mock_template" |
| 150 | + self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>" |
| 151 | + |
| 152 | + def mock_load_tokenizer(dp_instance): |
| 153 | + dp_instance.tokenizer = self.mock_tokenizer |
| 154 | + |
| 155 | + with patch.object(DataProcessor, "_load_tokenizer", side_effect=mock_load_tokenizer, autospec=True): |
| 156 | + with patch.object(AdaptiveImageProcessor, "from_pretrained") as mock_image_preprocessor: |
| 157 | + mock_image_preprocessor.return_value = MagicMock() |
| 158 | + self.data_processor = DataProcessor( |
| 159 | + tokenizer_name="mock_tokenizer", |
| 160 | + image_preprocessor_name="mock_image_preprocessor", |
| 161 | + enable_processor_cache=False, |
| 162 | + ) |
| 163 | + self.data_processor.image_patch_id = 1001 |
| 164 | + self.data_processor.image_start_id = 1002 |
| 165 | + self.data_processor.image_end_id = 1003 |
| 166 | + self.data_processor.video_start_id = 1004 |
| 167 | + self.data_processor.video_end_id = 1005 |
| 168 | + self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "} |
| 169 | + self.data_processor.enable_processor_cache = False |
| 170 | + self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], [])) |
| 171 | + |
| 172 | + def _mock_convert_tokens_to_ids(self, token): |
| 173 | + token_id_map = { |
| 174 | + "<|begin_of_sentence|>": 101, |
| 175 | + "<|end_of_sentence|>": 102, |
| 176 | + "</s>": 103, |
| 177 | + "<|IMAGE_PLACEHOLDER|>": 1001, |
| 178 | + "<|IMAGE_START|>": 1002, |
| 179 | + "<|IMAGE_END|>": 1003, |
| 180 | + "<|VIDEO_START|>": 1004, |
| 181 | + "<|VIDEO_END|>": 1005, |
| 182 | + } |
| 183 | + return token_id_map.get(token, 999) |
| 184 | + |
| 185 | + def test_prompt_token_ids2outputs_only_prompt_token_ids(self): |
| 186 | + test_prompt_token_ids = [101, 999, 998, 997, 102] |
| 187 | + request = { |
| 188 | + "prompt_token_ids": test_prompt_token_ids, |
| 189 | + } |
| 190 | + |
| 191 | + outputs = self.data_processor.prompt_token_ids2outputs(request) |
| 192 | + |
| 193 | + prompt_len = len(test_prompt_token_ids) |
| 194 | + |
| 195 | + self.assertEqual( |
| 196 | + outputs["input_ids"], |
| 197 | + [test_prompt_token_ids], |
| 198 | + f"input_ids 不匹配:实际{outputs['input_ids']},预期[{test_prompt_token_ids}]", |
| 199 | + ) |
| 200 | + |
| 201 | + self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len) |
| 202 | + |
| 203 | + expected_position_ids = [[i] * 3 for i in range(prompt_len)] |
| 204 | + self.assertEqual(outputs["position_ids"], expected_position_ids) |
| 205 | + |
| 206 | + self.assertEqual(outputs["cur_position"], prompt_len) |
| 207 | + |
| 208 | + self.assertEqual(len(outputs["images"]), 0) |
| 209 | + self.assertEqual(len(outputs["grid_thw"]), 0) |
| 210 | + self.assertEqual(len(outputs["mm_positions"]), 0) |
| 211 | + self.assertEqual(len(outputs["mm_hashes"]), 0) |
| 212 | + self.assertEqual(outputs["video_cnt"], 0) |
| 213 | + self.assertEqual(outputs["num_input_image_tokens"], 0) |
| 214 | + self.assertEqual(outputs["num_input_video_tokens"], 0) |
| 215 | + |
| 216 | + def test_prompt_token_ids2outputs_with_messages_no_mm(self): |
| 217 | + test_prompt_token_ids = [101, 999, 998, 997, 102] |
| 218 | + request = { |
| 219 | + "prompt_token_ids": test_prompt_token_ids, |
| 220 | + "messages": [{"role": "user", "content": "Hello World"}], |
| 221 | + } |
| 222 | + |
| 223 | + self.data_processor.extract_mm_items.return_value = ([], [], [], [], None, [], []) |
| 224 | + |
| 225 | + outputs = self.data_processor.prompt_token_ids2outputs(request) |
| 226 | + |
| 227 | + prompt_len = len(test_prompt_token_ids) |
| 228 | + |
| 229 | + self.assertEqual(outputs["input_ids"], test_prompt_token_ids) |
| 230 | + |
| 231 | + self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len) |
| 232 | + |
| 233 | + expected_position_ids = [[i] * 3 for i in range(prompt_len)] |
| 234 | + self.assertEqual(outputs["position_ids"], expected_position_ids) |
| 235 | + |
| 236 | + self.assertEqual(outputs["cur_position"], prompt_len) |
| 237 | + |
| 238 | + self.assertEqual(len(outputs["images"]), 0) |
| 239 | + self.assertEqual(outputs["video_cnt"], 0) |
| 240 | + self.assertEqual(outputs["num_input_image_tokens"], 0) |
| 241 | + |
| 242 | + def test_prompt_token_ids2outputs_add_image(self): |
| 243 | + test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102] |
| 244 | + mock_img = MagicMock() |
| 245 | + mock_img.height = 224 |
| 246 | + mock_img.width = 224 |
| 247 | + mock_img.convert.return_value = mock_img |
| 248 | + request = { |
| 249 | + "prompt_token_ids": test_prompt_token_ids, |
| 250 | + "messages": [ |
| 251 | + {"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]} |
| 252 | + ], |
| 253 | + } |
| 254 | + self.data_processor.extract_mm_items.return_value = ( |
| 255 | + [mock_img], |
| 256 | + [], |
| 257 | + ["img_uuid"], |
| 258 | + [], |
| 259 | + None, |
| 260 | + [], |
| 261 | + [{"type": "image", "data": mock_img}], |
| 262 | + ) |
| 263 | + mock_resize = (None, (2, 4)) |
| 264 | + self.data_processor.image_preprocessor.get_smarted_resize.return_value = mock_resize |
| 265 | + mock_preprocess = {"pixel_values": np.random.randn(1, 16, 16, 3), "image_grid_thw": np.array([[2, 4]])} |
| 266 | + self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess |
| 267 | + # self.data_processor._compute_3d_positions = MagicMock(return_value=[[i]*3 for i in range(4)]) |
| 268 | + outputs = self.data_processor.prompt_token_ids2outputs(request) |
| 269 | + self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102]) |
| 270 | + self.assertEqual( |
| 271 | + outputs["token_type_ids"], |
| 272 | + [ |
| 273 | + IDS_TYPE_FLAG["text"], |
| 274 | + IDS_TYPE_FLAG["text"], |
| 275 | + IDS_TYPE_FLAG["image"], |
| 276 | + IDS_TYPE_FLAG["image"], |
| 277 | + IDS_TYPE_FLAG["text"], |
| 278 | + IDS_TYPE_FLAG["text"], |
| 279 | + ], |
| 280 | + ) |
| 281 | + self.assertEqual(len(outputs["position_ids"]), 6) |
| 282 | + self.assertEqual(outputs["cur_position"], 6) |
| 283 | + self.assertEqual(len(outputs["images"]), 1) |
| 284 | + self.assertIsNotNone(outputs["images"][0]) |
| 285 | + self.assertEqual(outputs["num_input_image_tokens"], 2) |
| 286 | + self.assertEqual(len(outputs["mm_positions"]), 1) |
| 287 | + self.assertEqual(len(outputs["mm_hashes"]), 1) |
| 288 | + self.assertEqual(len(outputs["grid_thw"]), 1) |
| 289 | + self.assertEqual(len(outputs["image_type_ids"]), 1) |
| 290 | + |
| 291 | + def test_prompt_token_ids2outputs_add_processed_image(self): |
| 292 | + test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102] |
| 293 | + mock_img_data = np.random.randn(8, 28, 28) |
| 294 | + mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)}) |
| 295 | + request = { |
| 296 | + "prompt_token_ids": test_prompt_token_ids, |
| 297 | + "messages": [ |
| 298 | + {"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]} |
| 299 | + ], |
| 300 | + } |
| 301 | + self.data_processor.extract_mm_items.return_value = ( |
| 302 | + [mock_img_cache], |
| 303 | + [], |
| 304 | + ["img_uuid"], |
| 305 | + [], |
| 306 | + None, |
| 307 | + [], |
| 308 | + [{"type": "image", "data": mock_img_cache}], |
| 309 | + ) |
| 310 | + outputs = self.data_processor.prompt_token_ids2outputs(request) |
| 311 | + self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102]) |
| 312 | + self.assertEqual( |
| 313 | + outputs["token_type_ids"], |
| 314 | + [ |
| 315 | + IDS_TYPE_FLAG["text"], |
| 316 | + IDS_TYPE_FLAG["text"], |
| 317 | + IDS_TYPE_FLAG["image"], |
| 318 | + IDS_TYPE_FLAG["image"], |
| 319 | + IDS_TYPE_FLAG["text"], |
| 320 | + IDS_TYPE_FLAG["text"], |
| 321 | + ], |
| 322 | + ) |
| 323 | + self.assertEqual(len(outputs["position_ids"]), 20) |
| 324 | + self.assertEqual(outputs["cur_position"], 8) |
| 325 | + self.assertEqual(len(outputs["images"]), 1) |
| 326 | + self.assertIsNotNone(outputs["images"][0]) |
| 327 | + self.assertEqual(len(outputs["mm_positions"]), 1) |
| 328 | + self.assertEqual(outputs["mm_hashes"][0], "img_uuid") |
| 329 | + self.assertEqual(len(outputs["grid_thw"]), 1) |
| 330 | + self.assertEqual(len(outputs["image_type_ids"]), 1) |
| 331 | + |
| 332 | + def test_prompt_token_ids2outputs_add_video(self): |
| 333 | + test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102] |
| 334 | + mock_frame1 = MagicMock() |
| 335 | + mock_frame1.height = 224 |
| 336 | + mock_frame1.width = 224 |
| 337 | + mock_frame1.convert.return_value = mock_frame1 |
| 338 | + mock_frame2 = MagicMock() |
| 339 | + mock_frame2.height = 224 |
| 340 | + mock_frame2.width = 224 |
| 341 | + mock_frame2.convert.return_value = mock_frame2 |
| 342 | + frames = [mock_frame1, mock_frame2] |
| 343 | + request = { |
| 344 | + "prompt_token_ids": test_prompt_token_ids, |
| 345 | + "messages": [ |
| 346 | + {"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]} |
| 347 | + ], |
| 348 | + } |
| 349 | + self.data_processor.extract_mm_items.return_value = ( |
| 350 | + [], |
| 351 | + [frames], |
| 352 | + [], |
| 353 | + ["vid_uuid"], |
| 354 | + None, |
| 355 | + [], |
| 356 | + [{"type": "video", "data": frames}], |
| 357 | + ) |
| 358 | + self.data_processor._load_and_process_video = MagicMock(return_value=frames) |
| 359 | + patches_h, patches_w = 4, 4 |
| 360 | + self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w)) |
| 361 | + mock_preprocess = { |
| 362 | + "pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3), |
| 363 | + "video_grid_thw": np.array([[patches_h, patches_w]] * 2), |
| 364 | + } |
| 365 | + self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess |
| 366 | + outputs = self.data_processor.prompt_token_ids2outputs(request) |
| 367 | + self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]) |
| 368 | + self.assertEqual( |
| 369 | + outputs["token_type_ids"], |
| 370 | + [ |
| 371 | + IDS_TYPE_FLAG["text"], |
| 372 | + IDS_TYPE_FLAG["text"], |
| 373 | + IDS_TYPE_FLAG["video"], |
| 374 | + IDS_TYPE_FLAG["video"], |
| 375 | + IDS_TYPE_FLAG["video"], |
| 376 | + IDS_TYPE_FLAG["video"], |
| 377 | + IDS_TYPE_FLAG["text"], |
| 378 | + IDS_TYPE_FLAG["text"], |
| 379 | + ], |
| 380 | + ) |
| 381 | + self.assertEqual(len(outputs["position_ids"]), 8) |
| 382 | + self.assertEqual(outputs["cur_position"], 6) |
| 383 | + self.assertEqual(len(outputs["images"]), 1) |
| 384 | + self.assertIsNotNone(outputs["images"][0]) |
| 385 | + self.assertEqual(len(outputs["mm_positions"]), 1) |
| 386 | + self.assertEqual(outputs["mm_hashes"][0], "vid_uuid") |
| 387 | + self.assertEqual(len(outputs["grid_thw"]), 1) |
| 388 | + self.assertEqual(len(outputs["image_type_ids"]), 2) |
| 389 | + self.assertEqual(outputs["num_input_video_tokens"], 4) |
| 390 | + |
| 391 | + def test_prompt_token_ids2outputs_add_processed_video(self): |
| 392 | + test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102] |
| 393 | + t, h, w = 2, 4, 4 |
| 394 | + spatial_conv_size = self.data_processor.spatial_conv_size |
| 395 | + temporal_conv_size = self.data_processor.temporal_conv_size |
| 396 | + token_per_frame = (h // spatial_conv_size) * (w // spatial_conv_size) |
| 397 | + num_tokens = (t // temporal_conv_size) * token_per_frame |
| 398 | + mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28) |
| 399 | + mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)}) |
| 400 | + request = { |
| 401 | + "prompt_token_ids": test_prompt_token_ids, |
| 402 | + "messages": [ |
| 403 | + {"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]} |
| 404 | + ], |
| 405 | + } |
| 406 | + self.data_processor.extract_mm_items.return_value = ( |
| 407 | + [], |
| 408 | + [mock_frames_cache], |
| 409 | + [], |
| 410 | + ["vid_uuid"], |
| 411 | + None, |
| 412 | + [], |
| 413 | + [{"type": "video", "data": mock_frames_cache}], |
| 414 | + ) |
| 415 | + outputs = self.data_processor.prompt_token_ids2outputs(request) |
| 416 | + self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]) |
| 417 | + self.assertEqual( |
| 418 | + outputs["token_type_ids"], |
| 419 | + [ |
| 420 | + IDS_TYPE_FLAG["text"], |
| 421 | + IDS_TYPE_FLAG["text"], |
| 422 | + IDS_TYPE_FLAG["video"], |
| 423 | + IDS_TYPE_FLAG["video"], |
| 424 | + IDS_TYPE_FLAG["video"], |
| 425 | + IDS_TYPE_FLAG["video"], |
| 426 | + IDS_TYPE_FLAG["text"], |
| 427 | + IDS_TYPE_FLAG["text"], |
| 428 | + ], |
| 429 | + ) |
| 430 | + self.assertEqual(len(outputs["position_ids"]), 8) |
| 431 | + self.assertEqual(outputs["cur_position"], 6) |
| 432 | + self.assertEqual(len(outputs["images"]), 1) |
| 433 | + self.assertIsNotNone(outputs["images"][0]) |
| 434 | + self.assertEqual(len(outputs["mm_positions"]), 1) |
| 435 | + self.assertEqual(outputs["mm_hashes"][0], "vid_uuid") |
| 436 | + self.assertEqual(len(outputs["grid_thw"]), 1) |
| 437 | + self.assertEqual(len(outputs["image_type_ids"]), 2) |
| 438 | + |
| 439 | + |
136 | 440 | if __name__ == "__main__": |
137 | 441 | unittest.main() |
0 commit comments