Skip to content

Commit d2dedd1

Browse files
committed
HW5c fix: critic arch should match policy
also remove debugging print statement
1 parent ee7e7a4 commit d2dedd1

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

hw5/meta/train_policy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def build_computation_graph(self):
308308

309309
# PPO critic update
310310
critic_regularizer = tf.contrib.layers.l2_regularizer(1e-3) if self.l2reg else None
311-
self.critic_prediction = tf.squeeze(build_critic(self.sy_ob_no, self.sy_hidden, 1, 'critic_network', n_layers=self.n_layers, size=self.size, gru_size=self.gru_size, regularizer=critic_regularizer))
311+
self.critic_prediction = tf.squeeze(build_critic(self.sy_ob_no, self.sy_hidden, 1, 'critic_network', n_layers=self.n_layers, size=self.size, gru_size=self.gru_size, recurrent=self.recurrent, regularizer=critic_regularizer))
312312
self.sy_target_n = tf.placeholder(shape=[None], name="critic_target", dtype=tf.float32)
313313
self.critic_loss = tf.losses.mean_squared_error(self.sy_target_n, self.critic_prediction)
314314
self.critic_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='critic_network')
@@ -715,7 +715,6 @@ def unpack_sample(data):
715715

716716
log_probs = agent.sess.run(agent.sy_lp_n,
717717
feed_dict={agent.sy_ob_no: ob_no, agent.sy_hidden: hidden, agent.sy_ac_na: ac_na})
718-
print('new log prob', log_probs.shape)
719718

720719
agent.update_parameters(ob_no, hidden, ac_na, fixed_log_probs, q_n, adv_n)
721720

0 commit comments

Comments
 (0)