@@ -77,7 +77,8 @@ def run_binary_modules(modules, cur1, cur2, hparams):
7777 """Run binary modules."""
7878 selection_var = tf .get_variable ("selection" , [len (modules )],
7979 initializer = tf .zeros_initializer ())
80- inv_t = 100.0 * common_layers .inverse_exp_decay (100000 , min_value = 0.01 )
80+ inv_t = 100.0 * common_layers .inverse_exp_decay (
81+ hparams .anneal_until , min_value = 0.01 )
8182 selected_weights = tf .nn .softmax (selection_var * inv_t )
8283 all_res = [modules [n ](cur1 , cur2 , hparams ) for n in xrange (len (modules ))]
8384 all_res = tf .concat ([tf .expand_dims (r , axis = 0 ) for r in all_res ], axis = 0 )
@@ -89,7 +90,8 @@ def run_unary_modules_basic(modules, cur, hparams):
8990 """Run unary modules."""
9091 selection_var = tf .get_variable ("selection" , [len (modules )],
9192 initializer = tf .zeros_initializer ())
92- inv_t = 100.0 * common_layers .inverse_exp_decay (100000 , min_value = 0.01 )
93+ inv_t = 100.0 * common_layers .inverse_exp_decay (
94+ hparams .anneal_until , min_value = 0.01 )
9395 selected_weights = tf .nn .softmax (selection_var * inv_t )
9496 all_res = [modules [n ](cur , hparams ) for n in xrange (len (modules ))]
9597 all_res = tf .concat ([tf .expand_dims (r , axis = 0 ) for r in all_res ], axis = 0 )
@@ -109,7 +111,8 @@ def run_unary_modules_sample(modules, cur, hparams, k):
109111 lambda : tf .zeros_like (cur ),
110112 lambda i = n : modules [i ](cur , hparams ))
111113 for n in xrange (len (modules ))]
112- inv_t = 100.0 * common_layers .inverse_exp_decay (100000 , min_value = 0.01 )
114+ inv_t = 100.0 * common_layers .inverse_exp_decay (
115+ hparams .anneal_until , min_value = 0.01 )
113116 selected_weights = tf .nn .softmax (selection_var * inv_t - 1e9 * (1.0 - to_run ))
114117 all_res = tf .concat ([tf .expand_dims (r , axis = 0 ) for r in all_res ], axis = 0 )
115118 res = all_res * tf .reshape (selected_weights , [- 1 , 1 , 1 , 1 , 1 ])
@@ -122,6 +125,14 @@ def run_unary_modules(modules, cur, hparams):
122125 return run_unary_modules_sample (modules , cur , hparams , 4 )
123126
124127
128+ def batch_deviation (x ):
129+ """Average deviation of the batch."""
130+ x_mean = tf .reduce_mean (x , axis = [0 ], keep_dims = True )
131+ x_variance = tf .reduce_mean (
132+ tf .square (x - x_mean ), axis = [0 ], keep_dims = True )
133+ return tf .reduce_mean (tf .sqrt (x_variance ))
134+
135+
125136@registry .register_model
126137class BlueNet (t2t_model .T2TModel ):
127138
@@ -153,14 +164,15 @@ def run_unary(x, name):
153164 with tf .variable_scope ("conv" ):
154165 x = run_unary_modules (conv_modules , x , hparams )
155166 x .set_shape (x_shape )
156- return x
167+ return tf . nn . dropout ( x , 1.0 - hparams . dropout ), batch_deviation ( x )
157168
158- cur1 , cur2 = inputs , inputs
169+ cur1 , cur2 , extra_loss = inputs , inputs , 0.0
159170 cur_shape = inputs .get_shape ()
160171 for i in xrange (hparams .num_hidden_layers ):
161172 with tf .variable_scope ("layer_%d" % i ):
162- cur1 = run_unary (cur1 , "unary1" )
163- cur2 = run_unary (cur2 , "unary2" )
173+ cur1 , loss1 = run_unary (cur1 , "unary1" )
174+ cur2 , loss2 = run_unary (cur2 , "unary2" )
175+ extra_loss += (loss1 + loss2 ) / float (hparams .num_hidden_layers )
164176 with tf .variable_scope ("binary1" ):
165177 next1 = run_binary_modules (binary_modules , cur1 , cur2 , hparams )
166178 next1 .set_shape (cur_shape )
@@ -169,7 +181,9 @@ def run_unary(x, name):
169181 next2 .set_shape (cur_shape )
170182 cur1 , cur2 = next1 , next2
171183
172- return cur1
184+ anneal = common_layers .inverse_exp_decay (hparams .anneal_until )
185+ extra_loss *= hparams .batch_deviation_loss_factor * anneal
186+ return cur1 , extra_loss
173187
174188
175189@registry .register_hparams
@@ -185,7 +199,7 @@ def bluenet_base():
185199 hparams .num_hidden_layers = 8
186200 hparams .kernel_height = 3
187201 hparams .kernel_width = 3
188- hparams .learning_rate_decay_scheme = "exp50k "
202+ hparams .learning_rate_decay_scheme = "exp10k "
189203 hparams .learning_rate = 0.05
190204 hparams .learning_rate_warmup_steps = 3000
191205 hparams .initializer_gain = 1.0
@@ -196,6 +210,8 @@ def bluenet_base():
196210 hparams .optimizer_adam_beta1 = 0.85
197211 hparams .optimizer_adam_beta2 = 0.997
198212 hparams .add_hparam ("imagenet_use_2d" , True )
213+ hparams .add_hparam ("anneal_until" , 40000 )
214+ hparams .add_hparam ("batch_deviation_loss_factor" , 0.001 )
199215 return hparams
200216
201217
0 commit comments