Skip to content

Commit 20035cf

Browse files
authored
Merge pull request #15 from pockerman/add_q_learning_algorithm
Add q learning algorithm
2 parents 61f884b + bc4044a commit 20035cf

26 files changed

+1029
-111
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ src/tests/.pytest_cache/
55
src/spaces/__pycache__/
66
src/__pycache__/
77
src/algorithms/__pycache__/
8+
src/policies/__pycache__/

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,9 @@ to use the reinforcement learning paradigm in order to train agents to perform t
1616
places this into a persepctive
1717

1818

19-
![RL anonymity paradigm](images/general_concept.png "Reinforcement learning anonymity schematics")
19+
![RL anonymity paradigm](images/general_concept.png "Reinforcement learning anonymity schematics")
20+
21+
## Dependencies
22+
23+
## Documentation
2024

doc/env_concept.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Environment concept

src/algorithms/q_learning.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Simple Q-learning algorithm
3+
"""
4+
5+
import numpy as np
6+
from typing import TypeVar
7+
8+
from src.exceptions.exceptions import InvalidParamValue
9+
from src.utils.mixins import WithMaxActionMixin
10+
11+
Env = TypeVar('Env')
12+
Policy = TypeVar('Policy')
13+
14+
class QLearnConfig(object):
15+
16+
def __init__(self):
17+
self.gamma: float = 1.0
18+
self.alpha: float = 0.1
19+
self.n_itrs_per_episode: int = 100
20+
self.policy: Policy = None
21+
22+
23+
class QLearning(WithMaxActionMixin):
24+
25+
def __init__(self, algo_config: QLearnConfig):
26+
super(QLearning, self).__init__()
27+
self.q_table = {}
28+
self.config = algo_config
29+
30+
# monitor performance
31+
self.total_rewards: np.array = None
32+
self.iterations_per_episode = []
33+
34+
@property
35+
def name(self) -> str:
36+
return "QLearn"
37+
38+
def actions_before_training(self, env: Env, **options):
39+
40+
if self.config.policy is None:
41+
raise InvalidParamValue(param_name="policy", param_value="None")
42+
43+
for state in range(env.observation_space.n):
44+
for action in range(env.action_space.n):
45+
self.q_table[state, action] = 0.0
46+
47+
def actions_after_episode_ends(self, **options):
48+
"""
49+
Execute any actions the algorithm needs before
50+
starting the episode
51+
:param options:
52+
:return:
53+
"""
54+
55+
self.config.policy.actions_after_episode(options['episode_idx'])
56+
57+
def train(self, env: Env, **options) -> tuple:
58+
59+
# episode score
60+
episode_score = 0 # initialize score
61+
counter = 0
62+
63+
time_step = env.reset()
64+
state = time_step.observation
65+
66+
for itr in range(self.config.n_itrs_per_episode):
67+
68+
# epsilon-greedy action selection
69+
action_idx = self.config.policy(q_func=self.q_table, state=state)
70+
71+
action = env.get_action(action_idx)
72+
73+
# take action A, observe R, S'
74+
next_time_step = env.step(action)
75+
next_state = next_time_step.observation
76+
reward = next_time_step.reward
77+
78+
next_state_id = next_state.state_id if next_state is not None else None
79+
80+
# add reward to agent's score
81+
episode_score += next_time_step.reward
82+
self._update_Q_table(state=state.state_id, action=action_idx, reward=reward,
83+
next_state=next_state_id, n_actions=env.action_space.n)
84+
state = next_state # S <- S'
85+
counter += 1
86+
87+
if next_time_step.last():
88+
break
89+
90+
return episode_score, counter
91+
92+
def _update_Q_table(self, state: int, action: int, n_actions: int, reward: float, next_state: int = None) -> None:
93+
"""
94+
Update the Q-value for the state
95+
"""
96+
97+
# estimate in Q-table (for current state, action pair)
98+
q_s = self.q_table[state, action]
99+
100+
# value of next state
101+
Qsa_next = \
102+
self.q_table[next_state, self.max_action(next_state, n_actions=n_actions)] if next_state is not None else 0
103+
# construct TD target
104+
target = reward + (self.config.gamma * Qsa_next)
105+
106+
# get updated value
107+
new_value = q_s + (self.config.alpha * (target - q_s))
108+
self.q_table[state, action] = new_value

src/algorithms/trainer.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Trainer
33
"""
44

5-
from src.utils import INFO
5+
import numpy as np
66
from typing import TypeVar
7+
from src.utils import INFO
78

