Skip to content

Commit 1b54a81

Browse files
committed
Update API
1 parent 90cc550 commit 1b54a81

File tree

10 files changed

+141
-46
lines changed

10 files changed

+141
-46
lines changed

src/algorithms/a2c.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,32 @@ def forward(self, x):
4141
return pol_out, val_out
4242

4343

44+
class A2CConfig(object):
45+
"""
46+
Configuration for A2C algorithm
47+
"""
48+
49+
def __init__(self):
50+
self.gamma: float = 0.99
51+
self.tau: float = 1.2
52+
self.n_workers: int = 1
53+
self.n_iterations_per_episode: int = 100
54+
self.optimizer: Optimizer = None
55+
self.loss_function: LossFunction = None
56+
57+
4458
class A2C(Generic[Optimizer]):
4559

46-
def __init__(self, gamma: float, tau: float, n_workers: int,
47-
n_iterations: int, optimizer: Optimizer,
48-
a2c_net: A2CNet, loss_function: LossFunction):
60+
def __init__(self, config: A2CConfig, a2c_net: A2CNet):
4961

50-
self.gamma = gamma
51-
self.tau = tau
52-
self.rewards = []
53-
self.n_workers = n_workers
54-
self.n_iterations = n_iterations
55-
self.optimizer = optimizer
62+
self.gamma = config.gamma
63+
self.tau = config.tau
64+
self.n_workers = config.n_workers
65+
self.n_iterations_per_episode = config.n_iterations_per_episode
66+
self.optimizer = config.optimizer
67+
self.loss_function = config.loss_function
5668
self.a2c_net = a2c_net
57-
self.loss_function = loss_function
69+
self.rewards = []
5870
self.name = "A2C"
5971

6072
def _optimize_model(self):
@@ -80,7 +92,8 @@ def train(self, env: Env) -> None:
8092

8193
observation = time_step.observation
8294

83-
for iteration in range(1, self.n_iterations + 1):
95+
# learn over the episode
96+
for iteration in range(1, self.n_iterations_per_episode + 1):
8497

8598
# select an action
8699
action = self.select_action(env=env, observation=observation)

src/algorithms/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,4 @@ def train(self):
3333
if self.configuration["update_frequency"] % episode == 0:
3434
self.agent.update()
3535

36-
3736
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/datasets/dataset_wrapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,12 @@ def sample_column(self):
112112
return self.get_column(col_name=col_names[col_idx])
113113

114114
def apply_transform(self, transform: Transform) -> None:
115-
pass
115+
116+
# get the column
117+
column = self.get_column(col_name=transform.column_name)
118+
column = transform.act(**{"data": column})
119+
self.ds[transform.column_name] = column
120+
116121

117122

118123

src/spaces/action_space.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@ def __init__(self, n: int) -> None:
99
super(ActionSpace, self).__init__(n=n)
1010
self.actions = []
1111

12+
def __getitem__(self, item):
13+
return self.actions[item]
14+
1215
def add(self, action: ActionBase) -> None:
1316

1417
if len(self.actions) >= self.n:
1518
raise ValueError("Action space is saturated. You cannot add a new action")
1619

20+
# set a valid id for the action
21+
action.idx = len(self.actions)
1722
self.actions.append(action)
1823

1924
def add_may(self, *actions) -> None:

src/spaces/actions.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ class ActionBase(metaclass=abc.ABCMeta):
3333
Base class for actions
3434
"""
3535

36-
def __init__(self, action_type: ActionType) -> None:
36+
def __init__(self, column_name: str, action_type: ActionType) -> None:
37+
self.column_name = column_name
3738
self.action_type = action_type
39+
self.idx = None
40+
self.key = (self.column_name, self.action_type)
3841

3942
@abc.abstractmethod
4043
def act(self, **ops) -> None:
@@ -77,8 +80,8 @@ class ActionIdentity(ActionBase):
7780
Implements the identity action
7881
"""
7982

