Skip to content

Commit 0d5ed41

Browse files
committed
HW5c fix (nit): make build_rnn more clear, remove unneeded arg
1 parent 5e908c2 commit 0d5ed41

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

hw5/meta/train_policy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,14 @@ def build_mlp(x, output_size, scope, n_layers, size, activation=tf.tanh, output_
5757
x = tf.layers.dense(inputs=x, units=output_size, activation=output_activation, name='fc{}'.format(i + 1), kernel_regularizer=regularizer, bias_regularizer=regularizer)
5858
return x
5959

60-
def build_rnn(x, h, output_size, scope, n_layers, size, gru_size, activation=tf.tanh, output_activation=None, regularizer=None):
60+
def build_rnn(x, h, output_size, scope, n_layers, size, activation=tf.tanh, output_activation=None, regularizer=None):
6161
"""
6262
builds a gated recurrent neural network
6363
inputs are first embedded by an MLP then passed to a GRU cell
6464
65+
make MLP layers with `size` number of units
66+
make the GRU with `output_size` number of units
67+
6568
arguments:
6669
(see `build_policy()`)
6770
@@ -96,7 +99,7 @@ def build_policy(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
9699
"""
97100
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
98101
if recurrent:
99-
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, gru_size, activation=activation, output_activation=output_activation)
102+
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, activation=activation, output_activation=output_activation)
100103
else:
101104
x = tf.reshape(x, (-1, x.get_shape()[1]*x.get_shape()[2]))
102105
x = build_mlp(x, gru_size, scope, n_layers + 1, size, activation=activation, output_activation=activation)
@@ -115,7 +118,7 @@ def build_critic(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
115118
"""
116119
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
117120
if recurrent:
118-
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, gru_size, activation=activation, output_activation=output_activation, regularizer=regularizer)
121+
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, activation=activation, output_activation=output_activation, regularizer=regularizer)
119122
else:
120123
x = tf.reshape(x, (-1, x.get_shape()[1]*x.get_shape()[2]))
121124
x = build_mlp(x, gru_size, scope, n_layers + 1, size, activation=activation, output_activation=activation, regularizer=regularizer)

0 commit comments

Comments
 (0)