Skip to content

Commit e6816d3

Browse files
authored
Merge pull request #9 from pockerman/add_actor_critic_algorithm
Add actor critic algorithm
2 parents 10e6e39 + b1039e3 commit e6816d3

16 files changed

+276
-70
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
src/preprocessor/__pycache__/
22
src/exceptions/__pycache__/
3+
src/utils/__pycache__/
4+
src/tests/.pytest_cache/
5+
src/spaces/__pycache__/
6+
src/__pycache__/
7+
src/algorithms/__pycache__/

src/__init__.py

Whitespace-only changes.

src/algorithms/a2c.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,32 @@ def forward(self, x):
4141
return pol_out, val_out
4242

4343

44+
class A2CConfig(object):
45+
"""
46+
Configuration for A2C algorithm
47+
"""
48+
49+
def __init__(self):
50+
self.gamma: float = 0.99
51+
self.tau: float = 1.2
52+
self.n_workers: int = 1
53+
self.n_iterations_per_episode: int = 100
54+
self.optimizer: Optimizer = None
55+
self.loss_function: LossFunction = None
56+
57+
4458
class A2C(Generic[Optimizer]):
4559

46-
def __init__(self, gamma: float, tau: float, n_workers: int,
47-
n_iterations: int, optimizer: Optimizer,
48-
a2c_net: A2CNet, loss_function: LossFunction):
60+
def __init__(self, config: A2CConfig, a2c_net: A2CNet):
4961

50-
self.gamma = gamma
51-
self.tau = tau
52-
self.rewards = []
53-
self.n_workers = n_workers
54-
self.n_iterations = n_iterations
55-
self.optimizer = optimizer
62+
self.gamma = config.gamma
63+
self.tau = config.tau
64+
self.n_workers = config.n_workers
65+
self.n_iterations_per_episode = config.n_iterations_per_episode
66+
self.optimizer = config.optimizer
67+
self.loss_function = config.loss_function
5668
self.a2c_net = a2c_net
57-
self.loss_function = loss_function
69+
self.rewards = []
5870
self.name = "A2C"
5971

6072
def _optimize_model(self):
@@ -80,7 +92,8 @@ def train(self, env: Env) -> None:
8092

8193
observation = time_step.observation
8294

83-
for iteration in range(1, self.n_iterations + 1):
95+
# learn over the episode
96+
for iteration in range(1, self.n_iterations_per_episode + 1):
8497

8598
# select an action
8699
action = self.select_action(env=env, observation=observation)

src/algorithms/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,4 @@ def train(self):
3333
if self.configuration["update_frequency"] % episode == 0:
3434
self.agent.update()
3535

36-
3736
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/datasets/dataset_distances.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Various utilities to calculate the distance
3+
between two datasets. All distance metrics work
4+
accumulative
5+
"""
6+
7+
from typing import TypeVar
8+
import numpy as np
9+
10+
DataSet = TypeVar("DataSet")
11+
12+
13+
def lp_distance(ds1: DataSet, ds2: DataSet, p=None):
14+
"""
15+
Compute the Lp norms between the respective columns in the given data sets.
16+
This means that the two datasets should have the same schema. It is
17+
up to the application to ensure that the calculation is meaningless
18+
:param ds1: Dataset 1
19+
:param ds2: Dataset 2
20+
:param p: The order of the norm to calculate
21+
:return: The calculated Lp-norm
22+
"""
23+
24+
assert ds1.schema == ds2.schema, "Invalid schema for datasets"
25+
26+
distances = {}
27+
cols = ds1.get_columns_names()
28+
for col in cols:
29+
30+
val1 = ds1.get_column(col_name=col)
31+
val2 = ds2.get_column(col_name=col)
32+
distances[col] = np.linalg.norm(val1 - val2, ord=p)
33+
34+
return distances, sum(distances.values())
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Utilities for calculating the information leakage
3+
for a dataset
4+
"""

src/datasets/dataset_wrapper.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,19 @@ def sample_column(self):
111111
col_idx = np.random.choice(col_names, 1)
112112
return self.get_column(col_name=col_names[col_idx])
113113

