Skip to content

Commit 64d43e9

Browse files
committed
#13 Update API
1 parent 453398c commit 64d43e9

File tree

13 files changed

+204
-87
lines changed

13 files changed

+204
-87
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,9 @@ to use the reinforcement learning paradigm in order to train agents to perform t
1616
places this into a persepctive
1717

1818

19-
![RL anonymity paradigm](images/general_concept.png "Reinforcement learning anonymity schematics")
19+
![RL anonymity paradigm](images/general_concept.png "Reinforcement learning anonymity schematics")
20+
21+
## Dependencies
22+
23+
## Documentation
2024

src/algorithms/trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ def train(self):
5050
self.iterations_per_episode.append(n_itrs)
5151
self.total_rewards[episode] = episode_score
5252

53-
# is it time to update the model?
54-
if self.configuration["update_frequency"] % episode == 0:
55-
self.agent.update()
56-
5753
self.actions_after_episode_ends(**{"episode_idx": episode})
5854

5955
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/apps/qlearning_on_mock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@
6565

6666
agent = QLearning(algo_config=algo_config)
6767

68-
configuration = {"n_episodes": 10, "update_frequency": 100}
68+
configuration = {"n_episodes": 10, "output_msg_frequency": 100}
6969

7070
# create a trainer to train the A2C agent
7171
trainer = Trainer(env=env, agent=agent, configuration=configuration)
7272

73-
trainer.train()
73+
trainer.train()
Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,46 @@
11
"""
22
Utilities for calculating the information leakage
33
for a dataset
4-
"""
4+
"""
5+
import numpy as np
6+
from typing import TypeVar
7+
from src.exceptions.exceptions import InvalidSchemaException
8+
from src.datasets.dataset_distances import lp_distance
9+
10+
DataSet = TypeVar("DataSet")
11+
12+
13+
def info_leakage(ds1: DataSet, ds2: DataSet, column_distances: dict = None, p=None) -> tuple:
14+
"""
15+
Returns the information leakage between the two data sets
16+
:param ds1:
17+
:param ds2:
18+
:param column_dists: A dictionary that holds numeric distances to use if a column
19+
is of type string
20+
:return:
21+
"""
22+
23+
if ds1.schema != ds2.schema:
24+
raise InvalidSchemaException(message="Invalid schema for datasets")
25+
26+
if column_distances is None:
27+
return lp_distance(ds1=ds1, ds2=ds2, p=p)
28+
29+
distances = {}
30+
cols = ds1.get_columns_names()
31+
for col in cols:
32+
33+
if col in column_distances:
34+
# get the total distortion of the column
35+
distances[col] = column_distances[col]
36+
else:
37+
38+
val1 = ds1.get_column(col_name=col)
39+
val2 = ds2.get_column(col_name=col)
40+
distances[col] = np.linalg.norm(val1 - val2, ord=p)
41+
42+
sum_distances = sum(distances.values())
43+
return distances, sum_distances
44+
45+
46+

src/exceptions/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,15 @@ def __str__(self):
2323
return self.message
2424

2525

26+
class InvalidSchemaException(Exception):
27+
def __init__(self, message: str) -> None:
28+
self.message = message
29+
30+
def __str__(self):
31+
return self.message
32+
33+
34+
35+
2636

2737

src/policies/deterministic_policy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import numpy as np
22
from typing import TypeVar
33

4-
from src.policies.policy_adaptor_base import PolicyAdaptorBase
5-
64
PolicyBase = TypeVar('PolicyBase')
75

86

9-
class DeterministicAdaptorPolicy(PolicyAdaptorBase):
7+
class DeterministicAdaptorPolicy(object):
108

119
"""
1210
Update a policy by choosing the best action

src/policies/epsilon_greedy_policy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
UserDefinedDecreaseMethod = TypeVar('UserDefinedDecreaseMethod')
1313
Env = TypeVar("Env")
14+
QTable = TypeVar("QTable")
1415

1516

