1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- """VAE Transformer."""
16+ """AE Transformer."""
1717
1818from __future__ import absolute_import
1919from __future__ import division
3232import tensorflow as tf
3333
3434
35- def residual_conv (x , repeat , hparams , name , reuse = None ):
35+ def residual_conv (x , repeat , k , hparams , name , reuse = None ):
3636 """A stack of convolution blocks with residual connections."""
3737 with tf .variable_scope (name , reuse = reuse ):
38- k = (3 , 1 )
3938 dilations_and_kernels = [((1 , 1 ), k ) for _ in xrange (3 )]
4039 for i in xrange (repeat ):
4140 with tf .variable_scope ("repeat_%d" % i ):
@@ -72,15 +71,19 @@ def interleave(x, y, axis=1):
7271 return tf .concat ([x , y ], axis = axis + 1 )
7372
7473
75- def decompress_step (source , c , hparams , first_relu , name ):
74+ def decompress_step (source , c , hparams , first_relu , is_2d , name ):
7675 """Decompression function."""
7776 with tf .variable_scope (name ):
7877 shape = tf .shape (source )
7978 if c is not None :
8079 source = attend (source , c , hparams , "decompress_attend" )
80+ multiplier = 4 if is_2d else 2
81+ kernel = (1 , 1 ) if is_2d else (1 , 1 )
8182 thicker = common_layers .conv_block (
82- source , hparams .hidden_size * 2 , [((1 , 1 ), ( 1 , 1 ) )],
83+ source , hparams .hidden_size * multiplier , [((1 , 1 ), kernel )],
8384 first_relu = first_relu , name = "decompress_conv" )
85+ if is_2d :
86+ return tf .depth_to_space (thicker , 2 )
8487 return tf .reshape (thicker , [shape [0 ], shape [1 ] * 2 , 1 , hparams .hidden_size ])
8588
8689
@@ -90,7 +93,7 @@ def gumbel_sample(shape):
9093 return - tf .log (- tf .log (uniform_samples ))
9194
9295
93- def dvae (x , hparams , name ):
96+ def dae (x , hparams , name ):
9497 with tf .variable_scope (name ):
9598 m = tf .layers .dense (x , hparams .v_size , name = "mask" )
9699 logsm = tf .nn .log_softmax (m )
@@ -128,7 +131,7 @@ def nearest(x, means, hparams):
128131 _ , nearest_idx = tf .nn .top_k (- dist , k = 1 )
129132 nearest_hot = tf .one_hot (tf .squeeze (nearest_idx , axis = 1 ), hparams .v_size )
130133 nearest_hot = tf .reshape (nearest_hot , [tf .shape (x )[0 ], tf .shape (x )[1 ],
131- 1 , hparams .v_size ])
134+ tf . shape ( x )[ 2 ] , hparams .v_size ])
132135 return tf .stop_gradient (nearest_hot )
133136
134137
@@ -137,21 +140,23 @@ def kmeans(x, means, hparams, name):
137140 x_means_hot = nearest (x , means , hparams )
138141 x_means = tf .gather (means , tf .argmax (x_means_hot , axis = - 1 ))
139142 kl = tf .reduce_sum (tf .square (x - x_means ), axis = - 1 )
140- return x_means_hot , x_means_hot , tf .reduce_mean (kl ) * 10.0
143+ return x_means_hot , tf .reduce_mean (kl ) * 10.0
141144
142145
143- def compress (x , c , hparams , name ):
146+ def compress (x , c , is_2d , hparams , name ):
144147 """Compress."""
145148 with tf .variable_scope (name ):
146149 # Run compression by strided convs.
147150 cur = x
151+ k1 = (3 , 3 ) if is_2d else (3 , 1 )
152+ k2 = (2 , 2 ) if is_2d else (2 , 1 )
148153 for i in xrange (hparams .num_compress_steps ):
149154 if c is not None :
150155 cur = attend (cur , c , hparams , "compress_attend_%d" % i )
151- cur = residual_conv (cur , 1 , hparams , "compress_rc_%d" % i )
156+ cur = residual_conv (cur , 1 , k1 , hparams , "compress_rc_%d" % i )
152157 cur = common_layers .conv_block (
153- cur , hparams .hidden_size , [((1 , 1 ), ( 2 , 1 ) )],
154- strides = ( 2 , 1 ) , name = "compress_%d" % i )
158+ cur , hparams .hidden_size , [((1 , 1 ), k2 )],
159+ strides = k2 , name = "compress_%d" % i )
155160 return cur
156161
157162
@@ -188,7 +193,7 @@ def decode(cond_vec, cond_add, gold, c, ed, hparams):
188193 decoder_input = tf .squeeze (decoder_input , axis = 2 )
189194 decoder_input = common_attention .add_timing_signal_1d (decoder_input )
190195 bias = common_attention .attention_bias_lower_triangle (tf .shape (gold )[1 ])
191- if c is not None :
196+ if c is not None and len ( c . get_shape ()) > 3 :
192197 c = tf .squeeze (c , axis = 2 )
193198 return transformer .transformer_decoder (decoder_input , c , bias , ed , hparams )
194199
@@ -205,69 +210,62 @@ def expand_batch(x, mul):
205210 return tf .reshape (cx , res_shape )
206211
207212
208- def vae_compress (x , c , ed , hparams , compress_name , decompress_name , reuse = None ):
209- """Compress, then VAE ."""
210- with tf .variable_scope (compress_name , reuse = reuse ):
211- cur = compress (x , None , hparams , "compress" )
213+ def ae_compress (x , is_2d , hparams , name , reuse = None ):
214+ """Compress, then AE ."""
215+ with tf .variable_scope (name , reuse = reuse ):
216+ cur = compress (x , None , is_2d , hparams , "compress" )
212217 # Convolve and ReLu to get state.
213218 cur = common_layers .conv_block (
214219 cur , hparams .hidden_size , [((1 , 1 ), (1 , 1 ))], name = "mid_conv" )
215220 cur = tf .nn .l2_normalize (cur , dim = 3 )
216221 cur_n = hparams .kmeans_lr_factor * cur
217222 cur_n += (1.0 - hparams .kmeans_lr_factor ) * tf .stop_gradient (cur )
218223 means = tf .get_variable ("z_to_dense" , [hparams .v_size , hparams .hidden_size ])
219- # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
220- # z_true, z_sample, kl_loss = dvae(cur, hparams, name="dvae")
221- z_true , z_sample , kl_loss = kmeans (cur_n , means , hparams , name = "kmeans" )
222-
223- # Compress context.
224- with tf .variable_scope (compress_name , reuse = reuse ):
225- compress_c = compress (c , None , hparams , "compress_context" )
226- dec_c = decode (None , compress_c , cur , None , None , hparams )
227- c_z = tf .layers .dense (dec_c , hparams .v_size , name = "mask_context" )
228- reconstruct_loss = tf .nn .softmax_cross_entropy_with_logits (
229- labels = z_true , logits = c_z )
224+ hot , loss = kmeans (cur_n , means , hparams , name = "kmeans" )
225+ # We need a linear layer to undo the l2-normalization.
226+ cur = tf .layers .dense (cur , hparams .hidden_size , name = "unnormalize" )
227+ return cur , hot , loss
230228
231- # If not training, use the predicted z instead of the autoregressive one.
232- if hparams .mode == tf .contrib .learn .ModeKeys .INFER :
233- z = tf .one_hot (tf .argmax (c_z , axis = - 1 ), hparams .v_size )
234229
235- with tf .variable_scope (decompress_name , reuse = reuse ):
236- # Decompress.
237- z_sample_flat = tf .reshape (z_sample , [- 1 , hparams .v_size ])
238- z = tf .matmul (z_sample_flat , means )
239- z = tf .reshape (z , [tf .shape (z_sample )[0 ], tf .shape (z_sample )[1 ],
240- 1 , hparams .hidden_size ])
230+ def ae_embed (hot , hparams , name , reuse = None ):
231+ with tf .variable_scope (name , reuse = reuse ):
232+ means = tf .get_variable ("z_to_dense" , [hparams .v_size , hparams .hidden_size ])
233+ hot_flat = tf .reshape (hot , [- 1 , hparams .v_size ])
234+ emb = tf .matmul (hot_flat , means )
235+ emb = tf .reshape (emb , [tf .shape (hot )[0 ], tf .shape (hot )[1 ],
236+ tf .shape (hot )[2 ], hparams .hidden_size ])
237+ return tf .layers .dense (emb , hparams .hidden_size ,
238+ name = "unnormalize" , reuse = reuse )
239+
241240
241+ def ae_decompress (z , ae , x , is_2d , hparams , name , reuse = None ):
242+ """Decompress from z, leaking from ae."""
243+ with tf .variable_scope (name + "_decompress" , reuse = reuse ):
242244 # Leak at the beginning to help train.
243- z = mix (z , cur , hparams .startup_steps )
245+ z = mix (z , ae , hparams .startup_steps )
244246 prob_z = common_layers .inverse_exp_decay (hparams .startup_steps ) * 0.8
245- prob_z = prob_z if hparams .mode == tf .contrib .learn .ModeKeys .TRAIN else 0 .0
247+ prob_z = prob_z if hparams .mode == tf .contrib .learn .ModeKeys .TRAIN else 1 .0
246248 z = tf .cond (tf .less (tf .random_uniform ([]), prob_z ),
247- lambda : z , lambda : cur )
248- z = tf .layers .dense (z , hparams .hidden_size , name = "unnormalize" )
249+ lambda : z , lambda : ae )
249250
250251 # Dropout for better autoencoding.
251- z = tf .nn .dropout (z , keep_prob = 0.9 )
252+ z = tf .nn .dropout (z , keep_prob = 1.0 - hparams . z_dropout )
252253
253254 # Decompress.
254255 d = z
255256 for i in xrange (hparams .num_compress_steps ):
256257 j = hparams .num_compress_steps - i - 1
257- d = residual_conv (d , 1 , hparams , "decompress_rc_%d" % j )
258- d = decompress_step (d , c , hparams , i > 0 , "decompress_step_ %d" % j )
258+ d = residual_conv (d , 1 , ( 3 , 1 ), hparams , "decompress_rc_%d" % j )
259+ d = decompress_step (d , None , hparams , i > 0 , is_2d , "decompress_ %d" % j )
259260
260261 k = 2 ** hparams .num_compress_steps
261262 z_batch = tf .reshape (z , [- 1 , 1 , 1 , hparams .hidden_size ])
262263 x_batch = tf .reshape (x , [- 1 , k , 1 , hparams .hidden_size ])
263264 d_batch = tf .reshape (d , [- 1 , k , 1 , hparams .hidden_size ])
264- # dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
265- c = expand_batch (c , tf .shape (x_batch )[0 ] / tf .shape (x )[0 ])
266- ed = expand_batch (ed , tf .shape (x_batch )[0 ] / tf .shape (x )[0 ])
267- dec_batch = decode (z_batch , d_batch , x_batch , c , ed , hparams )
265+ dec_batch = decode (z_batch , d_batch , x_batch , None , None , hparams )
268266 z = tf .reshape (dec_batch , [- 1 , tf .shape (x )[1 ], 1 , hparams .hidden_size ])
269267
270- return z , kl_loss , reconstruct_loss
268+ return z
271269
272270
273271def ffn (x , hparams , name ):
@@ -277,35 +275,42 @@ def ffn(x, hparams, name):
277275 return common_layers .layer_postprocess (x , y , hparams )
278276
279277
280- def vae_transformer_internal (inputs , targets , target_space , hparams ):
281- """VAE Transformer, main step used for training."""
282- with tf .variable_scope ("vae_transformer" ):
283- # Prepare inputs, targets, and k.
284- inputs = common_layers .flatten4d3d (inputs )
285- input_len = tf .shape (inputs )[1 ] # Double input size to cover targets.
286- inputs = tf .pad (inputs , [[0 , 0 ], [0 , input_len ], [0 , 0 ]])
287- inputs .set_shape ([None , None , hparams .hidden_size ])
288- targets = common_layers .flatten4d3d (targets )
278+ def ae_transformer_internal (inputs , targets , target_space , hparams ):
279+ """AE Transformer, main step used for training."""
280+ with tf .variable_scope ("ae_transformer" ):
281+ # Prepare inputs, targets, k.
289282 k = 2 ** hparams .num_compress_steps
290- inputs , targets = common_layers .pad_to_same_length (
291- inputs , targets , final_length_divisible_by = k )
292- inputs , ed_bias = encode (inputs , target_space , hparams , "input_enc" )
293-
294- # Compress and vae.
295- z , kl , r = vae_compress (tf .expand_dims (targets , axis = 2 ),
296- tf .expand_dims (inputs , axis = 2 ),
297- ed_bias , hparams , "vae_compress" , "vae_decompress" )
283+ _ , targets = common_layers .pad_to_same_length (
284+ targets , targets , final_length_divisible_by = k )
285+ inputs = common_layers .flatten4d3d (inputs )
286+ inputs , ed = encode (inputs , target_space , hparams , "input_enc" )
287+
288+ # Compress and ae.
289+ ae , hot , kl = ae_compress (targets , False , hparams , "ae" )
290+ emb = ae_embed (hot , hparams , "ae" , reuse = True )
291+
292+ # Compress context and run autoregressive decoder on emb-hot.
293+ dec_c = decode (None , None , emb , inputs , ed , hparams )
294+ c_z = tf .layers .dense (dec_c , hparams .v_size , name = "mask_context" )
295+ reconstruct_loss = tf .nn .softmax_cross_entropy_with_logits (
296+ labels = hot , logits = c_z )
297+ # If not training, use the predicted z instead of the autoregressive one.
298+ if hparams .mode == tf .contrib .learn .ModeKeys .INFER :
299+ hot = tf .one_hot (tf .argmax (c_z , axis = - 1 ), hparams .v_size )
300+
301+ # Decompress, pass for ae loss.
302+ z = ae_decompress (emb , ae , targets , False , hparams , "ae" )
298303 kl *= common_layers .inverse_exp_decay (int (hparams .startup_steps * 0.5 ))
299- r *= common_layers .inverse_exp_decay (int ( hparams .startup_steps * 0.5 ) )
300- losses = {"kl" : kl , "reconstruction" : r }
304+ reconstruct_loss *= common_layers .inverse_exp_decay (hparams .startup_steps )
305+ losses = {"kl" : kl , "reconstruction" : reconstruct_loss }
301306 return z , losses
302307
303308
304309@registry .register_model
305- class TransformerVAE (t2t_model .T2TModel ):
310+ class TransformerAE (t2t_model .T2TModel ):
306311
307312 def model_fn_body (self , features ):
308- return vae_transformer_internal (
313+ return ae_transformer_internal (
309314 features ["inputs" ], features ["targets" ], features ["target_space_id" ],
310315 self ._hparams )
311316
@@ -348,7 +353,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
348353
349354
350355@registry .register_hparams
351- def transformer_vae_small ():
356+ def transformer_ae_small ():
352357 """Set of hyperparameters."""
353358 hparams = transformer .transformer_small ()
354359 hparams .batch_size = 2048
@@ -358,19 +363,20 @@ def transformer_vae_small():
358363 hparams .add_hparam ("num_compress_steps" , 4 )
359364 hparams .add_hparam ("kl_warmup_steps" , 60000 )
360365 hparams .add_hparam ("startup_steps" , 30000 )
366+ hparams .add_hparam ("kmeans_lr_factor" , 0.002 )
367+ hparams .add_hparam ("z_dropout" , 0.1 )
361368 return hparams
362369
363370
364371@registry .register_hparams
365- def transformer_vae_base ():
372+ def transformer_ae_base ():
366373 """Set of hyperparameters."""
367- hparams = transformer_vae_small ()
374+ hparams = transformer_ae_small ()
368375 hparams .hidden_size = 512
369376 hparams .filter_size = 2048
370377 hparams .attention_dropout = 0.0
371378 hparams .relu_dropout = 0.0
372379 hparams .dropout = 0.0
373380 hparams .num_hidden_layers = 4
374- hparams .kmeans_lr_factor = 0.002
375381 hparams .z_size = 256
376382 return hparams
0 commit comments