80-
def __init__(self) -> None:
81-
super(ActionIdentity, self).__init__(action_type=ActionType.IDENTITY)
83+
def __init__(self, column_name: str) -> None:
84+
super(ActionIdentity, self).__init__(column_name=column_name, action_type=ActionType.IDENTITY)
8285

8386
def act(self, **ops):
8487
"""
@@ -93,8 +96,8 @@ class ActionTransform(ActionBase):
9396
"""
9497
Implements the transform action
9598
"""
96-
def __init__(self):
97-
super(ActionTransform, self).__init__(action_type=ActionType.TRANSFORM)
99+
def __init__(self, column_name: str):
100+
super(ActionTransform, self).__init__(column_name=column_name, action_type=ActionType.TRANSFORM)
98101

99102
def act(self, **ops):
100103
"""
@@ -109,8 +112,8 @@ class ActionSuppress(ActionBase, _WithTable):
109112
"""
110113
Implements the suppress action
111114
"""
112-
def __init__(self, suppress_table=None):
113-
super(ActionSuppress, self).__init__(action_type=ActionType.SUPPRESS)
115+
def __init__(self, column_name: str, suppress_table=None):
116+
super(ActionSuppress, self).__init__(column_name=column_name, action_type=ActionType.SUPPRESS)
114117

115118
if suppress_table is not None:
116119
self.table = suppress_table
@@ -136,36 +139,47 @@ def act(self, **ops) -> None:
136139
move_next(iterators=self.iterators)
137140

138141

139-
class ActionGeneralize(ActionBase):
142+
class ActionGeneralize(ActionBase, _WithTable):
140143
"""
141144
Implements the generalization action
142145
"""
143146

144-
def __init__(self):
145-
super(ActionGeneralize, self).__init__(action_type=ActionType.GENERALIZE)
146-
self.generalization_table = {}
147+
def __init__(self, column_name: str, generalization_table: dict = None):
148+
super(ActionGeneralize, self).__init__(column_name=column_name, action_type=ActionType.GENERALIZE)
149+
150+
if generalization_table is not None:
151+
self.table = generalization_table
152+
153+
# fill in the iterators
154+
self.iterators = [iter(self.table[item]) for item in self.table]
147155

148156
def act(self, **ops):
149157
"""
150158
Perform an action
151159
:return:
152160
"""
161+
162+
# get the values of the column
163+
col_vals = ops['data'].values
164+
153165
# generalize the data given
154-
for item in ops["data"]:
166+
for i, item in enumerate(col_vals):
155167

168+
#print(item)
156169
# How do we update the generalizations?
157-
value = self.generalization_table[item].value
158-
item = value
170+
value = self.table[item].value
171+
col_vals[i] = value
159172

160-
# update the generalization
161-
self._move_next()
173+
ops["data"] = col_vals
162174

163-
def add_generalization(self, key: str, values: HierarchyBase) -> None:
164-
self.generalization_table[key] = values
175+
# update the generalization iterators
176+
# so next time we visit we update according to
177+
# the new values
178+
move_next(iterators=self.iterators)
179+
return ops['data']
165180

166-
def _move_next(self):
181+
def add_generalization(self, key: str, values: HierarchyBase) -> None:
182+
self.table[key] = values
167183

168-
for item in self.generalization_table:
169-
next(self.generalization_table[item])
170184

171185

src/spaces/environment.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import multiprocessing as mp
1313

1414
from src.exceptions.exceptions import Error
15-
from src.spaces.actions import ActionBase
15+
from src.spaces.actions import ActionBase, ActionType
1616
from src.utils.string_distance_calculator import DistanceType, TextDistanceCalculator
1717

1818
DataSet = TypeVar("DataSet")
@@ -169,6 +169,15 @@ def reset(self, **options) -> TimeStep:
169169
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
170170
return self.current_time_step
171171

