Skip to content

Commit 0a27ad6

Browse files
committed
Merge branch 'bugfix/fix_id_list' of https://github.com/kxz2002/FastDeploy into bugfix/fix_id_list
2 parents 8fed8b4 + e531d73 commit 0a27ad6

File tree

14 files changed

+1953
-115
lines changed

14 files changed

+1953
-115
lines changed

custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
441441
const std::string& quant_method,
442442
const int hadamard_blocksize,
443443
const int valid_token_num) {
444-
if (ffn_in.numel() == 0) {
444+
if (ffn_in.numel() == 0 || valid_token_num == 0) {
445445
paddle::Tensor ffn2_out =
446446
paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16);
447447
return {ffn2_out};

fastdeploy/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,10 +1628,6 @@ def postprocess(self):
16281628
else:
16291629
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
16301630

1631-
self.scheduler_config.max_chunk_len = (
1632-
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
1633-
)
1634-
16351631
if self.long_prefill_token_threshold == 0:
16361632
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
16371633

fastdeploy/engine/common_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ def _process_prefilled_requests():
11251125
# received the request sent by the client
11261126
waiting_request_outputs.append(req_output)
11271127
continue
1128-
1128+
req_output.finished = False
11291129
ready_request_outputs.append(req_output)
11301130
self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}")
11311131

@@ -1145,6 +1145,8 @@ def _process_prefilled_requests():
11451145
self.resource_manager.pre_recycle_resource(request_id)
11461146
if request_id in self.token_processor.tokens_counter:
11471147
del self.token_processor.tokens_counter[request_id]
1148+
req_output.finished = True
1149+
self.scheduler.put_results([req_output])
11481150
continue
11491151
if req_output.error_code != 200:
11501152
self.llm_logger.warning(
@@ -1156,6 +1158,8 @@ def _process_prefilled_requests():
11561158
self.scheduler.put_results([req_output])
11571159
continue
11581160
self.token_processor.tokens_counter[request_id] = 1
1161+
if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance
1162+
self.scheduler.put_results([req_output])
11591163
self.resource_manager.add_prefilled_request(req_output)
11601164
self.llm_logger.debug(f"add prefilled request success, {request_id}")
11611165

fastdeploy/scheduler/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def __init__(self, args):
270270
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
271271
self.max_num_batched_tokens = 2048 # base token_num for text inputs
272272
self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs
273-
self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens
274273
self.max_num_seqs = 34
275274
self.splitwise_role = "mixed"
276275
self.config = None

fastdeploy/spec_decode/mtp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,13 @@ def _init_model_inputs(self):
355355
self.target_model_inputs["decoder_tile_ids_per_batch"]
356356
)
357357
self.model_inputs["target_hidden_states"] = paddle.full(
358-
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
358+
[
359+
self.fd_config.scheduler_config.max_num_batched_tokens
360+
+ self.fd_config.scheduler_config.max_extra_num_batched_tokens,
361+
self.model_config.hidden_size,
362+
],
363+
0,
364+
dtype="bfloat16",
359365
)
360366

361367
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))

scripts/extract_mtp_weight_from_safetensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import argparse
1818
import json
1919
import os
20+
import re
2021

22+
import numpy as np
2123
import paddle
2224
from paddleformers.transformers.model_utils import shard_checkpoint
2325
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
@@ -46,6 +48,28 @@ def parse_args():
4648
return parser.parse_args()
4749

4850

