Skip to content

Commit f9fb8fa

Browse files
authored
Merge pull request #21 from pockerman/add_q_learning_algorithm
Add q learning algorithm
2 parents 20035cf + 6a8bcb0 commit f9fb8fa

26 files changed

+1066
-781
lines changed

src/algorithms/anonymity_a2c_ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ray.rllib.agents.a3c as a3c
77
from ray.tune.logger import pretty_print
88
from ray.rllib.env.env_context import EnvContext
9-
from src.spaces.environment import TimeStep, StepType
9+
from src.spaces.discrete_state_environment import TimeStep, StepType
1010
from src.spaces.observation_space import ObsSpace
1111

1212

src/algorithms/q_learning.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
Env = TypeVar('Env')
1212
Policy = TypeVar('Policy')
1313

14-
class QLearnConfig(object):
1514

15+
class QLearnConfig(object):
16+
"""
17+
Configuration for Q-learning
18+
"""
1619
def __init__(self):
1720
self.gamma: float = 1.0
1821
self.alpha: float = 0.1
@@ -21,16 +24,15 @@ def __init__(self):
2124

2225

2326
class QLearning(WithMaxActionMixin):
27+
"""
28+
Q-learning algorithm implementation
29+
"""
2430

2531
def __init__(self, algo_config: QLearnConfig):
2632
super(QLearning, self).__init__()
2733
self.q_table = {}
2834
self.config = algo_config
2935

30-
# monitor performance
31-
self.total_rewards: np.array = None
32-
self.iterations_per_episode = []
33-
3436
@property
3537
def name(self) -> str:
3638
return "QLearn"
@@ -40,8 +42,8 @@ def actions_before_training(self, env: Env, **options):
4042
if self.config.policy is None:
4143
raise InvalidParamValue(param_name="policy", param_value="None")
4244

43-
for state in range(env.observation_space.n):
44-
for action in range(env.action_space.n):
45+
for state in range(1, env.n_states):
46+
for action in range(env.n_actions):
4547
self.q_table[state, action] = 0.0
4648

4749
def actions_after_episode_ends(self, **options):
@@ -57,8 +59,9 @@ def actions_after_episode_ends(self, **options):
5759
def train(self, env: Env, **options) -> tuple:
5860

5961
# episode score
60-
episode_score = 0 # initialize score
62+
episode_score = 0
6163
counter = 0
64+
total_distortion = 0
6265

6366
time_step = env.reset()
6467
state = time_step.observation
@@ -70,24 +73,28 @@ def train(self, env: Env, **options) -> tuple:
7073

7174
action = env.get_action(action_idx)
7275

76+
if action.action_type.name == "GENERALIZE" and action.column_name == "salary":
77+
print("Attempt to generalize salary")
78+
else:
79+
print(action.action_type.name, " on ", action.column_name)
80+
7381
# take action A, observe R, S'
7482
next_time_step = env.step(action)
7583
next_state = next_time_step.observation
7684
reward = next_time_step.reward
7785

78-
next_state_id = next_state.state_id if next_state is not None else None
79-
8086
# add reward to agent's score
81-
episode_score += next_time_step.reward
82-
self._update_Q_table(state=state.state_id, action=action_idx, reward=reward,
83-
next_state=next_state_id, n_actions=env.action_space.n)
87+
episode_score += reward
88+
self._update_Q_table(state=state, action=action_idx, reward=reward,
89+
next_state=next_state, n_actions=env.n_actions)
8490
state = next_state # S <- S'
8591
counter += 1
92+
total_distortion += next_time_step.info["total_distortion"]
8693

8794
if next_time_step.last():
8895
break
8996

90-
return episode_score, counter
97+
return episode_score, total_distortion, counter
9198