114-
def apply_transform(self, transform: Transform) -> None:
115-
pass
114+
def apply_column_transform(self, column_name: str, transform: Transform) -> None:
115+
"""
116+
Apply the given transformation on the underlying dataset
117+
:param column_name: The column to transform
118+
:param transform: The transformation to apply
119+
:return: None
120+
"""
121+
122+
# get the column
123+
column = self.get_column(col_name=column_name)
124+
column = transform.act(**{"data": column})
125+
self.ds[transform.column_name] = column
126+
116127

117128

118129

src/spaces/action_space.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,68 @@
1+
"""
2+
ActionSpace class. This is a wrapper to the discrete
3+
actions in the actions.py module
4+
"""
5+
16
from gym.spaces.discrete import Discrete
27
from src.spaces.actions import ActionBase
38

49

510
class ActionSpace(Discrete):
11+
"""
12+
ActionSpace class models a discrete action space of size n
13+
"""
614

715
def __init__(self, n: int) -> None:
816

917
super(ActionSpace, self).__init__(n=n)
18+
19+
# the list of actions the space contains
1020
self.actions = []
1121

22+
def __getitem__(self, item) -> ActionBase:
23+
"""
24+
Returns the item-th action
25+
:param item: The index of the action to return
26+
:return: An action obeject
27+
"""
28+
return self.actions[item]
29+
30+
def __setitem__(self, key: int, value: ActionBase) -> None:
31+
"""
32+
Update the key-th Action with the new value
33+
:param key: The index to the action to update
34+
:param value: The new action
35+
:return: None
36+
"""
37+
self.actions[key] = value
38+
1239
def add(self, action: ActionBase) -> None:
40+
"""
41+
Add a new action in the space
42+
:param action:
43+
:return:
44+
"""
1345

1446
if len(self.actions) >= self.n:
1547
raise ValueError("Action space is saturated. You cannot add a new action")
1648

49+
# set a valid id for the action
50+
action.idx = len(self.actions)
1751
self.actions.append(action)
1852

19-
def add_may(self, *actions) -> None:
53+
def add_many(self, *actions) -> None:
54+
"""
55+
Add many actions in one go
56+
:param actions: List of actions to add
57+
:return: None
58+
"""
2059
for a in actions:
2160
self.add(action=a)
2261

2362
def sample_and_get(self) -> ActionBase:
24-
63+
"""
64+
Sample the space and return an action to the application
65+
:return: The sampled action
66+
"""
2567
action_idx = self.sample()
2668
return self.actions[action_idx]

src/spaces/actions.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ class ActionBase(metaclass=abc.ABCMeta):
3333
Base class for actions
3434
"""
3535

36-
def __init__(self, action_type: ActionType) -> None:
36+
def __init__(self, column_name: str, action_type: ActionType) -> None:
37+
self.column_name = column_name
3738
self.action_type = action_type
39+
self.idx = None
40+
self.key = (self.column_name, self.action_type)
3841

3942
@abc.abstractmethod
4043
def act(self, **ops) -> None:
@@ -77,8 +80,8 @@ class ActionIdentity(ActionBase):
7780
Implements the identity action
7881
"""
7982

80-
def __init__(self) -> None:
81-
super(ActionIdentity, self).__init__(action_type=ActionType.IDENTITY)
83+
def __init__(self, column_name: str) -> None:
84+
super(ActionIdentity, self).__init__(column_name=column_name, action_type=ActionType.IDENTITY)
8285

8386
def act(self, **ops):
8487
"""
@@ -93,8 +96,8 @@ class ActionTransform(ActionBase):
9396
"""
9497
Implements the transform action
9598
"""
96-
def __init__(self):
97-
super(ActionTransform, self).__init__(action_type=ActionType.TRANSFORM)
99+
def __init__(self, column_name: str):
100+
super(ActionTransform, self).__init__(column_name=column_name, action_type=ActionType.TRANSFORM)
98101

99102
def act(self, **ops):
100103
"""
@@ -109,8 +112,8 @@ class ActionSuppress(ActionBase, _WithTable):
109112
"""
110113
Implements the suppress action
111114
"""
112-
def __init__(self, suppress_table=None):
113-
super(ActionSuppress, self).__init__(action_type=ActionType.SUPPRESS)
115+
def __init__(self, column_name: str, suppress_table=None):
116+
super(ActionSuppress, self).__init__(column_name=column_name, action_type=ActionType.SUPPRESS)
114117

