Skip to content

Commit 569b436

Browse files
committed
#20 Add function and update API
1 parent bc4044a commit 569b436

File tree

14 files changed

+405
-75
lines changed

14 files changed

+405
-75
lines changed

src/algorithms/q_learning.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
Env = TypeVar('Env')
1212
Policy = TypeVar('Policy')
1313

14-
class QLearnConfig(object):
1514

15+
class QLearnConfig(object):
16+
"""
17+
Configuration for Q-learning
18+
"""
1619
def __init__(self):
1720
self.gamma: float = 1.0
1821
self.alpha: float = 0.1
@@ -21,16 +24,15 @@ def __init__(self):
2124

2225

2326
class QLearning(WithMaxActionMixin):
27+
"""
28+
Q-learning algorithm implementation
29+
"""
2430

2531
def __init__(self, algo_config: QLearnConfig):
2632
super(QLearning, self).__init__()
2733
self.q_table = {}
2834
self.config = algo_config
2935

30-
# monitor performance
31-
self.total_rewards: np.array = None
32-
self.iterations_per_episode = []
33-
3436
@property
3537
def name(self) -> str:
3638
return "QLearn"

src/algorithms/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, env: Env, agent: Agent, configuration: dir) -> None:
1717
self.agent = agent
1818
self.configuration = configuration
1919
# monitor performance
20-
self.total_rewards: np.array = None
20+
self.total_rewards: np.array = np.zeros(configuration['n_episodes'])
2121
self.iterations_per_episode = []
2222

2323
def actions_before_training(self):

src/apps/qlearning_on_mock.py

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
14
from src.algorithms.q_learning import QLearning, QLearnConfig
25
from src.algorithms.trainer import Trainer
3-
from src.utils.string_distance_calculator import DistanceType
6+
from src.utils.string_distance_calculator import StringDistanceType
47
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform
58
from src.spaces.environment import Environment, EnvConfig
69
from src.spaces.action_space import ActionSpace
@@ -11,45 +14,74 @@
1114
from src.utils.numeric_distance_type import NumericDistanceType
1215

1316

17+
def plot_running_avg(avg_rewards):
18+
19+
running_avg = np.empty(avg_rewards.shape[0])
20+
for t in range(avg_rewards.shape[0]):
21+
running_avg[t] = np.mean(avg_rewards[max(0, t-100) : (t+1)])
22+
plt.plot(running_avg)
23+
plt.xlabel("Number of episodes")
24+
plt.ylabel("Reward")
25+
plt.title("Running average")
26+
plt.show()
27+
28+
def get_ethinicity_hierarchies():
29+
30+
ethnicity_hierarchy = SerialHierarchy()
31+
ethnicity_hierarchy.add("Mixed White/Asian", values=["Mixed", '*'])
32+
ethnicity_hierarchy.add("Chinese", values=["Asian", '*'])
33+
ethnicity_hierarchy.add("Indian", values=["Asian", '*'])
34+
ethnicity_hierarchy.add("Mixed White/Black African", values=["Mixed", '*'])
35+
ethnicity_hierarchy.add("Black African", values=["Black", '*'])
36+
ethnicity_hierarchy.add("Asian other", values=["Asian", "*"])
37+
ethnicity_hierarchy.add("Black other", values=["Black", "*"])
38+
ethnicity_hierarchy.add("Mixed White/Black Caribbean", values=["Mixed", "*"])
39+
ethnicity_hierarchy.add("Mixed other", values=["Mixed", "*"])
40+
ethnicity_hierarchy.add("Arab", values=["Asian", "*"])
41+
ethnicity_hierarchy.add("White Irish", values=["White", "*"])
42+
ethnicity_hierarchy.add("Not stated", values=["Not stated", "*"])
43+
ethnicity_hierarchy.add("White Gypsy/Traveller", values=["White", "*"])
44+
ethnicity_hierarchy.add("White British", values=["White", "*"])
45+
ethnicity_hierarchy.add("Bangladeshi", values=["Asian", "*"])
46+
ethnicity_hierarchy.add("White other", values=["White", "*"])
47+
ethnicity_hierarchy.add("Black Caribbean", values=["Black", "*"])
48+
ethnicity_hierarchy.add("Pakistani", values=["Asian", "*"])
49+
50+
return ethnicity_hierarchy
51+
52+
1453
if __name__ == '__main__':
1554

