Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions MinAtar/FAME.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import numpy as np
import pickle
import itertools
import random
import torch
import torch.nn.functional as F
import torch.optim as optim
import copy
from model import *
Expand All @@ -11,6 +13,7 @@
from argparse import ArgumentParser
from configparser import ConfigParser
from tqdm import tqdm

import os
from torch.optim.lr_scheduler import ExponentialLR
from scipy import stats
Expand Down Expand Up @@ -43,6 +46,8 @@

### one-vs-one hypothesis test
parser.add_argument('--use_ttest', type=int, default=0, help="one-vs-one hypothesis test:, 1: on, 0: off, default off")
parser.add_argument('--p_explore', type=float, default=0.0, help="p_explore for guided exploration")



args = parser.parse_args()
Expand All @@ -59,10 +64,14 @@
print("torch.cuda.is_available():", torch.cuda.is_available())
print("CUDA_VISIBLE_DEVICES:", os.getenv("CUDA_VISIBLE_DEVICES"))
print("device_count:", torch.cuda.device_count())
torch.cuda.init()
device = torch.device(f"cuda:{args.gpu}")
torch.cuda.set_device(device)
print("device_count:", torch.cuda.device_count())
if torch.cuda.is_available():
torch.cuda.init()
device = torch.device(f"cuda:{args.gpu}")
torch.cuda.set_device(device)
print("device_count:", torch.cuda.device_count())
else:
device = torch.device("cpu")


print(args)
num_envs = int(args.t_steps / args.switch) # number of environments
Expand Down Expand Up @@ -157,11 +166,13 @@ def get_action(c_obs, LEARNER): # take action by the fast learner with the envi
action = curr_Q_vals.max(1)[1].item()
return action

def get_action_exploration(c_obs, LEARNER): # when p_explore > 0, use the expert learner to guide the exploration
def get_action_exploration(c_obs, Fast_Learner, Meta_Learner, p_explore): # when p_explore > 0, use the expert learner to guide the exploration

c_obs = np.moveaxis(c_obs, 2, 0)
c_obs = torch.tensor(c_obs, dtype=torch.float).to(device)
with torch.no_grad():
curr_Q_vals = LEARNER(c_obs.unsqueeze(0))
curr_Q_vals = Fast_Learner(c_obs.unsqueeze(0))

if np.random.random() <= epsilon:
action = env.action_space.sample()
else:
Expand Down Expand Up @@ -220,7 +231,8 @@ def get_action_exploration(c_obs, LEARNER): # when p_explore > 0, use the exper
exp_replay_meta = expReplay_Meta(max_size=args.size_meta, batch_size=args.batch_size, device=device)


returns_array = np.zeros(args.t_steps)
returns_array = np.zeros(args.t_steps + args.detection_step * 2 + 1)


avg_return = 0
epi_return = 0
Expand Down Expand Up @@ -302,7 +314,8 @@ def in_intervals(x):
avereward_fast.append(epi_return_fast)
epi_return_fast = 0
max_step = 0
returns_array[step] = copy.copy(avg_return)
returns_array[step-1] = copy.copy(avg_return)

pbar.update(1)
if Num_detection_fast > 0:
if len(avereward_fast) == 0:
Expand Down Expand Up @@ -334,7 +347,8 @@ def in_intervals(x):
avereward_meta.append(epi_return_meta)
epi_return_meta = 0
max_step = 0
returns_array[step] = copy.copy(avg_return)
returns_array[step-1] = copy.copy(avg_return)

pbar.update(1)

if Num_detection_meta > 0:
Expand Down