@@ -326,12 +326,20 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
326326 pooler_map = {
327327 'bert/pooler/dense/bias' : 'pooler.0.bias' ,
328328 'bert/pooler/dense/kernel' : 'pooler.0.weight' ,
329+ }
330+ classifier_map = {
329331 'output_bias' : '_logits_layer.bias' ,
330332 'output_weights' : '_logits_layer.weight' ,
331333 }
334+ global_prefix_map = {
335+ 'classifier' : '_encoder.'
336+ }
332337 tf_path = os .path .abspath (os .path .join (
333338 cache_dir , self ._MODEL2CKPT [pretrained_model_name ]))
334339
340+ class_type = kwargs .get ('class_type' , 'encoder' )
341+ global_prefix = global_prefix_map .get (class_type , '' )
342+
335343 # Load weights from TF model
336344 init_vars = tf .train .list_variables (tf_path )
337345 tfnames , arrays = [], []
@@ -351,13 +359,14 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
351359 continue
352360
353361 if name in global_tensor_map :
354- v_name = global_tensor_map [name ]
362+ v_name = global_prefix + global_tensor_map [name ]
355363 pointer = self ._name_to_variable (v_name )
356364 assert pointer .shape == array .shape
357365 pointer .data = torch .from_numpy (array )
358366 idx += 1
359367 elif name in pooler_map :
360- pointer = self ._name_to_variable (pooler_map [name ])
368+ pointer = self ._name_to_variable (global_prefix +
369+ pooler_map [name ])
361370 if name .endswith ('bias' ):
362371 assert pointer .shape == array .shape
363372 pointer .data = torch .from_numpy (array )
@@ -367,6 +376,13 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
367376 assert pointer .shape == array_t .shape
368377 pointer .data = torch .from_numpy (array_t )
369378 idx += 1
379+ elif name in classifier_map :
380+ if class_type != 'classifier' :
381+ continue
382+ pointer = self ._name_to_variable (classifier_map [name ])
383+ assert pointer .shape == array .shape
384+ pointer .data = torch .from_numpy (array )
385+ idx += 1
370386 else :
371387 # here name is the TensorFlow variable name
372388 name_tmp = name .split ("/" )
@@ -375,12 +391,14 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
375391 name_tmp = "/" .join (name_tmp [3 :])
376392 if name_tmp in layer_tensor_map :
377393 v_name = layer_tensor_map [name_tmp ].format (layer_no )
378- pointer = self ._name_to_variable (py_prefix + v_name )
394+ pointer = self ._name_to_variable (global_prefix +
395+ py_prefix + v_name )
379396 assert pointer .shape == array .shape
380397 pointer .data = torch .from_numpy (array )
381398 elif name_tmp in layer_transpose_map :
382399 v_name = layer_transpose_map [name_tmp ].format (layer_no )
383- pointer = self ._name_to_variable (py_prefix + v_name )
400+ pointer = self ._name_to_variable (global_prefix +
401+ py_prefix + v_name )
384402 array_t = np .transpose (array )
385403 assert pointer .shape == array_t .shape
386404 pointer .data = torch .from_numpy (array_t )
0 commit comments