1655
EPS = 1.0
1756
GAMMA = 0.99
1857
ALPHA = 0.1
58+
N_EPISODES = 100
1959

2060
# load the dataset
2161
ds = MockSubjectsLoader()
2262

63+
# generalization table for the ethnicity column
64+
ethinicity_table = get_ethinicity_hierarchies()
65+
2366
# specify the action space. We need to establish how these actions
2467
# are performed
25-
action_space = ActionSpace(n=4)
26-
27-
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
28-
"Chinese": SerialHierarchy(values=["Asian", ]),
29-
"Indian": SerialHierarchy(values=["Asian", ]),
30-
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
31-
"Black African": SerialHierarchy(values=["Black", ]),
32-
"Asian other": SerialHierarchy(values=["Asian", ]),
33-
"Black other": SerialHierarchy(values=["Black", ]),
34-
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
35-
"Mixed other": SerialHierarchy(values=["Mixed", ]),
36-
"Arab": SerialHierarchy(values=["Asian", ]),
37-
"White Irish": SerialHierarchy(values=["White", ]),
38-
"Not stated": SerialHierarchy(values=["Not stated"]),
39-
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
40-
"White British": SerialHierarchy(values=["White", ]),
41-
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
42-
"White other": SerialHierarchy(values=["White", ]),
43-
"Black Caribbean": SerialHierarchy(values=["Black", ]),
44-
"Pakistani": SerialHierarchy(values=["Asian", ])}
45-
68+
action_space = ActionSpace(n=5)
4669
action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]),
4770
'M': SerialHierarchy(values=['*', ])}),
48-
ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"),
49-
ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
50-
71+
ActionIdentity(column_name="salary"),
72+
ActionIdentity(column_name="education"),
73+
ActionGeneralize(column_name="ethnicity", generalization_table=ethinicity_table),
74+
ActionSuppress(column_name="preventative_treatment",
75+
suppress_table={"No": SerialHierarchy(values=['Maybe', '*']),
76+
'Yes': SerialHierarchy(values=['Maybe', '*']),
77+
"NA": SerialHierarchy(values=['Maybe', '*']),
78+
"Maybe": SerialHierarchy(values=['*', '*'])
79+
}))
80+
81+
# average distirtion
5182
average_distortion_constraint = {"salary": [0.0, 0.0, 0.0], "education": [0.0, 0.0, 0.0],
52-
"ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0]}
83+
"ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0],
84+
"preventative_treatment": [4.0, 1.0, -1.0]}
5385

5486
# specify the reward manager to use
5587
reward_manager = RewardManager(average_distortion_constraint=average_distortion_constraint)
@@ -66,20 +98,36 @@
6698
env = Environment(env_config=env_config)
6799

68100
# initialize text distances
69-
env.initialize_text_distances(distance_type=DistanceType.COSINE)
101+
env.initialize_text_distances(distance_type=StringDistanceType.COSINE)
70102

71103
algo_config = QLearnConfig()
72-
algo_config.n_itrs_per_episode = 1000
104+
algo_config.n_itrs_per_episode = 10
73105
algo_config.gamma = 0.99
74106
algo_config.alpha = 0.1
75107
algo_config.policy = EpsilonGreedyPolicy(eps=EPS, env=env,
76108
decay_op=EpsilonDecreaseOption.INVERSE_STEP)
77109

78110
agent = QLearning(algo_config=algo_config)
79111

80-
configuration = {"n_episodes": 10, "output_msg_frequency": 100}
112+
configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": 10}
81113

82114
# create a trainer to train the A2C agent
83115
trainer = Trainer(env=env, agent=agent, configuration=configuration)
84116

85117
trainer.train()
118+
119+
# get the state space
120+
state_space = env.state_space
121+
122+
for state in state_space:
123+
print("Column {0} history {1}".format(state, state_space[state].history))
124+
125+
total_reward = trainer.total_rewards
126+
episodes = [episode for episode in range(N_EPISODES)]
127+
128+
plt.plot(episodes, total_reward)
129+
plt.xlabel("Episodes")
130+
plt.ylabel("Reward")
131+
plt.show()
132+
133+

