diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index e7725be6d23..226f4e14c1d 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -56,9 +56,6 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) - split_axis = extra_weight_attrs.get("split_axis") - if hasattr(layer, "nranks") and layer.nranks > 0: - _set_var_distributed(layer.weight, split_axis=split_axis) if self.model_format == "torch" and "output_dim" in extra_weight_attrs: extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"] @@ -882,15 +879,18 @@ def __init__( if add_bias: assert with_bias, "with_bias must be True when add_bias is True." assert self.quant_method is not None - self.quant_method.create_weights( - self, - split_axis=0, + create_weight_kwargs = dict( + layer=self, output_dim=None if self.split_token else False, weight_loader=( self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) ), model_format=fd_config.model_config.model_format, ) + if self.nranks > 0: + create_weight_kwargs["split_axis"] = 0 + create_weight_kwargs["is_distributed"] = True + self.quant_method.create_weights(**create_weight_kwargs) self.reduce_results = reduce_results