Skip to content

Commit ede0cd1

Browse files
authored
Merge pull request #1 from pockerman/add_actor_critic_algorithm
Add actor critic algorithm
2 parents 5e07d3b + 2c25650 commit ede0cd1

27 files changed

+240
-106
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
preprocessor/__pycache__/
1+
src/preprocessor/__pycache__/
2+
src/exceptions/__pycache__/

algorithms/a2c.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

src/algorithms/a2c.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import numpy as np
2+
from typing import TypeVar, Generic
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
Env = TypeVar("Env")
8+
Optimizer = TypeVar("Optimizer")
9+
LossFunction = TypeVar("LossFunction")
10+
State = TypeVar("State")
11+
Action = TypeVar("Action")
12+
TimeStep = TypeVar("TimeStep")
13+
14+
15+
class A2CNetBase(nn.Module):
16+
"""
17+
Base class for A2C networks
18+
"""
19+
20+
def __init__(self, architecture):
21+
super(A2CNetBase, self).__init__()
22+
self.architecture = architecture
23+
24+
def forward(self, x):
25+
return self.architecture(x)
26+
27+
28+
class A2CNet(nn.Module):
29+
30+
def __init__(self, common_net: A2CNetBase, policy_net: A2CNetBase, value_net: A2CNetBase):
31+
super(A2CNet, self).__init__()
32+
self.common_net = common_net
33+
self.policy_net = policy_net
34+
self.value_net = value_net
35+
36+
def forward(self, x):
37+
x = self.common_net(x)
38+
39+
pol_out = self.policy_net(x)
40+
val_out = self.value_net(x)
41+
return pol_out, val_out
42+
43+
44+
class A2C(Generic[Optimizer]):
45+
46+
def __init__(self, gamma: float, tau: float, n_workers: int,
47+
n_iterations: int, optimizer: Optimizer,
48+
a2c_net: A2CNet, loss_function: LossFunction):
49+
50+
self.gamma = gamma
51+
self.tau = tau
52+
self.rewards = []
53+
self.n_workers = n_workers
54+
self.n_iterations = n_iterations
55+
self.optimizer = optimizer
56+
self.a2c_net = a2c_net
57+
self.loss_function = loss_function
58+
self.name = "A2C"
59+
60+
def _optimize_model(self):
61+
pass
62+
63+
def select_action(self, env: Env, observation: State) -> Action:
64+
"""
65+
Select an action
66+
:param env: The environment over which the agent is trained
67+
:param observation: The current observation of the environment
68+
:return: Returns an Action type
69+
"""
70+
return env.sample_action()
71+
72+
def update(self):
73+
pass
74+
75+
def train(self, env: Env) -> None:
76+
77+
# reset the environment and obtain the
78+
# the time step
79+
time_step: TimeStep = env.reset()
80+
81+
observation = time_step.observation
82+
83+
for iteration in range(1, self.n_iterations + 1):
84+
85+
# select an action
86+
action = self.select_action(env=env, observation=observation)
87+
88+
# step in the environment according
89+
# to the selected action
90+
next_time_step = env.step(action=action)
91+
92+
# we reached the end of the episode
93+
if next_time_step.last():
94+
break
95+
96+
next_state = next_time_step.observation
97+
policy_val, v_val = self.a2c_net.forward(x=next_state)
98+
self._optimize_model()
99+
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import ray.rllib.agents.a3c as a3c
77
from ray.tune.logger import pretty_print
88
from ray.rllib.env.env_context import EnvContext
9-
from spaces.environment import TimeStep, StepType
10-
from spaces.observation_space import ObsSpace
9+
from src.spaces.environment import TimeStep, StepType
10+
from src.spaces.observation_space import ObsSpace
1111

1212