src/datasets/dataset_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def read(self, filename: Path, **options) -> None:
7474
features_drop_names=options["features_drop_names"],
7575
names=options["names"])
7676

77-
if "change_col_vals" in options:
77+
if "change_col_vals" in options and \
78+
options["change_col_vals"] is not None and \
79+
len(options["change_col_vals"]) != 0:
7880
self.ds = replace(ds=self.ds, options=options["change_col_vals"])
7981

8082
# try to cast to the data types

src/exceptions/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ def __str__(self):
3131
return self.message
3232

3333

34+
class InvalidStateException(Exception):
35+
def __init__(self, type_name: str, state_type: str) -> None:
36+
self.message = "Type= {0} is not in state= {1}".format(type_name, state_type)
37+
38+
def __str__(self):
39+
return self.message
40+
3441

3542

3643

src/spaces/action_space.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import numpy as np
7+
import random
78
from gym.spaces.discrete import Discrete
89
from src.spaces.actions import ActionBase
910

@@ -37,6 +38,13 @@ def __setitem__(self, key: int, value: ActionBase) -> None:
3738
"""
3839
self.actions[key] = value
3940

41+
def shuffle(self) -> None:
42+
"""
43+
Randomly shuffle the actions in the space
44+
:return:
45+
"""
46+
random.shuffle(self.actions)
47+
4048
def get_action_by_column_name(self, column_name: str) -> ActionBase:
4149
"""
4250
Get the action that corresponds to the column with

src/spaces/actions.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ class ActionType(enum.IntEnum):
2222
Defines the status of an Action
2323
"""
2424

25+
INVALID_TYPE = -1
2526
TRANSFORM = 0
2627
SUPPRESS = 1
2728
GENERALIZE = 2
2829
IDENTITY = 3
30+
RESTORE = 4
31+
32+
def invalid(self) -> bool:
33+
return self is ActionType.RESTORE
2934

3035
def transform(self) -> bool:
3136
return self is ActionType.TRANSFORM
@@ -39,6 +44,9 @@ def generalize(self) -> bool:
3944
def identity(self) -> bool:
4045
return self is ActionType.IDENTITY
4146

47+
def restore(self) -> bool:
48+
return self is ActionType.RESTORE
49+
4250

4351
class ActionBase(metaclass=abc.ABCMeta):
4452
"""
@@ -122,6 +130,46 @@ def reinitialize(self) -> None:
122130
self.called = False
123131

124132

133+
class ActionRestore(ActionBase, WithHierarchyTable):
134+
"""
135+
Implements the restore action
136+
"""
137+
138+
def __init__(self, column_name: str, restore_table):
139+
super(ActionRestore, self).__init__(column_name=column_name, action_type=ActionType.RESTORE)
140+
self.table = restore_table
141+
142+
def act(self, **ops):
143+
"""
144+
Perform an action
145+
:return:
146+
"""
147+
pass
148+
149+
def get_maximum_number_of_transforms(self):
150+
"""
151+
Returns the maximum number of transforms that the action applies
152+
:return:
153+
"""
154+
raise NotImplementedError("Method not implemented")
155+
156+
def is_exhausted(self) -> bool:
157+
"""
158+
Returns true if the action has exhausted all its
159+
transforms
160+
:return:
161+
"""
162+
raise NotImplementedError("Method not implemented")
163+
164+
def reinitialize(self) -> None:
165+
"""
166+
Reinitialize the action to the state when the
167+
constructor is called
168+
:return:
169+
"""
170+
raise NotImplementedError("Method not implemented")
171+
172+
125173
class ActionTransform(ActionBase):
126174
"""
127175
Implements the transform action
@@ -183,6 +231,7 @@ def act(self, **ops) -> None:
183231

184232
# generalize the data given
185233
for i, item in enumerate(ops["data"]):
234+
186235
value = self.table[item].value
187236
col_vals[i] = value
188237

0 commit comments

Comments
 (0)