89
Env = TypeVar("Env")
910
Agent = TypeVar("Agent")
@@ -15,22 +16,40 @@ def __init__(self, env: Env, agent: Agent, configuration: dir) -> None:
1516
self.env = env
1617
self.agent = agent
1718
self.configuration = configuration
19+
# monitor performance
20+
self.total_rewards: np.array = None
21+
self.iterations_per_episode = []
22+
23+
def actions_before_training(self):
24+
self.total_rewards: np.array = np.zeros(self.configuration['n_episodes'])
25+
self.iterations_per_episode = []
26+
27+
self.agent.actions_before_training(self.env)
28+
29+
def actions_after_episode_ends(self, **options):
30+
self.agent.actions_after_episode_ends(**options)
1831

1932
def train(self):
2033

2134
print("{0} Training agent {1}".format(INFO, self.agent.name))
35+
self.actions_before_training()
2236

23-
for episode in range(1, self.configuration["max_n_episodes"] + 1):
24-
print("INFO: Episode {0}/{1}".format(episode, self.configuration["max_n_episodes"]))
37+
for episode in range(0, self.configuration["n_episodes"]):
38+
print("INFO: Episode {0}/{1}".format(episode, self.configuration["n_episodes"]))
2539

2640
# reset the environment
2741
ignore = self.env.reset()
2842

2943
# train for a number of iterations
30-
self.agent.train(self.env)
44+
episode_score, n_itrs = self.agent.train(self.env)
45+
46+
if episode % self.configuration['output_msg_frequency'] == 0:
47+
print("{0}: On episode {1} training finished with "
48+
"{2} iterations. Total reward={3}".format(INFO, episode, n_itrs, episode_score))
49+
50+
self.iterations_per_episode.append(n_itrs)
51+
self.total_rewards[episode] = episode_score
3152

32-
# is it time to update the model?
33-
if self.configuration["update_frequency"] % episode == 0:
34-
self.agent.update()
53+
self.actions_after_episode_ends(**{"episode_idx": episode})
3554

3655
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/apps/__init__.py

Whitespace-only changes.

