Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 46f518c

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Hparam specification for solving Librispeech with Transformer
PiperOrigin-RevId: 185972342
1 parent 8548dab commit 46f518c

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

tensor2tensor/data_generators/librispeech.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,10 @@ def add_librispeech_hparams(hparams):
185185
hparams.train_steps = 5000000
186186
hparams.num_hidden_layers = 4
187187
return hparams
188+
189+
190+
def set_librispeech_length_hparams(hparams):
191+
hparams.max_length = 1650 * 80 # this limits inputs[1] * inputs[2]
192+
hparams.max_input_seq_length = 1650
193+
hparams.max_target_seq_length = 350
194+
return hparams

tensor2tensor/models/transformer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from six.moves import xrange # pylint: disable=redefined-builtin
3333

34+
from tensor2tensor.data_generators import librispeech
3435
from tensor2tensor.layers import common_attention
3536
from tensor2tensor.layers import common_hparams
3637
from tensor2tensor.layers import common_layers
@@ -1423,3 +1424,32 @@ def transformer_lm_tpu_1():
14231424
hparams.hidden_size = 2048
14241425
hparams.filter_size = 8192
14251426
return hparams
1427+
1428+
1429+
@registry.register_hparams
1430+
def transformer_librispeech():
1431+
"""Hparams for training ASR model on Librispeech."""
1432+
hparams = transformer_base()
1433+
1434+
hparams.num_heads = 4
1435+
hparams.filter_size = 1024
1436+
hparams.hidden_size = 256
1437+
hparams.num_encoder_layers = 5
1438+
hparams.num_decoder_layers = 3
1439+
hparams.learning_rate = 0.15
1440+
hparams.batch_size = 6000000
1441+
1442+
librispeech.set_librispeech_length_hparams(hparams)
1443+
return hparams
1444+
1445+
1446+
@registry.register_hparams
1447+
def transformer_librispeech_tpu():
1448+
"""Hparams for training ASR model on Librispeech on TPU."""
1449+
hparams = transformer_librispeech()
1450+
update_hparams_for_tpu(hparams)
1451+
1452+
hparams.batch_size = 32
1453+
librispeech.set_librispeech_length_hparams(hparams)
1454+
return hparams
1455+

0 commit comments

Comments
 (0)