Skip to content

Commit 90a06a1

Browse files
author
Atif Ahmed
committed
Adding logits layer weights and bias
1 parent e0a2da2 commit 90a06a1

File tree

1 file changed

+5
-3
lines changed
  • texar/torch/modules/pretrained

1 file changed

+5
-3
lines changed

texar/torch/modules/pretrained/bert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ class PretrainedBERTMixin(PretrainedMixin, ABC):
167167

168168
# BERT for MS-MARCO
169169
'bert-msmarco-base':
170-
_BERT_MSMARCO_PATH + '1cyUrhs7JaCJTTu-DjFUqP6Bs4f8a6JTX/view',
170+
_BERT_MSMARCO_PATH + '1cyUrhs7JaCJTTu-DjFUqP6Bs4f8a6JTX/',
171171
'bert-msmarco-large':
172-
_BERT_MSMARCO_PATH + '1crlASTMlsihALlkabAQP6JTYIZwC1Wm8/view'
172+
_BERT_MSMARCO_PATH + '1crlASTMlsihALlkabAQP6JTYIZwC1Wm8/'
173173
}
174174
_MODEL2CKPT = {
175175
# Standard BERT
@@ -325,7 +325,9 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
325325
}
326326
pooler_map = {
327327
'bert/pooler/dense/bias': 'pooler.0.bias',
328-
'bert/pooler/dense/kernel': 'pooler.0.weight'
328+
'bert/pooler/dense/kernel': 'pooler.0.weight',
329+
'output_bias': '_logits_layer.bias',
330+
'output_weights': '_logits_layer.weight',
329331
}
330332
tf_path = os.path.abspath(os.path.join(
331333
cache_dir, self._MODEL2CKPT[pretrained_model_name]))

0 commit comments

Comments
 (0)