172+
def apply_action(self, action: ActionBase):
173+
"""
174+
Apply the action on the environment
175+
:param action: The action to apply on the environment
176+
:return:
177+
"""
178+
# apply the transform of the data set
179+
self.data_set.apply_transform(transform=action)
180+
172181
def step(self, action: ActionBase) -> TimeStep:
173182
"""
174183
@@ -182,11 +191,21 @@ def step(self, action: ActionBase) -> TimeStep:
182191
`action` will be ignored.
183192
"""
184193

194+
# if the action is identity don't bother
195+
# doing anything
196+
if action.action_type == ActionType.IDENTITY:
197+
return TimeStep(step_type=StepType.MID, reward=0.0,
198+
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
199+
200+
# apply the transform of the data set
185201
self.data_set.apply_transform(transform=action)
186202

187203
# perform the action on the data set
188204
self.prepare_column_states()
189205

206+
# calculate the information leakage and establish the reward
207+
# to return to the agent
208+
190209
return TimeStep(step_type=StepType.MID, reward=0.0,
191210
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
192211

src/tests/test_actions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
from src.utils.default_hierarchy import DefaultHierarchy
4-
from src.spaces import ActionSuppress
4+
from src.spaces.actions import ActionSuppress
55

66

77
class TestActions(unittest.TestCase):
@@ -11,7 +11,7 @@ def test_suppress_action_creation(self):
1111
suppress_table = {"test": DefaultHierarchy(values=["test", "tes*", "te**", "t***", "****"]),
1212
"do_not_test": DefaultHierarchy(values=["do_not_test", "do_not_tes*", "do_not_te**", "do_not_t***", "do_not_****"])}
1313

14-
suppress_action = ActionSuppress(suppress_table=suppress_table)
14+
suppress_action = ActionSuppress(column_name="none", suppress_table=suppress_table)
1515

1616
self.assertEqual(len(suppress_action.table), 2, "Invalid table size")
1717

@@ -23,7 +23,7 @@ def test_suppress_action_act(self):
2323
"do_not_test": DefaultHierarchy(values=["do_not_test", "do_not_tes*",
2424
"do_not_te**", "do_not_t***", "do_not_****"])}
2525

26-
suppress_action = ActionSuppress(suppress_table=suppress_table)
26+
suppress_action = ActionSuppress(column_name="none", suppress_table=suppress_table)
2727

2828
suppress_action.act(**{"data": data})
2929

src/tests/test_environment.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from src.spaces.environment import Environment
77
from src.spaces.action_space import ActionSpace
8+
from src.spaces.actions import ActionSuppress, ActionGeneralize
89
from src.exceptions.exceptions import Error
10+
from src.utils.default_hierarchy import DefaultHierarchy
911
from src.utils.string_distance_calculator import DistanceType
1012
from src.datasets.dataset_wrapper import PandasDSWrapper
1113

@@ -33,7 +35,7 @@ def setUp(self) -> None:
3335
"drop_na": True,
3436
"change_col_vals": {"diagnosis": [('N', 0)]}})
3537

36-
#@pytest.mark.skip(reason="no way of currently testing this")
38+
@pytest.mark.skip(reason="no way of currently testing this")
3739
def test_prepare_column_states_throw_Error(self):
3840
# specify the action space. We need to establish how these actions
3941
# are performed
@@ -45,7 +47,7 @@ def test_prepare_column_states_throw_Error(self):
4547
with pytest.raises(Error):
4648
env.prepare_column_states()
4749

48-
#@pytest.mark.skip(reason="no way of currently testing this")
50+
@pytest.mark.skip(reason="no way of currently testing this")
4951
def test_prepare_column_states(self):
5052
# specify the action space. We need to establish how these actions
5153
# are performed
@@ -57,6 +59,7 @@ def test_prepare_column_states(self):
5759
env.initialize_text_distances(distance_type=DistanceType.COSINE)
5860
env.prepare_column_states()
5961