src/apps/qlearning_on_mock.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from src.algorithms.q_learning import QLearning, QLearnConfig
2+
from src.algorithms.trainer import Trainer
3+
from src.utils.string_distance_calculator import DistanceType
4+
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform
5+
from src.spaces.environment import Environment, EnvConfig
6+
from src.spaces.action_space import ActionSpace
7+
from src.datasets.datasets_loaders import MockSubjectsLoader
8+
from src.utils.reward_manager import RewardManager
9+
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonDecreaseOption
10+
from src.utils.serial_hierarchy import SerialHierarchy
11+
from src.utils.numeric_distance_type import NumericDistanceType
12+
13+
14+
if __name__ == '__main__':
15+
16+
EPS = 1.0
17+
GAMMA = 0.99
18+
ALPHA = 0.1
19+
20+
# load the dataset
21+
ds = MockSubjectsLoader()
22+
23+
# specify the action space. We need to establish how these actions
24+
# are performed
25+
action_space = ActionSpace(n=4)
26+
27+
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
28+
"Chinese": SerialHierarchy(values=["Asian", ]),
29+
"Indian": SerialHierarchy(values=["Asian", ]),
30+
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
31+
"Black African": SerialHierarchy(values=["Black", ]),
32+
"Asian other": SerialHierarchy(values=["Asian", ]),
33+
"Black other": SerialHierarchy(values=["Black", ]),
34+
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
35+
"Mixed other": SerialHierarchy(values=["Mixed", ]),
36+
"Arab": SerialHierarchy(values=["Asian", ]),
37+
"White Irish": SerialHierarchy(values=["White", ]),
38+
"Not stated": SerialHierarchy(values=["Not stated"]),
39+
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
40+
"White British": SerialHierarchy(values=["White", ]),
41+
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
42+
"White other": SerialHierarchy(values=["White", ]),
43+
"Black Caribbean": SerialHierarchy(values=["Black", ]),
44+
"Pakistani": SerialHierarchy(values=["Asian", ])}
45+
46+
action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]),
47+
'M': SerialHierarchy(values=['*', ])}),
48+
ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"),
49+
ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
50+
51+
average_distortion_constraint = {"salary": [0.0, 0.0, 0.0], "education": [0.0, 0.0, 0.0],
52+
"ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0]}
53+
54+
# specify the reward manager to use
55+
reward_manager = RewardManager(average_distortion_constraint=average_distortion_constraint)
56+
57+
env_config = EnvConfig()
58+
env_config.start_column = "gender"
59+
env_config.action_space = action_space
60+
env_config.reward_manager = reward_manager
61+
env_config.data_set = ds
62+
env_config.gamma = 0.99
63+
env_config.numeric_column_distortion_metric_type = NumericDistanceType.L2
64+
65+
# create the environment
66+
env = Environment(env_config=env_config)
67+
68+
# initialize text distances
69+
env.initialize_text_distances(distance_type=DistanceType.COSINE)
70+
71+
algo_config = QLearnConfig()
72+
algo_config.n_itrs_per_episode = 1000
73+
algo_config.gamma = 0.99
74+
algo_config.alpha = 0.1
75+
algo_config.policy = EpsilonGreedyPolicy(eps=EPS, env=env,
76+
decay_op=EpsilonDecreaseOption.INVERSE_STEP)
77+
78+
agent = QLearning(algo_config=algo_config)
79+
80+
configuration = {"n_episodes": 10, "output_msg_frequency": 100}
81+
82+
# create a trainer to train the A2C agent
83+
trainer = Trainer(env=env, agent=agent, configuration=configuration)
84+
85+
trainer.train()
Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,68 @@
11
"""
22
Utilities for calculating the information leakage
33
for a dataset
4-
"""
4+
"""
5+
import numpy as np
6+
from typing import TypeVar
7+
from src.exceptions.exceptions import InvalidSchemaException, Error
8+
from src.datasets.dataset_distances import lp_distance
9+
from src.utils import numeric_distance_type
10+
11+
DataSet = TypeVar("DataSet")
12+
State = TypeVar("State")
13+
14+
15+
def state_leakage(state1: State, state2: State, dist_type: numeric_distance_type.NumericDistanceType) -> float:
16+
17+
if dist_type == numeric_distance_type.NumericDistanceType.L2:
18+
return _l2_state_leakage(state1=state1, state2=state2)
19+
elif dist_type == numeric_distance_type.NumericDistanceType.L1:
20+
return _l1_state_leakage(state1=state1, state2=state2)
21+
22+
raise Error("Invalid distance type {0}".format(dist_type.name))
23+
24+
25+
def info_leakage(ds1: DataSet, ds2: DataSet, column_distances: dict = None, p=None) -> tuple:
26+
"""
27+
Returns the information leakage between the two data sets
28+
:param ds1:
29+
:param ds2:
30+
:param column_dists: A dictionary that holds numeric distances to use if a column
31+
is of type string
32+
:return:
33+
"""
34+
35+
if ds1.schema != ds2.schema:
36+
raise InvalidSchemaException(message="Invalid schema for datasets")
37+
38+
if column_distances is None:
39+
return lp_distance(ds1=ds1, ds2=ds2, p=p)
40+
41+
distances = {}
42+
cols = ds1.get_columns_names()
43+
for col in cols:
44+
45+
if col in column_distances:
46+
# get the total distortion of the column
47+
distances[col] = column_distances[col]
48+
else:
49+
50+
val1 = ds1.get_column(col_name=col)
51+
val2 = ds2.get_column(col_name=col)
52+
distances[col] = np.linalg.norm(val1 - val2, ord=p)
53+
54+
sum_distances = sum(distances.values())
55+
return distances, sum_distances
56+
57+
58+
def _l2_state_leakage(state1: State, state2: State) -> float:
59+
return np.linalg.norm(state1 - state2, ord=None)
60+
61+
def _l1_state_leakage(state1: State, state2: State) -> float:
62+
return np.linalg.norm(state1 - state2, ord=1)
63+
64+
65+
66+
67+
68+

src/datasets/dataset_wrapper.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,20 @@ def read(self, filename: Path, **options) -> None:
8080
# try to cast to the data types
8181
self.ds = change_column_types(ds=self.ds, column_types=self.columns)
8282

83+
def sample_column_name(self) -> str:
84+
"""
85+
Samples a name from the columns
86+
:return: a column name
87+
"""
88+
names = self.get_columns_names()
89+
return np.random.choice(names)
90+
8391
def set_columns_to_type(self, col_name_types) -> None:
92+
"""
93+
Set the types of the columns
94+
:param col_name_types:
95+
:return:
96+
"""
8497
self.ds.astype(dtype=col_name_types)
8598

8699
def attach_column_hierarchy(self, col_name: str, hierarchy: HierarchyBase):

src/exceptions/exceptions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,23 @@ def __str__(self):
1515
return self.message
1616

1717

18+
class InvalidParamValue(Exception):
19+
def __init__(self, param_name: str, param_value: str):
20+
self.message = "Parameter {0} has invalid value {1}".format(param_name, param_value)
21+
22+
def __str__(self):
23+
return self.message
24+
25+
26+
class InvalidSchemaException(Exception):
27+
def __init__(self, message: str) -> None:
28+
self.message = message
29+
30+
def __str__(self):
31+
return self.message
32+
33+
34+
35+
36+
37+

0 commit comments

Comments
 (0)