Skip to content

Commit 014c558

Browse files
author
Atif Ahmed
committed
Making the PretrainedMixin work for both encoder and classifier
1 parent 90a06a1 commit 014c558

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

texar/torch/modules/classifiers/bert_classifier.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def __init__(self,
7171

7272
super().__init__(hparams=hparams)
7373

74+
self.load_pretrained_config(pretrained_model_name, cache_dir)
75+
7476
# Create the underlying encoder
7577
encoder_hparams = dict_fetch(hparams,
7678
self._ENCODER_CLASS.default_hparams())
@@ -120,6 +122,8 @@ def __init__(self,
120122
(self.num_classes <= 0 and
121123
self._hparams.encoder.dim == 1)
122124

125+
self.init_pretrained_weights(class_type='classifier')
126+
123127
@staticmethod
124128
def default_hparams():
125129
r"""Returns a dictionary of hyperparameters with default values.

texar/torch/modules/pretrained/bert.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,20 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
326326
pooler_map = {
327327
'bert/pooler/dense/bias': 'pooler.0.bias',
328328
'bert/pooler/dense/kernel': 'pooler.0.weight',
329+
}
330+
classifier_map = {
329331
'output_bias': '_logits_layer.bias',
330332
'output_weights': '_logits_layer.weight',
331333
}
334+
global_prefix_map = {
335+
'classifier': '_encoder.'
336+
}
332337
tf_path = os.path.abspath(os.path.join(
333338
cache_dir, self._MODEL2CKPT[pretrained_model_name]))
334339

340+
class_type = kwargs.get('class_type', 'encoder')
341+
global_prefix = global_prefix_map.get(class_type, '')
342+
335343
# Load weights from TF model
336344
init_vars = tf.train.list_variables(tf_path)
337345
tfnames, arrays = [], []
@@ -351,13 +359,14 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
351359
continue
352360

353361
if name in global_tensor_map:
354-
v_name = global_tensor_map[name]
362+
v_name = global_prefix + global_tensor_map[name]
355363
pointer = self._name_to_variable(v_name)
356364
assert pointer.shape == array.shape
357365
pointer.data = torch.from_numpy(array)
358366
idx += 1
359367
elif name in pooler_map:
360-
pointer = self._name_to_variable(pooler_map[name])
368+
pointer = self._name_to_variable(global_prefix +
369+
pooler_map[name])
361370
if name.endswith('bias'):
362371
assert pointer.shape == array.shape
363372
pointer.data = torch.from_numpy(array)
@@ -367,6 +376,13 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
367376
assert pointer.shape == array_t.shape
368377
pointer.data = torch.from_numpy(array_t)
369378
idx += 1
379+
elif name in classifier_map:
380+
if class_type != 'classifier':
381+
continue
382+
pointer = self._name_to_variable(classifier_map[name])
383+
assert pointer.shape == array.shape
384+
pointer.data = torch.from_numpy(array)
385+
idx += 1
370386
else:
371387
# here name is the TensorFlow variable name
372388
name_tmp = name.split("/")
@@ -375,12 +391,14 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
375391
name_tmp = "/".join(name_tmp[3:])
376392
if name_tmp in layer_tensor_map:
377393
v_name = layer_tensor_map[name_tmp].format(layer_no)
378-
pointer = self._name_to_variable(py_prefix + v_name)
394+
pointer = self._name_to_variable(global_prefix +
395+
py_prefix + v_name)
379396
assert pointer.shape == array.shape
380397
pointer.data = torch.from_numpy(array)
381398
elif name_tmp in layer_transpose_map:
382399
v_name = layer_transpose_map[name_tmp].format(layer_no)
383-
pointer = self._name_to_variable(py_prefix + v_name)
400+
pointer = self._name_to_variable(global_prefix +
401+
py_prefix + v_name)
384402
array_t = np.transpose(array)
385403
assert pointer.shape == array_t.shape
386404
pointer.data = torch.from_numpy(array_t)

0 commit comments

Comments
 (0)