diff --git a/MinAtar/FAME.py b/MinAtar/FAME.py index 39f0d94..35a9172 100644 --- a/MinAtar/FAME.py +++ b/MinAtar/FAME.py @@ -2,7 +2,9 @@ import numpy as np import pickle import itertools +import random import torch +import torch.nn.functional as F import torch.optim as optim import copy from model import * @@ -11,6 +13,7 @@ from argparse import ArgumentParser from configparser import ConfigParser from tqdm import tqdm + import os from torch.optim.lr_scheduler import ExponentialLR from scipy import stats @@ -43,6 +46,8 @@ ### one-vs-one hypothesis test parser.add_argument('--use_ttest', type=int, default=0, help="one-vs-one hypothesis test:, 1: on, 0: off, default off") +parser.add_argument('--p_explore', type=float, default=0.0, help="p_explore for guided exploration") + args = parser.parse_args() @@ -59,10 +64,14 @@ print("torch.cuda.is_available():", torch.cuda.is_available()) print("CUDA_VISIBLE_DEVICES:", os.getenv("CUDA_VISIBLE_DEVICES")) print("device_count:", torch.cuda.device_count()) -torch.cuda.init() -device = torch.device(f"cuda:{args.gpu}") -torch.cuda.set_device(device) -print("device_count:", torch.cuda.device_count()) +if torch.cuda.is_available(): + torch.cuda.init() + device = torch.device(f"cuda:{args.gpu}") + torch.cuda.set_device(device) + print("device_count:", torch.cuda.device_count()) +else: + device = torch.device("cpu") + print(args) num_envs = int(args.t_steps / args.switch) # number of environments @@ -157,11 +166,13 @@ def get_action(c_obs, LEARNER): # take action by the fast learner with the envi action = curr_Q_vals.max(1)[1].item() return action -def get_action_exploration(c_obs, LEARNER): # when p_explore > 0, use the expert learner to guide the exploration +def get_action_exploration(c_obs, Fast_Learner, Meta_Learner, p_explore): # when p_explore > 0, use the expert learner to guide the exploration + c_obs = np.moveaxis(c_obs, 2, 0) c_obs = torch.tensor(c_obs, dtype=torch.float).to(device) with torch.no_grad(): - curr_Q_vals = LEARNER(c_obs.unsqueeze(0)) + curr_Q_vals = Fast_Learner(c_obs.unsqueeze(0)) + if np.random.random() <= epsilon: action = env.action_space.sample() else: @@ -220,7 +231,8 @@ def get_action_exploration(c_obs, LEARNER): # when p_explore > 0, use the exper exp_replay_meta = expReplay_Meta(max_size=args.size_meta, batch_size=args.batch_size, device=device) -returns_array = np.zeros(args.t_steps) +returns_array = np.zeros(args.t_steps + args.detection_step * 2 + 1) + avg_return = 0 epi_return = 0 @@ -302,7 +314,8 @@ def in_intervals(x): avereward_fast.append(epi_return_fast) epi_return_fast = 0 max_step = 0 - returns_array[step] = copy.copy(avg_return) + returns_array[step-1] = copy.copy(avg_return) + pbar.update(1) if Num_detection_fast > 0: if len(avereward_fast) == 0: @@ -334,7 +347,8 @@ def in_intervals(x): avereward_meta.append(epi_return_meta) epi_return_meta = 0 max_step = 0 - returns_array[step] = copy.copy(avg_return) + returns_array[step-1] = copy.copy(avg_return) + pbar.update(1) if Num_detection_meta > 0: