@@ -667,14 +667,16 @@ def update_hparams_for_tpu(hparams):
667667
668668@registry .register_hparams
669669def imagetransformer_base_tpu ():
670- hparams = imagetransformer_base ()
670+ """Transformer base params for cifar-10."""
671+ hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet ()
671672 update_hparams_for_tpu (hparams )
672673 hparams .batch_size = 4
673674 hparams .num_heads = 4 # heads are expensive on tpu
674- hparams .hidden_size = 256
675- hparams .filter_size = 512
676- hparams .num_hidden_layers = 8
677- hparams .sampling_method = "random"
675+ hparams .num_decoder_layers = 12
676+ hparams .block_length = 128
677+ hparams .layer_preprocess_sequence = "none"
678+ hparams .layer_postprocess_sequence = "dan"
679+ hparams .layer_prepostprocess_dropout = 0.3
678680 return hparams
679681
680682
@@ -691,11 +693,16 @@ def imagetransformer_sep_channels_8l_tpu():
691693
692694@registry .register_hparams
693695def imagetransformer_b10l_4h_big_uncond_dr03_tpu ():
696+ """Small model for tpu cifar 10."""
694697 hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet ()
695698 update_hparams_for_tpu (hparams )
696699 hparams .batch_size = 4
697700 hparams .num_heads = 4 # heads are expensive on tpu
698701 hparams .num_decoder_layers = 10
702+ hparams .block_length = 128
703+ hparams .hidden_size = 256
704+ hparams .filter_size = 1024
705+ hparams .learning_rate = 0.2
699706 hparams .layer_preprocess_sequence = "none"
700707 hparams .layer_postprocess_sequence = "dan"
701708 return hparams
@@ -740,6 +747,8 @@ def imagetransformer_b12l_4h_big_uncond_dr03_tpu():
740747 hparams .num_heads = 4 # heads are expensive on tpu
741748 hparams .num_decoder_layers = 12
742749 hparams .block_length = 128
750+ hparams .hidden_size = 512
751+ hparams .filter_size = 1024
743752 hparams .layer_preprocess_sequence = "none"
744753 hparams .layer_postprocess_sequence = "dan"
745754 hparams .layer_prepostprocess_dropout = 0.3
0 commit comments