115118
if suppress_table is not None:
116119
self.table = suppress_table
@@ -136,36 +139,47 @@ def act(self, **ops) -> None:
136139
move_next(iterators=self.iterators)
137140

138141

139-
class ActionGeneralize(ActionBase):
142+
class ActionGeneralize(ActionBase, _WithTable):
140143
"""
141144
Implements the generalization action
142145
"""
143146

144-
def __init__(self):
145-
super(ActionGeneralize, self).__init__(action_type=ActionType.GENERALIZE)
146-
self.generalization_table = {}
147+
def __init__(self, column_name: str, generalization_table: dict = None):
148+
super(ActionGeneralize, self).__init__(column_name=column_name, action_type=ActionType.GENERALIZE)
149+
150+
if generalization_table is not None:
151+
self.table = generalization_table
152+
153+
# fill in the iterators
154+
self.iterators = [iter(self.table[item]) for item in self.table]
147155

148156
def act(self, **ops):
149157
"""
150158
Perform an action
151159
:return:
152160
"""
161+
162+
# get the values of the column
163+
col_vals = ops['data'].values
164+
153165
# generalize the data given
154-
for item in ops["data"]:
166+
for i, item in enumerate(col_vals):
155167

168+
#print(item)
156169
# How do we update the generalizations?
157-
value = self.generalization_table[item].value
158-
item = value
170+
value = self.table[item].value
171+
col_vals[i] = value
159172

160-
# update the generalization
161-
self._move_next()
173+
ops["data"] = col_vals
162174

163-
def add_generalization(self, key: str, values: HierarchyBase) -> None:
164-
self.generalization_table[key] = values
175+
# update the generalization iterators
176+
# so next time we visit we update according to
177+
# the new values
178+
move_next(iterators=self.iterators)
179+
return ops['data']
165180

166-
def _move_next(self):
181+
def add_generalization(self, key: str, values: HierarchyBase) -> None:
182+
self.table[key] = values
167183

168-
for item in self.generalization_table:
169-
next(self.generalization_table[item])
170184

171185

src/spaces/environment.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import multiprocessing as mp
1313

1414
from src.exceptions.exceptions import Error
15-
from src.spaces.actions import ActionBase
15+
from src.spaces.actions import ActionBase, ActionType
1616
from src.utils.string_distance_calculator import DistanceType, TextDistanceCalculator
1717

1818
DataSet = TypeVar("DataSet")
@@ -169,6 +169,19 @@ def reset(self, **options) -> TimeStep:
169169
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
170170
return self.current_time_step
171171

172+
def apply_action(self, action: ActionBase):
173+
"""
174+
Apply the action on the environment
175+
:param action: The action to apply on the environment
176+
:return:
177+
"""
178+
179+
if action.action_type == ActionType.IDENTITY:
180+
return
181+
182+
# apply the transform of the data set
183+
self.data_set.apply_column_transform(column_name=action.column_name, transform=action)
184+
172185
def step(self, action: ActionBase) -> TimeStep:
173186
"""
174187
@@ -182,11 +195,23 @@ def step(self, action: ActionBase) -> TimeStep:
182195
`action` will be ignored.
183196
"""
184197

185-
self.data_set.apply_transform(transform=action)
198+
self.apply_action(action=action)
199+
200+
# if the action is identity don't bother
201+
# doing anything
202+
#if action.action_type == ActionType.IDENTITY:
203+
# return TimeStep(step_type=StepType.MID, reward=0.0,
204+
# observation=self.get_ds_as_tensor().float(), discount=self.gamma)
205+
206+
# apply the transform of the data set
207+
#self.data_set.apply_column_transform(transform=action)
186208

187209
# perform the action on the data set
188210
self.prepare_column_states()
189211

212+
# calculate the information leakage and establish the reward
213+
# to return to the agent
214+
190215
return TimeStep(step_type=StepType.MID, reward=0.0,
191216
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
192217

0 commit comments

Comments
 (0)