Skip to content

Commit 34f59d9

Browse files
authored
[RL]Fix missing is_distributed attribute (#5150)
* fix * update
1 parent 6ca2651 commit 34f59d9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
5656
is_bias=False,
5757
default_initializer=paddle.nn.initializer.Constant(0),
5858
)
59-
split_axis = extra_weight_attrs.get("split_axis")
60-
if hasattr(layer, "nranks") and layer.nranks > 0:
61-
_set_var_distributed(layer.weight, split_axis=split_axis)
6259

6360
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
6461
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
@@ -882,15 +879,18 @@ def __init__(
882879
if add_bias:
883880
assert with_bias, "with_bias must be True when add_bias is True."
884881
assert self.quant_method is not None
885-
self.quant_method.create_weights(
886-
self,
887-
split_axis=0,
882+
create_weight_kwargs = dict(
883+
layer=self,
888884
output_dim=None if self.split_token else False,
889885
weight_loader=(
890886
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
891887
),
892888
model_format=fd_config.model_config.model_format,
893889
)
890+
if self.nranks > 0:
891+
create_weight_kwargs["split_axis"] = 0
892+
create_weight_kwargs["is_distributed"] = True
893+
self.quant_method.create_weights(**create_weight_kwargs)
894894

895895
self.reduce_results = reduce_results
896896

0 commit comments

Comments
 (0)