@@ -57,11 +57,14 @@ def build_mlp(x, output_size, scope, n_layers, size, activation=tf.tanh, output_
5757 x = tf .layers .dense (inputs = x , units = output_size , activation = output_activation , name = 'fc{}' .format (i + 1 ), kernel_regularizer = regularizer , bias_regularizer = regularizer )
5858 return x
5959
60- def build_rnn (x , h , output_size , scope , n_layers , size , gru_size , activation = tf .tanh , output_activation = None , regularizer = None ):
60+ def build_rnn (x , h , output_size , scope , n_layers , size , activation = tf .tanh , output_activation = None , regularizer = None ):
6161 """
6262 builds a gated recurrent neural network
6363 inputs are first embedded by an MLP then passed to a GRU cell
6464
65+ make MLP layers with `size` number of units
66+ make the GRU with `output_size` number of units
67+
6568 arguments:
6669 (see `build_policy()`)
6770
@@ -96,7 +99,7 @@ def build_policy(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
9699 """
97100 with tf .variable_scope (scope , reuse = tf .AUTO_REUSE ):
98101 if recurrent :
99- x , h = build_rnn (x , h , gru_size , scope , n_layers , size , gru_size , activation = activation , output_activation = output_activation )
102+ x , h = build_rnn (x , h , gru_size , scope , n_layers , size , activation = activation , output_activation = output_activation )
100103 else :
101104 x = tf .reshape (x , (- 1 , x .get_shape ()[1 ]* x .get_shape ()[2 ]))
102105 x = build_mlp (x , gru_size , scope , n_layers + 1 , size , activation = activation , output_activation = activation )
@@ -115,7 +118,7 @@ def build_critic(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
115118 """
116119 with tf .variable_scope (scope , reuse = tf .AUTO_REUSE ):
117120 if recurrent :
118- x , h = build_rnn (x , h , gru_size , scope , n_layers , size , gru_size , activation = activation , output_activation = output_activation , regularizer = regularizer )
121+ x , h = build_rnn (x , h , gru_size , scope , n_layers , size , activation = activation , output_activation = output_activation , regularizer = regularizer )
119122 else :
120123 x = tf .reshape (x , (- 1 , x .get_shape ()[1 ]* x .get_shape ()[2 ]))
121124 x = build_mlp (x , gru_size , scope , n_layers + 1 , size , activation = activation , output_activation = activation , regularizer = regularizer )
0 commit comments