Skip to content

Commit 2d78759

Browse files
authored
[Feature] The 45VL supports prompt_token_ids + messages input. (#5148)
* support prompt_token_ids + messages * fix bug * refact code structure * support cache mm items * refact code structure * delete test cases * modify unit test * add unit test * add unit test * fix append * add check for messages
1 parent 66e096d commit 2d78759

File tree

4 files changed

+601
-21
lines changed

4 files changed

+601
-21
lines changed

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,7 @@ def to_dict_for_infer(self, request_id=None):
671671
if request_id is not None:
672672
req_dict["request_id"] = request_id
673673

674-
if "prompt_token_ids" in req_dict:
675-
if "messages" in req_dict:
676-
del req_dict["messages"]
677-
else:
674+
if "prompt_token_ids" not in req_dict or not req_dict["prompt_token_ids"]:
678675
# If disable_chat_template is set, then the first message in messages will be used as the prompt.
679676
assert (
680677
len(req_dict["messages"]) > 0

fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,13 @@ def process_request_dict(self, request, max_model_len=None):
219219
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
220220
request["bad_words_token_ids"] = bad_words_token_ids
221221

222-
if request.get("prompt"):
222+
if request.get("prompt_token_ids"):
223+
messages = request.get("messages")
224+
if messages:
225+
self._check_mm_limits(messages)
226+
request.setdefault("enable_thinking", True)
227+
outputs = self.ernie4_5_processor.prompt_token_ids2outputs(request)
228+
elif request.get("prompt"):
223229
multimodal_data = request.get("multimodal_data")
224230
if multimodal_data is None:
225231
multimodal_data = {}
@@ -256,7 +262,9 @@ def process_request_dict(self, request, max_model_len=None):
256262
self.append_completion_tokens(outputs, request["completion_token_ids"])
257263

258264
outputs = self.pack_outputs(outputs)
259-
request["prompt_token_ids"] = outputs["input_ids"].tolist()
265+
request["prompt_token_ids"] = (
266+
outputs["input_ids"].tolist() if "prompt_token_ids" not in request else request["prompt_token_ids"]
267+
)
260268
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
261269
request["multimodal_inputs"] = outputs
262270

fastdeploy/input/ernie4_5_vl_processor/process.py

Lines changed: 139 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def __init__(
136136
self.video_end = self.VID_END
137137
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
138138
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
139+
self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end)
139140
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
141+
self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end)
140142
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
141143
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
142144

@@ -243,14 +245,7 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N
243245

244246
return outputs
245247

246-
def request2ids(
247-
self, request: Dict[str, Any], tgts: List[str] = None
248-
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
249-
"""
250-
Convert chat messages into model inputs.
251-
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
252-
"""
253-
248+
def extract_mm_items(self, request: Dict[str, Any]):
254249
messages = parse_chat_messages(request.get("messages"))
255250
mm_items = []
256251
for msg in messages:
@@ -273,6 +268,7 @@ def request2ids(
273268
if len(missing_hashes) > 0 and not self.enable_processor_cache:
274269
raise ValueError("Missing items cannot be retrieved without processor cache.")
275270

271+
dealer = None
276272
if self.enable_processor_cache:
277273
context = zmq.Context()
278274
dealer = context.socket(zmq.DEALER)
@@ -295,6 +291,16 @@ def request2ids(
295291
video_uuid.append(item["uuid"])
296292
else:
297293
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
294+
return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items
295+
296+
def request2ids(
297+
self, request: Dict[str, Any], tgts: List[str] = None
298+
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
299+
"""
300+
Convert chat messages into model inputs.
301+
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
302+
"""
303+
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
298304

299305
if self.tokenizer.chat_template is None:
300306
raise ValueError("This model does not support chat template.")
@@ -329,6 +335,115 @@ def request2ids(
329335

330336
return outputs
331337

338+
def prompt_token_ids2outputs(
339+
self, request: Dict[str, Any], tgts: List[str] = None
340+
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
341+
outputs = {
342+
"input_ids": [],
343+
"token_type_ids": [],
344+
"position_ids": [],
345+
"images": [],
346+
"grid_thw": [],
347+
"image_type_ids": [],
348+
"labels": [],
349+
"cur_position": 0,
350+
"video_cnt": 0,
351+
"num_input_image_tokens": 0,
352+
"num_input_video_tokens": 0,
353+
"mm_positions": [],
354+
"mm_hashes": [],
355+
}
356+
prompt_token_ids = request.get("prompt_token_ids", [])
357+
prompt_token_ids_len = len(prompt_token_ids)
358+
if not request.get("messages"):
359+
outputs["input_ids"].extend(prompt_token_ids)
360+
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
361+
for i in range(prompt_token_ids_len):
362+
outputs["position_ids"].append([i] * 3)
363+
outputs["cur_position"] += prompt_token_ids_len
364+
return outputs
365+
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
366+
st, image_idx, video_idx = 0, 0, 0
367+
while st < prompt_token_ids_len:
368+
cur_token_id = prompt_token_ids[st]
369+
if cur_token_id == self.image_start_id:
370+
if image_idx >= len(images):
371+
raise ValueError("prompt token ids has more image placeholder than in messages")
372+
# append image_start_id
373+
outputs["input_ids"].extend([cur_token_id])
374+
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
375+
outputs["position_ids"].append([outputs["cur_position"]] * 3)
376+
outputs["cur_position"] += 1
377+
st += 1
378+
# process placeholder token ids
379+
cur_idx = st
380+
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id:
381+
cur_idx += 1
382+
if cur_idx >= prompt_token_ids_len:
383+
raise ValueError("image token ids not complete")
384+
image = images[image_idx]
385+
uuid = image_uuid[image_idx] if image_uuid else None
386+
token_len = cur_idx - st
387+
if not isinstance(image, tuple):
388+
self._add_image(image, outputs, uuid, token_len)
389+
else:
390+
self._add_processed_image(image, outputs, uuid, token_len)
391+
image_idx += 1
392+
st = cur_idx
393+
elif cur_token_id == self.video_start_id:
394+
if video_idx >= len(videos):
395+
raise ValueError("prompt token ids has more video placeholder than in messages")
396+
# append video_start_id
397+
outputs["input_ids"].extend([cur_token_id])
398+
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
399+
outputs["position_ids"].append([outputs["cur_position"]] * 3)
400+
outputs["cur_position"] += 1
401+
st += 1
402+
# process placeholder token ids
403+
cur_idx = st
404+
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id:
405+
cur_idx += 1
406+
if cur_idx >= prompt_token_ids_len:
407+
raise ValueError("video token ids not complete")
408+
video = videos[video_idx]
409+
uuid = video_uuid[video_idx] if video_uuid else None
410+
token_len = cur_idx - st
411+
if not isinstance(video, tuple):
412+
if isinstance(video, dict):
413+
frames = self._load_and_process_video(video["video"], video)
414+
else:
415+
frames = self._load_and_process_video(video, {})
416+
self._add_video(frames, outputs, uuid, token_len)
417+
else:
418+
self._add_processed_video(video, outputs, uuid, token_len)
419+
video_idx += 1
420+
st = cur_idx
421+
else:
422+
outputs["input_ids"].extend([cur_token_id])
423+
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
424+
outputs["position_ids"].append([outputs["cur_position"]] * 3)
425+
outputs["cur_position"] += 1
426+
st += 1
427+
if image_idx != len(images):
428+
raise ValueError("number of images does not match")
429+
if video_idx != len(videos):
430+
raise ValueError("number of videos does not match")
431+
432+
if self.enable_processor_cache:
433+
missing_idx = set(missing_idx)
434+
hashes_to_cache, items_to_cache = [], []
435+
for idx in range(len(mm_items)):
436+
if idx in missing_idx:
437+
continue
438+
meta = {}
439+
t, h, w = outputs["grid_thw"][idx][0]
440+
meta["thw"] = (t, h, w)
441+
hashes_to_cache.append(outputs["mm_hashes"][idx])
442+
items_to_cache.append((outputs["images"][idx], meta))
443+
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
444+
445+
return outputs
446+
332447
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
333448
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
334449
outputs["input_ids"].append(token_id)
@@ -348,14 +463,16 @@ def _add_text(self, tokens, outputs: Dict) -> None:
348463
outputs["position_ids"].append([start + i] * 3)
349464
outputs["cur_position"] += len(tokens)
350465

351-
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
466+
def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
352467
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
353468
img.height,
354469
img.width,
355470
min_pixels=self.image_min_pixels,
356471
max_pixels=self.image_max_pixels,
357472
)[1]
358473
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
474+
if token_len and token_len != num_tokens:
475+
raise ValueError("image tokens num not match the size")
359476

360477
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
361478
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
@@ -383,9 +500,13 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
383500
outputs["grid_thw"].append(ret["image_grid_thw"])
384501
outputs["image_type_ids"].append(0)
385502

386-
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
503+
def _add_processed_image(
504+
self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
505+
) -> None:
387506
img, meta = img_cache
388507
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
508+
if token_len and num_tokens != token_len:
509+
raise ValueError("image tokens num not match the size")
389510

390511
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
391512
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
@@ -401,7 +522,7 @@ def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict
401522
outputs["grid_thw"].append(np.array([[1, h, w]]))
402523
outputs["image_type_ids"].append(0)
403524

404-
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
525+
def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
405526
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
406527
frames[0].height,
407528
frames[0].width,
@@ -410,6 +531,8 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
410531
)[1]
411532
num_frames = len(frames)
412533
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
534+
if token_len and num_tokens != token_len:
535+
raise ValueError("video tokens num not match the size")
413536

414537
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
415538
ret = self.image_preprocessor.preprocess(
@@ -438,9 +561,13 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
438561
outputs["position_ids"].extend(pos_ids)
439562
outputs["cur_position"] = np.max(pos_ids) + 1
440563

441-
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
564+
def _add_processed_video(
565+
self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
566+
) -> None:
442567
frames, meta = frames_cache
443568
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
569+
if token_len and num_tokens != token_len:
570+
raise ValueError("video tokens num not match the size")
444571

445572
t, h, w = meta["thw"]
446573
outputs["images"].append(frames)

0 commit comments

Comments
 (0)