Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6be28ff

Browse files
Niki Parmarlukaszkaiser
authored andcommitted
Update base hparams for tpu
PiperOrigin-RevId: 195761510
1 parent 59e4a82 commit 6be28ff

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

tensor2tensor/models/image_transformer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -667,14 +667,16 @@ def update_hparams_for_tpu(hparams):
667667

668668
@registry.register_hparams
669669
def imagetransformer_base_tpu():
670-
hparams = imagetransformer_base()
670+
"""Transformer base params for cifar-10."""
671+
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
671672
update_hparams_for_tpu(hparams)
672673
hparams.batch_size = 4
673674
hparams.num_heads = 4 # heads are expensive on tpu
674-
hparams.hidden_size = 256
675-
hparams.filter_size = 512
676-
hparams.num_hidden_layers = 8
677-
hparams.sampling_method = "random"
675+
hparams.num_decoder_layers = 12
676+
hparams.block_length = 128
677+
hparams.layer_preprocess_sequence = "none"
678+
hparams.layer_postprocess_sequence = "dan"
679+
hparams.layer_prepostprocess_dropout = 0.3
678680
return hparams
679681

680682

@@ -691,11 +693,16 @@ def imagetransformer_sep_channels_8l_tpu():
691693

692694
@registry.register_hparams
693695
def imagetransformer_b10l_4h_big_uncond_dr03_tpu():
696+
"""Small model for tpu cifar 10."""
694697
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
695698
update_hparams_for_tpu(hparams)
696699
hparams.batch_size = 4
697700
hparams.num_heads = 4 # heads are expensive on tpu
698701
hparams.num_decoder_layers = 10
702+
hparams.block_length = 128
703+
hparams.hidden_size = 256
704+
hparams.filter_size = 1024
705+
hparams.learning_rate = 0.2
699706
hparams.layer_preprocess_sequence = "none"
700707
hparams.layer_postprocess_sequence = "dan"
701708
return hparams
@@ -740,6 +747,8 @@ def imagetransformer_b12l_4h_big_uncond_dr03_tpu():
740747
hparams.num_heads = 4 # heads are expensive on tpu
741748
hparams.num_decoder_layers = 12
742749
hparams.block_length = 128
750+
hparams.hidden_size = 512
751+
hparams.filter_size = 1024
743752
hparams.layer_preprocess_sequence = "none"
744753
hparams.layer_postprocess_sequence = "dan"
745754
hparams.layer_prepostprocess_dropout = 0.3

0 commit comments

Comments
 (0)