51+
def dtype_byte_size(dtype):
52+
"""
53+
Returns the size (in bytes) occupied by one parameter of type `dtype`.
54+
55+
Example:
56+
57+
```py
58+
>>> dtype_byte_size(paddle.float32)
59+
4
60+
```
61+
"""
62+
if str(dtype) in {"paddle.bool", "bool"}:
63+
return 1 / 8
64+
if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
65+
return 1
66+
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
67+
if bit_search is None:
68+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
69+
bit_size = int(bit_search.groups()[0])
70+
return bit_size // 8
71+
72+
4973
def extract_mtp_weights(input_dir: str) -> dict:
5074
"""
5175
Load all MTP-related weights from safetensors files in input_dir.
@@ -103,6 +127,18 @@ def save_safetensors(state_dict: dict, output_dir: str):
103127
logger.info(f"Saving shard: {save_path}")
104128
safe_save_file(shard, save_path, metadata={"format": "np"})
105129

130+
# If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null
131+
if len(shards) == 1:
132+
logger.info("Generate index file for single shard")
133+
weight_size = 0
134+
for key, weight in shards["model.safetensors"].items():
135+
weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype)
136+
137+
index = {
138+
"metadata": {"total_size": int(weight_size)},
139+
"weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()},
140+
}
141+
106142
index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
107143
with open(index_path, "w", encoding="utf-8") as f:
108144
json.dump(index, f, indent=2)

scripts/run_ci_xpu.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ echo "uninstall org"
4444
python -m pip uninstall paddlepaddle-xpu -y
4545
python -m pip uninstall fastdeploy-xpu -y
4646

47-
python -m pip install paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/
48-
47+
# python -m pip install paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/
48+
# 由于ep并行报错暂时锁死paddle版本
49+
python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.3.0.dev20251123-cp310-cp310-linux_x86_64.whl
4950
echo "build whl"
5051
bash custom_ops/xpu_ops/download_dependencies.sh develop
5152
export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import unittest
18+
19+
from partial_json_parser.core.options import Allow
20+
21+
from fastdeploy.entrypoints.openai.tool_parsers import utils
22+
23+
24+
class TestPartialJsonUtils(unittest.TestCase):
25+
"""Unit test suite for partial JSON utility functions."""
26+
27+
def test_find_common_prefix(self):
28+
"""Test common prefix detection between two strings."""
29+
string1 = '{"fruit": "ap"}'
30+
string2 = '{"fruit": "apple"}'
31+
self.assertEqual(utils.find_common_prefix(string1, string2), '{"fruit": "ap')
32+
33+
def test_find_common_suffix(self):
34+
"""Test common suffix detection between two strings."""
35+
string1 = '{"fruit": "ap"}'
36+
string2 = '{"fruit": "apple"}'
37+
self.assertEqual(utils.find_common_suffix(string1, string2), '"}')
38+
39+
def test_extract_intermediate_diff(self):
40+
"""Test extraction of intermediate difference between current and old strings."""
41+
old_string = '{"fruit": "ap"}'
42+
current_string = '{"fruit": "apple"}'
43+
self.assertEqual(utils.extract_intermediate_diff(current_string, old_string), "ple")
44+
45+
def test_find_all_indices(self):
46+
"""Test finding all occurrence indices of a substring in a string."""
47+
target_string = "banana"
48+
substring = "an"
49+
self.assertEqual(utils.find_all_indices(target_string, substring), [1, 3])
50+
51+
def test_partial_json_loads_complete(self):
52+
"""Test partial_json_loads with a complete JSON string."""
53+
input_json = '{"a": 1, "b": 2}'
54+
parse_flags = Allow.ALL
55+
parsed_obj, parsed_length = utils.partial_json_loads(input_json, parse_flags)
56+
self.assertEqual(parsed_obj, {"a": 1, "b": 2})
57+
self.assertEqual(parsed_length, len(input_json))
58+
59+
def test_is_complete_json(self):
60+
"""Test JSON completeness check."""
61+
self.assertTrue(utils.is_complete_json('{"a": 1}'))
62+
self.assertFalse(utils.is_complete_json('{"a": 1'))
63+
64+
def test_consume_space(self):
65+
"""Test whitespace consumption from the start of a string."""
66+
input_string = " \t\nabc"
67+
# 3 spaces + 1 tab + 1 newline = 5 whitespace characters
68+
first_non_whitespace_idx = utils.consume_space(0, input_string)
69+
self.assertEqual(first_non_whitespace_idx, 5)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)