Skip to content

Commit a418d7b

Browse files
authored
[CI] Add Unittest (#5187)
* add test * Delete tests/model_executor/test_w4afp8.py * Rename test_utils.py to test_tool_parsers_utils.py * add test * add test * fix platforms * Delete tests/cache_manager/test_platforms.py * dont change Removed copyright notice and license information.
1 parent 717da50 commit a418d7b

File tree

5 files changed

+856
-30
lines changed

5 files changed

+856
-30
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import unittest
18+
19+
from partial_json_parser.core.options import Allow
20+
21+
from fastdeploy.entrypoints.openai.tool_parsers import utils
22+
23+
24+
class TestPartialJsonUtils(unittest.TestCase):
25+
"""Unit test suite for partial JSON utility functions."""
26+
27+
def test_find_common_prefix(self):
28+
"""Test common prefix detection between two strings."""
29+
string1 = '{"fruit": "ap"}'
30+
string2 = '{"fruit": "apple"}'
31+
self.assertEqual(utils.find_common_prefix(string1, string2), '{"fruit": "ap')
32+
33+
def test_find_common_suffix(self):
34+
"""Test common suffix detection between two strings."""
35+
string1 = '{"fruit": "ap"}'
36+
string2 = '{"fruit": "apple"}'
37+
self.assertEqual(utils.find_common_suffix(string1, string2), '"}')
38+
39+
def test_extract_intermediate_diff(self):
40+
"""Test extraction of intermediate difference between current and old strings."""
41+
old_string = '{"fruit": "ap"}'
42+
current_string = '{"fruit": "apple"}'
43+
self.assertEqual(utils.extract_intermediate_diff(current_string, old_string), "ple")
44+
45+
def test_find_all_indices(self):
46+
"""Test finding all occurrence indices of a substring in a string."""
47+
target_string = "banana"
48+
substring = "an"
49+
self.assertEqual(utils.find_all_indices(target_string, substring), [1, 3])
50+
51+
def test_partial_json_loads_complete(self):
52+
"""Test partial_json_loads with a complete JSON string."""
53+
input_json = '{"a": 1, "b": 2}'
54+
parse_flags = Allow.ALL
55+
parsed_obj, parsed_length = utils.partial_json_loads(input_json, parse_flags)
56+
self.assertEqual(parsed_obj, {"a": 1, "b": 2})
57+
self.assertEqual(parsed_length, len(input_json))
58+
59+
def test_is_complete_json(self):
60+
"""Test JSON completeness check."""
61+
self.assertTrue(utils.is_complete_json('{"a": 1}'))
62+
self.assertFalse(utils.is_complete_json('{"a": 1'))
63+
64+
def test_consume_space(self):
65+
"""Test whitespace consumption from the start of a string."""
66+
input_string = " \t\nabc"
67+
# 3 spaces + 1 tab + 1 newline = 5 whitespace characters
68+
first_non_whitespace_idx = utils.consume_space(0, input_string)
69+
self.assertEqual(first_non_whitespace_idx, 5)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import unittest
18+
from unittest.mock import MagicMock, patch
19+
20+
import paddle
21+
22+
from fastdeploy.model_executor.layers.quantization.tensor_wise_fp8 import (
23+
TensorWiseFP8Config,
24+
TensorWiseFP8LinearMethod,
25+
)
26+
27+
28+
# Dummy classes for test
29+
class DummyLayer:
30+
"""Dummy linear layer for test purposes"""
31+
32+
def __init__(self):
33+
self.weight_shape = [4, 8]
34+
self.weight_key = "weight"
35+
self.weight_scale_key = "weight_scale"
36+
self.act_scale_key = "act_scale"
37+
self.weight_dtype = "float32"
38+
self.weight = MagicMock() # Mock weight to avoid dtype copy errors
39+
40+
def create_parameter(self, shape, dtype, is_bias=False, default_initializer=None):
41+
"""Mock parameter creation"""
42+
return MagicMock()
43+
44+
45+
class DummyFusedMoE:
46+
"""Dummy FusedMoE class for patching"""
47+
48+
pass
49+
50+
51+
class TestTensorWiseFP8Config(unittest.TestCase):
52+
"""Test suite for TensorWiseFP8Config"""
53+
54+
def test_get_quant_method_linear(self):
55+
"""Verify linear layer returns TensorWiseFP8LinearMethod"""
56+
cfg = TensorWiseFP8Config()
57+
layer = DummyLayer()
58+
method = cfg.get_quant_method(layer)
59+
self.assertIsInstance(method, TensorWiseFP8LinearMethod)
60+
61+
def test_get_quant_method_moe(self):
62+
"""Verify FusedMoE layer returns valid quant method"""
63+
cfg = TensorWiseFP8Config()
64+
layer = DummyFusedMoE()
65+
with patch("fastdeploy.model_executor.layers.moe.FusedMoE", DummyFusedMoE):
66+
method = cfg.get_quant_method(layer)
67+
self.assertTrue(hasattr(method, "quant_config"))
68+
69+
70+
class TestTensorWiseFP8LinearMethod(unittest.TestCase):
71+
"""Test suite for TensorWiseFP8LinearMethod"""
72+
73+
def setUp(self):
74+
"""Initialize test fixtures"""
75+
self.layer = DummyLayer()
76+
self.method = TensorWiseFP8LinearMethod(TensorWiseFP8Config())
77+
# Initialize scales to avoid apply errors
78+
self.method.act_scale = 1.0
79+
self.method.total_scale = 1.0
80+
81+
def test_create_weights(self):
82+
"""Verify weight dtype is set to float8_e4m3fn"""
83+
self.method.create_weights(self.layer)
84+
self.assertEqual(self.layer.weight_dtype, "float8_e4m3fn")
85+
86+
def test_process_prequanted_weights(self):
87+
"""Verify prequantized weights and scales are processed correctly"""
88+
self.layer.weight.copy_ = MagicMock()
89+
state_dict = {
90+
"weight": paddle.randn([8, 4]),
91+
"weight_scale": paddle.to_tensor([0.5], dtype="float32"),
92+
"act_scale": paddle.to_tensor([2.0], dtype="float32"),
93+
}
94+
self.method.process_prequanted_weights(self.layer, state_dict)
95+
self.assertAlmostEqual(self.method.act_scale, 2.0)
96+
self.assertAlmostEqual(self.method.total_scale, 1.0)
97+
self.layer.weight.copy_.assert_called_once()
98+
99+
@patch("fastdeploy.model_executor.ops.gpu.fused_hadamard_quant_fp8", autospec=True)
100+
@patch("fastdeploy.model_executor.ops.gpu.cutlass_fp8_fp8_half_gemm_fused", autospec=True)
101+
def test_apply(self, mock_gemm, mock_quant):
102+
"""Verify apply method executes with mocked ops"""
103+
mock_quant.side_effect = lambda x, scale: x
104+
mock_gemm.side_effect = lambda x, w, **kwargs: x
105+
x = paddle.randn([4, 8])
106+
out = self.method.apply(self.layer, x)
107+
self.assertTrue((out == x).all())
108+
109+
110+
if __name__ == "__main__":
111+
unittest.main()

0 commit comments

Comments
 (0)