1313
class DataSetEnv(gym.Env):
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Trainer
33
"""
44

5-
from utils import INFO
5+
from src.utils import INFO
66
from typing import TypeVar
77

88
Env = TypeVar("Env")
@@ -29,4 +29,9 @@ def train(self):
2929
# train for a number of iterations
3030
self.agent.train(self.env)
3131

32+
# is it time to update the model?
33+
if self.configuration["update_frequency"] % episode == 0:
34+
self.agent.update()
35+
36+
3237
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))
Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import pandas as pd
55
import numpy as np
66

7-
from preprocessor.cleanup_utils import read_csv, replace, change_column_types
7+
from src.preprocessor.cleanup_utils import read_csv, replace, change_column_types
88

99
DS = TypeVar("DS")
1010
HierarchyBase = TypeVar('HierarchyBase')
11+
Transform = TypeVar("Transform")
1112

1213

1314
class DSWrapper(Generic[DS], metaclass=abc.ABCMeta):
@@ -42,21 +43,26 @@ def __init__(self, columns: dir) -> None:
4243
# on each column in the dataset
4344
self.column_hierarchy = {}
4445

46+
@property
4547
def n_rows(self) -> int:
4648
"""
4749
Returns the number of rows of the data set
4850
:return:
4951
"""
50-
5152
return self.ds.shape[0]
5253

54+
@property
5355
def n_columns(self) -> int:
5456
"""
5557
Returns the number of rows of the data set
5658
:return:
5759
"""
5860
return self.ds.shape[1]
5961

62+
@property
63+
def schema(self) -> dict:
64+
return pd.io.json.build_table_schema(self.ds)
65+
6066
def read(self, filename: Path, **options) -> None:
6167
"""
6268
Load a data set from a file
@@ -72,7 +78,7 @@ def read(self, filename: Path, **options) -> None:
7278
self.ds = replace(ds=self.ds, options=options["change_col_vals"])
7379

7480
# try to cast to the data types
75-
self.ds = change_column_types(ds=self.ds, column_types=self.columns)
81+
self.ds = change_column_types(ds=self.ds, column_types=self.columns)
7682

7783
def set_columns_to_type(self, col_name_types) -> None:
7884
self.ds.astype(dtype=col_name_types)
@@ -88,7 +94,6 @@ def get_column_unique_values(self, col_name: str):
8894

8995
col = self.get_column(col_name=col_name)
9096
vals = col.values.ravel()
91-
9297
return pd.unique(vals)
9398

9499
def get_columns_types(self):
@@ -106,3 +111,9 @@ def sample_column(self):
106111
col_idx = np.random.choice(col_names, 1)
107112
return self.get_column(col_name=col_names[col_idx])
108113

114+
def apply_transform(self, transform: Transform) -> None:
115+
pass
116+
117+
118+
119+

src/datasets/datasets_loaders.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from pathlib import Path
2+
from src.datasets.dataset_wrapper import PandasDSWrapper
3+
4+
5+
class MockSubjectsLoader(PandasDSWrapper):
6+
7+
DEFAULT_COLUMNS = {"gender": str, "ethnicity": str, "education": int,
8+
"salary": int, "diagnosis": int, "preventative_treatment": str,
9+
"mutation_status": int, }
10+
11+
FILENAME = Path("../../data/mocksubjects.csv")
12+
13+
FEATURES_DROP_NAMES = ["NHSno", "given_name", "surname", "dob"]
14+
15+
NAMES = ["NHSno", "given_name", "surname", "gender",
16+
"dob", "ethnicity", "education", "salary",
17+
"mutation_status", "preventative_treatment", "diagnosis"]
18+
19+
DROP_NA = True
20+
21+
CHANGE_COLS_VALS = {"diagnosis": [('N', 0)]}
22+
23+
def __init__(self):
24+
super(MockSubjectsLoader, self).__init__(columns=MockSubjectsLoader.DEFAULT_COLUMNS)
25+
self.read(filename=MockSubjectsLoader.FILENAME, **{"features_drop_names": MockSubjectsLoader.FEATURES_DROP_NAMES,
26+
"names": MockSubjectsLoader.NAMES,
27+
"drop_na": MockSubjectsLoader.DROP_NA,
28+
"change_col_vals": MockSubjectsLoader.CHANGE_COLS_VALS})

0 commit comments

Comments
 (0)