From 0b88556772d4d0708156fabcb517552eda12e823 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 10:58:45 +0000 Subject: [PATCH] create param list for scales list Signed-off-by: yiliu30 --- .../torch/algorithms/fp8_quant/_core/scale_handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py index 5589e90e715..27a7ac56fc6 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py @@ -51,7 +51,7 @@ def create_scale_tensor(orig_scales, scale_format): if isinstance(orig_scales, (torch.Tensor, float)): return scale_creation_func(orig_scales) elif isinstance(orig_scales, list): - return [scale_creation_func(x) for x in orig_scales] + return torch.nn.ParameterList([scale_creation_func(x) for x in orig_scales]) else: raise ValueError("unexpected scale format value {}".format(scale_format)) @@ -78,6 +78,8 @@ def scale_to_scalar(scale): def get_scale_dtype(scale): if isinstance(scale, torch.Tensor): # tensor case return scale.dtype + if isinstance(scale, torch.nn.ParameterList): # tensor case + return scale[0].dtype elif isinstance(scale, float): # already scalar case return type(scale).__name__ elif scale is None: # possible dynamic scalar case