|
| 1 | +""" |
| 2 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import unittest |
| 18 | +from unittest.mock import MagicMock, patch |
| 19 | + |
| 20 | +import paddle |
| 21 | + |
| 22 | +from fastdeploy.model_executor.layers.quantization.tensor_wise_fp8 import ( |
| 23 | + TensorWiseFP8Config, |
| 24 | + TensorWiseFP8LinearMethod, |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +# Dummy classes for test |
| 29 | +class DummyLayer: |
| 30 | + """Dummy linear layer for test purposes""" |
| 31 | + |
| 32 | + def __init__(self): |
| 33 | + self.weight_shape = [4, 8] |
| 34 | + self.weight_key = "weight" |
| 35 | + self.weight_scale_key = "weight_scale" |
| 36 | + self.act_scale_key = "act_scale" |
| 37 | + self.weight_dtype = "float32" |
| 38 | + self.weight = MagicMock() # Mock weight to avoid dtype copy errors |
| 39 | + |
| 40 | + def create_parameter(self, shape, dtype, is_bias=False, default_initializer=None): |
| 41 | + """Mock parameter creation""" |
| 42 | + return MagicMock() |
| 43 | + |
| 44 | + |
| 45 | +class DummyFusedMoE: |
| 46 | + """Dummy FusedMoE class for patching""" |
| 47 | + |
| 48 | + pass |
| 49 | + |
| 50 | + |
| 51 | +class TestTensorWiseFP8Config(unittest.TestCase): |
| 52 | + """Test suite for TensorWiseFP8Config""" |
| 53 | + |
| 54 | + def test_get_quant_method_linear(self): |
| 55 | + """Verify linear layer returns TensorWiseFP8LinearMethod""" |
| 56 | + cfg = TensorWiseFP8Config() |
| 57 | + layer = DummyLayer() |
| 58 | + method = cfg.get_quant_method(layer) |
| 59 | + self.assertIsInstance(method, TensorWiseFP8LinearMethod) |
| 60 | + |
| 61 | + def test_get_quant_method_moe(self): |
| 62 | + """Verify FusedMoE layer returns valid quant method""" |
| 63 | + cfg = TensorWiseFP8Config() |
| 64 | + layer = DummyFusedMoE() |
| 65 | + with patch("fastdeploy.model_executor.layers.moe.FusedMoE", DummyFusedMoE): |
| 66 | + method = cfg.get_quant_method(layer) |
| 67 | + self.assertTrue(hasattr(method, "quant_config")) |
| 68 | + |
| 69 | + |
| 70 | +class TestTensorWiseFP8LinearMethod(unittest.TestCase): |
| 71 | + """Test suite for TensorWiseFP8LinearMethod""" |
| 72 | + |
| 73 | + def setUp(self): |
| 74 | + """Initialize test fixtures""" |
| 75 | + self.layer = DummyLayer() |
| 76 | + self.method = TensorWiseFP8LinearMethod(TensorWiseFP8Config()) |
| 77 | + # Initialize scales to avoid apply errors |
| 78 | + self.method.act_scale = 1.0 |
| 79 | + self.method.total_scale = 1.0 |
| 80 | + |
| 81 | + def test_create_weights(self): |
| 82 | + """Verify weight dtype is set to float8_e4m3fn""" |
| 83 | + self.method.create_weights(self.layer) |
| 84 | + self.assertEqual(self.layer.weight_dtype, "float8_e4m3fn") |
| 85 | + |
| 86 | + def test_process_prequanted_weights(self): |
| 87 | + """Verify prequantized weights and scales are processed correctly""" |
| 88 | + self.layer.weight.copy_ = MagicMock() |
| 89 | + state_dict = { |
| 90 | + "weight": paddle.randn([8, 4]), |
| 91 | + "weight_scale": paddle.to_tensor([0.5], dtype="float32"), |
| 92 | + "act_scale": paddle.to_tensor([2.0], dtype="float32"), |
| 93 | + } |
| 94 | + self.method.process_prequanted_weights(self.layer, state_dict) |
| 95 | + self.assertAlmostEqual(self.method.act_scale, 2.0) |
| 96 | + self.assertAlmostEqual(self.method.total_scale, 1.0) |
| 97 | + self.layer.weight.copy_.assert_called_once() |
| 98 | + |
| 99 | + @patch("fastdeploy.model_executor.ops.gpu.fused_hadamard_quant_fp8", autospec=True) |
| 100 | + @patch("fastdeploy.model_executor.ops.gpu.cutlass_fp8_fp8_half_gemm_fused", autospec=True) |
| 101 | + def test_apply(self, mock_gemm, mock_quant): |
| 102 | + """Verify apply method executes with mocked ops""" |
| 103 | + mock_quant.side_effect = lambda x, scale: x |
| 104 | + mock_gemm.side_effect = lambda x, w, **kwargs: x |
| 105 | + x = paddle.randn([4, 8]) |
| 106 | + out = self.method.apply(self.layer, x) |
| 107 | + self.assertTrue((out == x).all()) |
| 108 | + |
| 109 | + |
| 110 | +if __name__ == "__main__": |
| 111 | + unittest.main() |
0 commit comments