@@ -106,6 +106,8 @@ def __init__(
106106 params_dtype : str = "bfloat16" ,
107107 prefix = "" ,
108108 padding_size : int = DEFAULT_VOCAB_PADDING_SIZE ,
109+ org_num_embeddings : int | None = None ,
110+ general = False ,
109111 ) -> None :
110112 """
111113 Initialize the VocabParallelEmbedding layer for the model.
@@ -132,17 +134,23 @@ def __init__(
132134 self .max_position_embeddings : int = fd_config .model_config .max_position_embeddings
133135 self .tie_word_embeddings : bool = fd_config .model_config .tie_word_embeddings
134136 self .params_dtype : str = params_dtype
135- self .padding_size = padding_size
136137
137- self .org_vocab_size = num_embeddings
138+ self .general = general # used for general Embedding
138139 self .num_embeddings = num_embeddings
139- num_added_embeddings = num_embeddings - self .org_vocab_size
140+ self .padding_size = padding_size
141+ if self .general :
142+ self .org_vocab_size = num_embeddings
143+ self .num_embeddings_padded = num_embeddings
144+ self .org_vocab_size_padded = num_embeddings
145+ else :
146+ self .org_vocab_size = org_num_embeddings or num_embeddings
147+ num_added_embeddings = num_embeddings - self .org_vocab_size
140148
141- self .org_vocab_size_padded = pad_vocab_size (self .org_vocab_size , self .padding_size )
142- self .num_embeddings_padded = pad_vocab_size (
143- self .org_vocab_size_padded + num_added_embeddings , self .padding_size
144- )
145- assert self .org_vocab_size_padded <= self .num_embeddings_padded
149+ self .org_vocab_size_padded = pad_vocab_size (self .org_vocab_size , self .padding_size )
150+ self .num_embeddings_padded = pad_vocab_size (
151+ self .org_vocab_size_padded + num_added_embeddings , self .padding_size
152+ )
153+ assert self .org_vocab_size_padded <= self .num_embeddings_padded
146154 self .shard_indices = self ._get_indices (
147155 self .num_embeddings_padded ,
148156 self .org_vocab_size_padded ,
@@ -152,9 +160,6 @@ def __init__(
152160 self .world_size ,
153161 )
154162
155- if num_embeddings % self .world_size != 0 :
156- self .num_embeddings_padded = pad_vocab_size (num_embeddings , self .padding_size )
157-
158163 if not self .column_cut :
159164 self .embeddings = fleet .meta_parallel .VocabParallelEmbedding (
160165 self .num_embeddings_padded ,
@@ -188,7 +193,7 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
188193 Args:
189194 state_dict (dict): A dictionary containing the checkpoint weights and biases.
190195 """
191- if self .tie_word_embeddings :
196+ if self .tie_word_embeddings and not self . general :
192197 weight_tensor = get_tensor (state_dict [self .prefix + ".weight" ]).astype (paddle .get_default_dtype ())
193198 else :
194199 weight_tensor = get_tensor (state_dict .pop (self .prefix + ".weight" )).astype (paddle .get_default_dtype ())
0 commit comments