2323
2424from six .moves import xrange # pylint: disable=redefined-builtin
2525
26+ from tensor2tensor .layers import common_attention
2627from tensor2tensor .layers import common_layers
2728from tensor2tensor .models import transformer
2829from tensor2tensor .utils import registry
@@ -49,13 +50,43 @@ def residual_conv(x, repeat, hparams, name, reuse=None):
4950 return x
5051
5152
52- def decompress_step (source , hparams , first_relu , name ):
53+ def attend (x , source , hparams , name ):
54+ with tf .variable_scope (name ):
55+ x = tf .squeeze (x , axis = 2 )
56+ if len (source .get_shape ()) > 3 :
57+ source = tf .squeeze (source , axis = 2 )
58+ source = common_attention .add_timing_signal_1d (source )
59+ y = common_attention .multihead_attention (
60+ common_layers .layer_preprocess (x , hparams ), source , None ,
61+ hparams .attention_key_channels or hparams .hidden_size ,
62+ hparams .attention_value_channels or hparams .hidden_size ,
63+ hparams .hidden_size , hparams .num_heads ,
64+ hparams .attention_dropout )
65+ res = common_layers .layer_postprocess (x , y , hparams )
66+ return tf .expand_dims (res , axis = 2 )
67+
68+
69+ def interleave (x , y , axis = 1 ):
70+ x = tf .expand_dims (x , axis = axis + 1 )
71+ y = tf .expand_dims (y , axis = axis + 1 )
72+ return tf .concat ([x , y ], axis = axis + 1 )
73+
74+
75+ def decompress_step (source , c , hparams , first_relu , name ):
5376 """Decompression function."""
5477 with tf .variable_scope (name ):
5578 shape = tf .shape (source )
56- thicker = common_layers .conv_block (
57- source , hparams .hidden_size * 2 , [((1 , 1 ), (1 , 1 ))],
58- first_relu = first_relu , name = "decompress_conv" )
79+ if c is not None :
80+ source = attend (source , c , hparams , "decompress_attend" )
81+ first = common_layers .conv_block (
82+ source ,
83+ hparams .hidden_size , [((1 , 1 ), (3 , 1 )), ((1 , 1 ), (3 , 1 ))],
84+ first_relu = first_relu , padding = "SAME" , name = "decompress_conv1" )
85+ second = common_layers .conv_block (
86+ tf .concat ([source , first ], axis = 3 ),
87+ hparams .hidden_size , [((1 , 1 ), (3 , 1 )), ((1 , 1 ), (3 , 1 ))],
88+ first_relu = first_relu , padding = "SAME" , name = "decompress_conv2" )
89+ thicker = interleave (first , second )
5990 return tf .reshape (thicker , [shape [0 ], shape [1 ] * 2 , 1 , hparams .hidden_size ])
6091
6192
@@ -71,23 +102,25 @@ def vae(x, hparams, name):
71102 return z , tf .reduce_mean (kl ), mu , log_sigma
72103
73104
74- def compress (inputs , hparams , name ):
105+ def compress (x , c , hparams , name ):
75106 """Compress."""
76107 with tf .variable_scope (name ):
77108 # Run compression by strided convs.
78- cur = inputs
109+ cur = x
79110 for i in xrange (hparams .num_compress_steps ):
111+ if c is not None :
112+ cur = attend (cur , c , hparams , "compress_attend_%d" % i )
80113 cur = residual_conv (cur , 1 , hparams , "compress_rc_%d" % i )
81114 cur = common_layers .conv_block (
82115 cur , hparams .hidden_size , [((1 , 1 ), (2 , 1 ))],
83116 strides = (2 , 1 ), name = "compress_%d" % i )
84117 return cur
85118
86119
87- def vae_compress (inputs , hparams , compress_name , decompress_name , reuse = None ):
120+ def vae_compress (x , c , hparams , compress_name , decompress_name , reuse = None ):
88121 """Compress, then VAE."""
89122 with tf .variable_scope (compress_name , reuse = reuse ):
90- cur = compress (inputs , hparams , "compress" )
123+ cur = compress (x , c , hparams , "compress" )
91124 # Convolve and ReLu to get state.
92125 cur = common_layers .conv_block (
93126 cur , hparams .hidden_size , [((1 , 1 ), (1 , 1 ))], name = "mid_conv" )
@@ -100,7 +133,7 @@ def vae_compress(inputs, hparams, compress_name, decompress_name, reuse=None):
100133 for i in xrange (hparams .num_compress_steps ):
101134 j = hparams .num_compress_steps - i - 1
102135 z = residual_conv (z , 1 , hparams , "decompress_rc_%d" % j )
103- z = decompress_step (z , hparams , i > 0 , "decompress__step_%d" % j )
136+ z = decompress_step (z , c , hparams , i > 0 , "decompress__step_%d" % j )
104137 return z , kl_loss , mu , log_sigma
105138
106139
@@ -124,6 +157,13 @@ def dropmask(targets, targets_dropout_max, is_training):
124157 return targets * keep_mask
125158
126159
160+ def ffn (x , hparams , name ):
161+ with tf .variable_scope (name ):
162+ y = transformer .transformer_ffn_layer (
163+ common_layers .layer_preprocess (x , hparams ), hparams )
164+ return common_layers .layer_postprocess (x , y , hparams )
165+
166+
127167def vae_transformer_internal (inputs , targets , target_space , hparams ):
128168 """VAE Transformer, main step used for training."""
129169 with tf .variable_scope ("vae_transformer" ):
@@ -140,36 +180,40 @@ def vae_transformer_internal(inputs, targets, target_space, hparams):
140180 inputs = encode (inputs , target_space , hparams , "input_enc" )
141181
142182 # Dropout targets or swap for zeros 5% of the time.
183+ targets_nodrop = targets
143184 max_prestep = hparams .kl_warmup_steps
144185 prob_targets = 0.95 if is_training else 1.0
145186 targets_dropout_max = common_layers .inverse_lin_decay (max_prestep ) - 0.01
146187 targets = dropmask (targets , targets_dropout_max * 0.7 , is_training )
147188 targets = tf .cond (tf .less (tf .random_uniform ([]), prob_targets ),
148189 lambda : targets , lambda : tf .zeros_like (targets ))
149-
150- # Join targets with inputs, run encoder.
151- # to_encode = common_layers.conv_block(
152- # tf.expand_dims(tf.concat([targets, inputs], axis=2), axis=2),
153- # hparams.hidden_size, [((1, 1), (1, 1))],
154- # first_relu=False, name="join_targets")
155- # to_compress = encode(tf.squeeze(to_encode, axis=2),
156- # target_space, hparams, "enc")
190+ targets = targets_nodrop
157191
158192 # Compress and vae.
159- z , kl_loss , _ , _ = vae_compress (tf .expand_dims (targets , axis = 2 ), hparams ,
160- "vae_compress" , "vae_decompress" )
193+ z = tf .get_variable ("z" , [hparams .hidden_size ])
194+ z = tf .reshape (z , [1 , 1 , 1 , - 1 ])
195+ z = tf .tile (z , [tf .shape (inputs )[0 ], 1 , 1 , 1 ])
196+
197+ z = attend (z , inputs , hparams , "z_attendsi" )
198+ z = ffn (z , hparams , "zff2" )
199+ z = attend (z , targets , hparams , "z_attendst2" )
200+ z = ffn (z , hparams , "zff3" )
201+ z , kl_loss , _ , _ = vae (z , hparams , name = "vae" )
202+ z = tf .layers .dense (z , hparams .hidden_size , name = "z_to_dense" )
203+
204+ # z, kl_loss, _, _ = vae_compress(
205+ # tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2),
206+ # hparams, "vae_compress", "vae_decompress")
161207
162- # Join z with inputs, run decoder.
163- to_decode = common_layers .conv_block (
164- tf .concat ([z , tf .expand_dims (inputs , axis = 2 )], axis = 3 ),
165- hparams .hidden_size , [((1 , 1 ), (1 , 1 ))], name = "join_z" )
166- ret = encode (tf .squeeze (to_decode , axis = 2 ), target_space , hparams , "dec" )
167- # to_decode = residual_conv(to_decode, 2, hparams, "dec_conv")
168- # ret = tf.squeeze(to_decode, axis=2)
208+ decoder_in = tf .squeeze (z , axis = 2 ) + tf .zeros_like (targets )
209+ (decoder_input , decoder_self_attention_bias ) = (
210+ transformer .transformer_prepare_decoder (decoder_in , hparams ))
211+ ret = transformer .transformer_decoder (
212+ decoder_input , inputs , decoder_self_attention_bias , None , hparams )
169213
170- # Randomize decoder inputs..
171- kl_loss *= common_layers . inverse_exp_decay ( max_prestep ) * 10.0
172- return tf .expand_dims (ret , axis = 2 ), kl_loss
214+ kl_loss *= common_layers . inverse_exp_decay ( int ( max_prestep * 1.5 )) * 5.0
215+ losses = { "kl" : kl_loss }
216+ return tf .expand_dims (ret , axis = 2 ), losses
173217
174218
175219@registry .register_model
@@ -203,13 +247,15 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
203247 sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
204248 samples = tf .concat (sharded_samples , 0 )
205249
206- # 2nd step.
207- with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
208- features ["targets" ] = samples
209- sharded_logits , _ = self .model_fn (
210- features , False , last_position_only = last_position_only )
211- sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
212- samples = tf .concat (sharded_samples , 0 )
250+ # More steps.
251+ how_many_more_steps = 20
252+ for _ in xrange (how_many_more_steps ):
253+ with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
254+ features ["targets" ] = samples
255+ sharded_logits , _ = self .model_fn (
256+ features , False , last_position_only = last_position_only )
257+ sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
258+ samples = tf .concat (sharded_samples , 0 )
213259
214260 if inputs_old is not None : # Restore to not confuse Estimator.
215261 features ["inputs" ] = inputs_old
@@ -221,9 +267,10 @@ def transformer_vae_small():
221267 """Set of hyperparameters."""
222268 hparams = transformer .transformer_small ()
223269 hparams .batch_size = 2048
270+ hparams .learning_rate_warmup_steps = 16000
224271 hparams .add_hparam ("z_size" , 128 )
225272 hparams .add_hparam ("num_compress_steps" , 4 )
226- hparams .add_hparam ("kl_warmup_steps" , 50000 )
273+ hparams .add_hparam ("kl_warmup_steps" , 60000 )
227274 return hparams
228275
229276
@@ -233,9 +280,9 @@ def transformer_vae_base():
233280 hparams = transformer_vae_small ()
234281 hparams .hidden_size = 512
235282 hparams .filter_size = 2048
236- hparams .attention_dropout = 0.1
237- hparams .relu_dropout = 0.1
238- hparams .dropout = 0.1
239- hparams .num_hidden_layers = 4
283+ hparams .attention_dropout = 0.0
284+ hparams .relu_dropout = 0.0
285+ hparams .dropout = 0.0
286+ hparams .num_hidden_layers = 3
240287 hparams .z_size = 256
241288 return hparams
0 commit comments