62+
@pytest.mark.skip(reason="no way of currently testing this")
6063
def test_get_numeric_ds(self):
6164
# specify the action space. We need to establish how these actions
6265
# are performed
@@ -74,12 +77,49 @@ def test_get_numeric_ds(self):
7477
shape0 = tensor.size(dim=0)
7578
shape1 = tensor.size(dim=1)
7679

77-
self.assertEqual(shape0, env.start_ds.n_rows())
78-
self.assertEqual(shape1, env.start_ds.n_columns())
80+
self.assertEqual(shape0, env.start_ds.n_rows)
81+
self.assertEqual(shape1, env.start_ds.n_columns)
7982

83+
def test_apply_action(self):
84+
# specify the action space. We need to establish how these actions
85+
# are performed
86+
action_space = ActionSpace(n=1)
87+
88+
generalization_table = {"Mixed White/Asian": DefaultHierarchy(values=["Mixed", ]),
89+
"Chinese": DefaultHierarchy(values=["Asian", ]),
90+
"Indian": DefaultHierarchy(values=["Asian", ]),
91+
"Mixed White/Black African": DefaultHierarchy(values=["Mixed", ]),
92+
"Black African": DefaultHierarchy(values=["Black", ]),
93+
"Asian other": DefaultHierarchy(values=["Asian", ]),
94+
"Black other": DefaultHierarchy(values=["Black", ]),
95+
"Mixed White/Black Caribbean": DefaultHierarchy(values=["Mixed", ]),
96+
"Mixed other": DefaultHierarchy(values=["Mixed", ]),
97+
"Arab": DefaultHierarchy(values=["Asian", ]),
98+
"White Irish": DefaultHierarchy(values=["White", ]),
99+
"Not stated": DefaultHierarchy(values=["Not stated"]),
100+
"White Gypsy/Traveller": DefaultHierarchy(values=["White", ]),
101+
"White British": DefaultHierarchy(values=["White", ]),
102+
"Bangladeshi": DefaultHierarchy(values=["Asian", ]),
103+
"White other": DefaultHierarchy(values=["White", ]),
104+
"Black Caribbean": DefaultHierarchy(values=["Black", ]),
105+
"Pakistani": DefaultHierarchy(values=["Asian", ])}
106+
107+
action_space.add(ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
108+
109+
# create the environment and
110+
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
111+
112+
# this will update the environment
113+
env.apply_action(action=action_space[0])
80114

115+
# test that the ethnicity column has been changed
116+
# get the unique values for the ethnicity column
117+
unique_col_vals = env.data_set.get_column_unique_values(col_name="ethnicity")
81118

119+
print(unique_col_vals)
82120

121+
unique_vals = ["Mixed", "Asian", "Not stated", "White", "Black"]
122+
self.assertEqual(len(unique_vals), len(unique_col_vals))
83123

84124
if __name__ == '__main__':
85125
unittest.main()

src/utils/hierarchy_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
from typing import TypeVar
44

55

6-
HierarchyBase = TypeVar("HierarchyBase")
6+
#HierarchyBase = TypeVar("HierarchyBase")
77

88

99
class HierarchyBase(metaclass=abc.ABCMeta):
1010

1111
def __init__(self):
1212
pass
1313

14-
@abc.abstractmethod
15-
def read_from(self, filename: Path) -> HierarchyBase:
14+
#@abc.abstractmethod
15+
#def read_from(self, filename: Path) -> HierarchyBase:
1616
"""
1717
Reads the values of the hierarchy from the file
1818
:param filename: The file to read the values of the hierarchy

src/utils/string_distance_calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import textdistance
33
import enum
4-
from src.exceptions import Error
4+
from src.exceptions.exceptions import Error
55

66

77
class DistanceType(enum.IntEnum):

0 commit comments

Comments
 (0)