9299
def _update_Q_table(self, state: int, action: int, n_actions: int, reward: float, next_state: int = None) -> None:
93100
"""

src/algorithms/trainer.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,26 @@ def __init__(self, env: Env, agent: Agent, configuration: dir) -> None:
1717
self.agent = agent
1818
self.configuration = configuration
1919
# monitor performance
20-
self.total_rewards: np.array = None
20+
self.total_rewards: np.array = np.zeros(configuration['n_episodes'])
2121
self.iterations_per_episode = []
22+
self.total_distortions = []
23+
24+
def avg_rewards(self) -> np.array:
25+
"""
26+
Returns the average reward per episode
27+
:return:
28+
"""
29+
avg = np.zeros(self.configuration['n_episodes'])
30+
31+
for i in range(self.total_rewards.shape[0]):
32+
avg[i] = self.total_rewards[i] / self.iterations_per_episode[i]
33+
return avg
2234

2335
def actions_before_training(self):
36+
"""
37+
Any actions to perform before training begins
38+
:return:
39+
"""
2440
self.total_rewards: np.array = np.zeros(self.configuration['n_episodes'])
2541
self.iterations_per_episode = []
2642

@@ -29,27 +45,32 @@ def actions_before_training(self):
2945
def actions_after_episode_ends(self, **options):
3046
self.agent.actions_after_episode_ends(**options)
3147

48+
if options["episode_idx"] % self.configuration['output_msg_frequency'] == 0:
49+
if self.env.config.distorted_set_path is not None:
50+
self.env.save_current_dataset(options["episode_idx"])
51+
3252
def train(self):
3353

3454
print("{0} Training agent {1}".format(INFO, self.agent.name))
3555
self.actions_before_training()
3656

3757
for episode in range(0, self.configuration["n_episodes"]):
38-
print("INFO: Episode {0}/{1}".format(episode, self.configuration["n_episodes"]))
58+
print("{0} On episode {1}/{2}".format(INFO, episode, self.configuration["n_episodes"]))
3959

4060
# reset the environment
4161
ignore = self.env.reset()
4262

4363
# train for a number of iterations
44-
episode_score, n_itrs = self.agent.train(self.env)
64+
episode_score, total_distortion, n_itrs = self.agent.train(self.env)
4565

46-
if episode % self.configuration['output_msg_frequency'] == 0:
47-
print("{0}: On episode {1} training finished with "
48-
"{2} iterations. Total reward={3}".format(INFO, episode, n_itrs, episode_score))
66+
print("{0} Episode score={1}, episode total distortion {2}".format(INFO, episode_score, total_distortion / n_itrs))
67+
68+
#if episode % self.configuration['output_msg_frequency'] == 0:
69+
print("{0} Episode finished after {1} iterations".format(INFO, n_itrs))
4970

5071
self.iterations_per_episode.append(n_itrs)
5172
self.total_rewards[episode] = episode_score
52-
73+
self.total_distortions.append(total_distortion)
5374
self.actions_after_episode_ends(**{"episode_idx": episode})
5475

5576
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/apps/qlearning_on_mock.py

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
14
from src.algorithms.q_learning import QLearning, QLearnConfig
25
from src.algorithms.trainer import Trainer
3-
from src.utils.string_distance_calculator import DistanceType
4-
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform
5-
from src.spaces.environment import Environment, EnvConfig
6+
from src.utils.string_distance_calculator import StringDistanceType
7+
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionStringGeneralize, ActionTransform
8+
from src.spaces.discrete_state_environment import Environment, EnvConfig
69
from src.spaces.action_space import ActionSpace
710
from src.datasets.datasets_loaders import MockSubjectsLoader
811
from src.utils.reward_manager import RewardManager
@@ -11,45 +14,74 @@
1114
from src.utils.numeric_distance_type import NumericDistanceType
1215

1316

17+
def plot_running_avg(avg_rewards):
18+
19+
running_avg = np.empty(avg_rewards.shape[0])
20+
for t in range(avg_rewards.shape[0]):
21+
running_avg[t] = np.mean(avg_rewards[max(0, t-100) : (t+1)])
22+
plt.plot(running_avg)
23+
plt.xlabel("Number of episodes")
24+
plt.ylabel("Reward")
25+
plt.title("Running average")
26+
plt.show()
27+
28+
def get_ethinicity_hierarchies():
29+
30+
ethnicity_hierarchy = SerialHierarchy()
31+
ethnicity_hierarchy.add("Mixed White/Asian", values=["Mixed", '*'])
32+
ethnicity_hierarchy.add("Chinese", values=["Asian", '*'])
33+
ethnicity_hierarchy.add("Indian", values=["Asian", '*'])
34+
ethnicity_hierarchy.add("Mixed White/Black African", values=["Mixed", '*'])
35+
ethnicity_hierarchy.add("Black African", values=["Black", '*'])
36+
ethnicity_hierarchy.add("Asian other", values=["Asian", "*"])
37+
ethnicity_hierarchy.add("Black other", values=["Black", "*"])
38+
ethnicity_hierarchy.add("Mixed White/Black Caribbean", values=["Mixed", "*"])
39+
ethnicity_hierarchy.add("Mixed other", values=["Mixed", "*"])
40+
ethnicity_hierarchy.add("Arab", values=["Asian", "*"])
41+
ethnicity_hierarchy.add("White Irish", values=["White", "*"])
42+
ethnicity_hierarchy.add("Not stated", values=["Not stated", "*"])
43+
ethnicity_hierarchy.add("White Gypsy/Traveller", values=["White", "*"])
44+
ethnicity_hierarchy.add("White British", values=["White", "*"])
45+
ethnicity_hierarchy.add("Bangladeshi", values=["Asian", "*"])
46+
ethnicity_hierarchy.add("White other", values=["White", "*"])
47+
ethnicity_hierarchy.add("Black Caribbean", values=["Black", "*"])
48+
ethnicity_hierarchy.add("Pakistani", values=["Asian", "*"])
49+
50+
return ethnicity_hierarchy
51+
52+
1453
if __name__ == '__main__':
1554

1655
EPS = 1.0
1756
GAMMA = 0.99
1857
ALPHA = 0.1
58+
N_EPISODES = 100
1959

2060
# load the dataset
2161
ds = MockSubjectsLoader()
2262

63+
# generalization table for the ethnicity column
64+
ethinicity_table = get_ethinicity_hierarchies()
65+
2366
# specify the action space. We need to establish how these actions
2467
# are performed
25-
action_space = ActionSpace(n=4)
26-
27-
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
28-
"Chinese": SerialHierarchy(values=["Asian", ]),
29-
"Indian": SerialHierarchy(values=["Asian", ]),
30-
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
31-
"Black African": SerialHierarchy(values=["Black", ]),
32-
"Asian other": SerialHierarchy(values=["Asian", ]),
33-
"Black other": SerialHierarchy(values=["Black", ]),
34-
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
35-
"Mixed other": SerialHierarchy(values=["Mixed", ]),
36-
"Arab": SerialHierarchy(values=["Asian", ]),
37-
"White Irish": SerialHierarchy(values=["White", ]),
38-
"Not stated": SerialHierarchy(values=["Not stated"]),
39-
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
40-
"White British": SerialHierarchy(values=["White", ]),
41-
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
42-
"White other": SerialHierarchy(values=["White", ]),
43-
"Black Caribbean": SerialHierarchy(values=["Black", ]),
44-
"Pakistani": SerialHierarchy(values=["Asian", ])}
45-
68+
action_space = ActionSpace(n=5)
4669
action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]),
4770
'M': SerialHierarchy(values=['*', ])}),
48-
ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"),
49-
ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
50-
71+
ActionIdentity(column_name="salary"),
72+
ActionIdentity(column_name="education"),
73+
ActionStringGeneralize(column_name="ethnicity", generalization_table=ethinicity_table),
74+
ActionSuppress(column_name="preventative_treatment",
75+
suppress_table={"No": SerialHierarchy(values=['Maybe', '*']),
76+
'Yes': SerialHierarchy(values=['Maybe', '*']),
77+
"NA": SerialHierarchy(values=['Maybe', '*']),
78+
"Maybe": SerialHierarchy(values=['*', '*'])
79+
}))
80+
81+
# average distirtion
5182
average_distortion_constraint = {"salary": [0.0, 0.0, 0.0], "education": [0.0, 0.0, 0.0],
52-
"ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0]}
83+
"ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0],
84+
"preventative_treatment": [4.0, 1.0, -1.0]}
5385

5486
# specify the reward manager to use
5587
reward_manager = RewardManager(average_distortion_constraint=average_distortion_constraint)
@@ -66,20 +98,36 @@
6698
env = Environment(env_config=env_config)
6799

68100
# initialize text distances
69-
env.initialize_text_distances(distance_type=DistanceType.COSINE)
101+
env.initialize_text_distances(distance_type=StringDistanceType.COSINE)
70102

71103
algo_config = QLearnConfig()
72-
algo_config.n_itrs_per_episode = 1000
104+
algo_config.n_itrs_per_episode = 10
73105
algo_config.gamma = 0.99
74106
algo_config.alpha = 0.1
75107
algo_config.policy = EpsilonGreedyPolicy(eps=EPS, env=env,
76108
decay_op=EpsilonDecreaseOption.INVERSE_STEP)
77109

78110
agent = QLearning(algo_config=algo_config)
79111

80-
configuration = {"n_episodes": 10, "output_msg_frequency": 100}
112+
configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": 10}
81113

82114
# create a trainer to train the A2C agent
83115
trainer = Trainer(env=env, agent=agent, configuration=configuration)
84116

85117
trainer.train()
118+
119+
# get the state space
120+
state_space = env.state_space
121+
122+
for state in state_space:
123+
print("Column {0} history {1}".format(state, state_space[state].history))
124+
125+
total_reward = trainer.total_rewards
126+
episodes = [episode for episode in range(N_EPISODES)]
127+
128+
plt.plot(episodes, total_reward)
129+
plt.xlabel("Episodes")
130+
plt.ylabel("Reward")
131+
plt.show()
132+
133+

0 commit comments

Comments
 (0)