Skip to content

Commit fae6e7b

Browse files
committed
HW5c fix: Problem 1 eval deterministically on all tasks
1 parent d2dedd1 commit fae6e7b

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

hw5/meta/point_mass_observed.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,23 @@ class ObservedPointEnv(Env):
1818
# YOUR CODE SOMEWHERE HERE
1919
def __init__(self, num_tasks=1):
2020
self.tasks = [0, 1, 2, 3][:num_tasks]
21+
self.task_idx = -1
2122
self.reset_task()
2223
self.reset()
2324

2425
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,))
2526
self.action_space = spaces.Box(low=-0.1, high=0.1, shape=(2,))
2627

2728
def reset_task(self, is_evaluation=False):
28-
idx = np.random.choice(len(self.tasks))
29-
self._task = self.tasks[idx]
29+
# for evaluation, cycle deterministically through all tasks
30+
if is_evaluation:
31+
self.task_idx = (self.task_idx + 1) % len(self.tasks)
32+
# during training, sample tasks randomly
33+
else:
34+
self.task_idx = np.random.randint(len(self.tasks))
35+
self._task = self.tasks[self.task_idx]
3036
goals = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
31-
self._goal = np.array(goals[idx])*10
37+
self._goal = np.array(goals[self.task_idx])*10
3238

3339
def reset(self):
3440
self._state = np.array([0, 0], dtype=np.float32)

hw5/meta/train_policy.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,11 @@ def unpack_sample(data):
689689

690690
# sample trajectories to fill agent's replay buffer
691691
print("********** Iteration %i ************"%itr)
692-
stats, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch)
693-
total_timesteps += timesteps_this_batch
692+
stats = []
693+
for _ in range(num_tasks):
694+
s, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch)
695+
total_timesteps += timesteps_this_batch
696+
stats += s
694697

695698
# compute the log probs, advantages, and returns for all data in agent's buffer
696699
# store in ppo buffer for use in multiple ppo updates
@@ -720,7 +723,10 @@ def unpack_sample(data):
720723

721724
# compute validation statistics
722725
print('Validating...')
723-
val_stats, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch // 10, is_evaluation=True)
726+
val_stats = []
727+
for _ in range(num_tasks):
728+
vs, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch // 10, is_evaluation=True)
729+
val_stats += vs
724730

725731
# save trajectories for viz
726732
with open("output/{}-epoch{}.pkl".format(exp_name, itr), 'wb') as f:

0 commit comments

Comments
 (0)