@@ -41,7 +41,7 @@ def _rev_layer_forward(xs, f, g):
4141 y1 = x1 + f (x2 )
4242 with tf .variable_scope ("g" ):
4343 y2 = x2 + g (y1 )
44- return ( y1 , y2 )
44+ return tf . tuple ([ y1 , y2 ] )
4545
4646
4747def _rev_layer_backward (ys , grad_ys , f , g , f_vars , g_vars ):
@@ -65,17 +65,26 @@ def _rev_layer_backward(ys, grad_ys, f, g, f_vars, g_vars):
6565
6666 # Compute gradients wrt to inputs
6767 # dL/dy2 * dG(y1)/y1
68- grad_gy1_y2 = tf .gradients (gy1 , y1_stop , grad_y2 )[0 ]
68+ grad_gy1_y2 = tf .gradients (gy1 , y1_stop , grad_y2 , gate_gradients = True )[0 ]
6969 grad_x1 = grad_y1 + grad_gy1_y2
70- grad_x2 = (tf .gradients (fx2 , x2_stop , grad_y1 )[0 ] + grad_y2 + tf .gradients (
71- fx2 , x2_stop , grad_gy1_y2 )[0 ])
70+ grad_x2 = (
71+ tf .gradients (fx2 , x2_stop , grad_y1 , gate_gradients = True )[0 ] + grad_y2 +
72+ tf .gradients (fx2 , x2_stop , grad_gy1_y2 , gate_gradients = True )[0 ])
7273
7374 # Compute gradients wrt to vars in f and g
74- grad_g_vars = tf .gradients (gy1 , g_vars , grad_y2 )
75- grad_f_y1 = tf .gradients (fx2 , f_vars , grad_y1 )
76- grad_f_y2 = tf .gradients (fx2 , f_vars , grad_gy1_y2 )
75+ grad_g_vars = tf .gradients (gy1 , g_vars , grad_y2 , gate_gradients = True )
76+ grad_f_y1 = tf .gradients (fx2 , f_vars , grad_y1 , gate_gradients = True )
77+ grad_f_y2 = tf .gradients (fx2 , f_vars , grad_gy1_y2 , gate_gradients = True )
7778 grad_f_vars = [tf .add_n (grads ) for grads in zip (grad_f_y1 , grad_f_y2 )]
7879
80+ # Put returns in a tuple to ensure a constant memory budget (i.e. don't want
81+ # the subsequent layer to start computing and consuming memory based on a
82+ # subset of these values).
83+ outs = tf .tuple ([x1 , x2 , grad_x1 , grad_x2 ] + grad_f_vars + grad_g_vars )
84+ x1 , x2 , grad_x1 , grad_x2 = outs [:4 ]
85+ grad_f_vars = outs [4 :4 + len (grad_f_vars )]
86+ grad_g_vars = outs [4 + len (grad_f_vars ):]
87+
7988 return (x1 , x2 ), (grad_x1 , grad_x2 ), grad_f_vars , grad_g_vars
8089
8190
0 commit comments