1617
class EpsilonDecreaseOption(Enum):
@@ -41,7 +42,13 @@ def __init__(self, env: Env, eps: float,
4142
self._epsilon_decay_factor = epsilon_decay_factor
4243
self.user_defined_decrease_method: UserDefinedDecreaseMethod = user_defined_decrease_method
4344

44-
def __call__(self, q_func: Any, state: Any) -> int:
45+
def __call__(self, q_func: QTable, state: Any) -> int:
46+
"""
47+
Execute the policy
48+
:param q_func:
49+
:param state:
50+
:return:
51+
"""
4552

4653
# select greedy action with probability epsilon
4754
if random.random() > self._eps:

src/spaces/action_space.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,28 @@ def __setitem__(self, key: int, value: ActionBase) -> None:
3737
"""
3838
self.actions[key] = value
3939

40-
def add(self, action: ActionBase) -> None:
40+
def get_action_by_column_name(self, column_name: str) -> ActionBase:
4141
"""
42-
Add a new action in the space
43-
:param action:
44-
:return:
42+
Get the action that corresponds to the column with
43+
the given name. Raises ValueError if such an action does not
44+
exist
45+
:param column_name: The column name to look for
46+
:return: The action that corresponds to this name
4547
"""
4648

49+
for action in self.actions:
50+
if action.column_name == column_name:
51+
return action
52+
53+
raise ValueError("No action exists for column={0}".format(column_name))
54+
55+
def add(self, action: ActionBase) -> None:
56+
"""
57+
Add a new action in the space. Throws ValueError if the action space
58+
is full
59+
:param action: the action to add
60+
:return: None
61+
"""
4762
if len(self.actions) >= self.n:
4863
raise ValueError("Action space is saturated. You cannot add a new action")
4964

@@ -69,24 +84,32 @@ def sample_and_get(self) -> ActionBase:
6984
return self.actions[action_idx]
7085

7186
def get_non_exhausted_actions(self) -> list:
72-
87+
"""
88+
Returns a list of actions that have not exhausted the
89+
transformations that apply on a column.
90+
:return: list of actions. List may be empty. Client code should handle this
91+
"""
7392
actions_ = []
74-
7593
for action in self.actions:
7694
if not action.is_exhausted():
7795
actions_.append(action)
7896

7997
return actions_
8098

8199
def sample_and_get_non_exhausted(self) -> ActionBase:
82-
100+
"""
101+
Sample an action from the non exhausted actions
102+
:return: A non-exhausted action
103+
"""
83104
actions = self.get_non_exhausted_actions()
84105
return np.random.choice(actions)
85106

86-
def is_exhausted(self):
87-
107+
def is_exhausted(self) -> bool:
108+
"""
109+
Returns true if all the actions in the space are exhausted
110+
:return:
111+
"""
88112
finished = True
89-
90113
for action in self.actions:
91114
if not action.is_exhausted():
92115
return False
@@ -99,5 +122,5 @@ def reset(self) -> None:
99122
:return:
100123
"""
101124
for action in self.actions:
102-
action.reinit()
125+
action.reinitialize()
103126

src/spaces/actions.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def is_exhausted(self) -> bool:
7474
"""
7575

7676
@abc.abstractmethod
77-
def reinit(self) -> None:
77+
def reinitialize(self) -> None:
7878
"""
7979
Reinitialize the action to the state when the
8080
constructor is called
@@ -113,7 +113,7 @@ def is_exhausted(self) -> bool:
113113
"""
114114
return self.called
115115

116-
def reinit(self) -> None:
116+
def reinitialize(self) -> None:
117117
"""
118118
Reinitialize the action to the state when the
119119
constructor is called
@@ -123,7 +123,6 @@ def reinit(self) -> None:
123123

124124

125125
class ActionTransform(ActionBase):
126-
127126
"""
128127
Implements the transform action
129128
"""
@@ -152,7 +151,7 @@ def is_exhausted(self) -> bool:
152151
"""
153152
raise NotImplementedError("Method not implemented")
154153

155-
def reinit(self) -> None:
154+
def reinitialize(self) -> None:
156155
"""
157156
Reinitialize the action to the state when the
158157
constructor is called
@@ -162,7 +161,6 @@ def reinit(self) -> None:
162161

163162

164163
class ActionSuppress(ActionBase, WithHierarchyTable):
165-
166164
"""
167165
Implements the suppress action
168166
"""
@@ -219,7 +217,7 @@ def is_exhausted(self) -> bool:
219217
"""
220218
return self.finished()
221219

222-
def reinit(self) -> None:
220+
def reinitialize(self) -> None:
223221
"""
224222
Reinitialize the action to the state when the
225223
constructor is called
@@ -291,7 +289,7 @@ def is_exhausted(self) -> bool:
291289
"""
292290
return self.finished()
293291

294-
def reinit(self) -> None:
292+
def reinitialize(self) -> None:
295293
"""
296294
Reinitialize the action to the state when the
297295
constructor is called

0 commit comments

Comments
 (0)