Skip to content

Commit c278ee2

Browse files
committed
#13 Update API
1 parent 07f755a commit c278ee2

File tree

4 files changed

+89
-15
lines changed

4 files changed

+89
-15
lines changed

src/apps/qlearning_on_mock.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from src.algorithms.trainer import Trainer
33
from src.utils.string_distance_calculator import DistanceType
44
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform
5-
from src.spaces.environment import Environment
5+
from src.spaces.environment import Environment, EnvConfig
66
from src.spaces.action_space import ActionSpace
77
from src.datasets.datasets_loaders import MockSubjectsLoader
88
from src.utils.reward_manager import RewardManager
99
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonDecreaseOption
1010
from src.utils.serial_hierarchy import SerialHierarchy
11+
from src.utils.numeric_distance_type import NumericDistanceType
1112

1213

1314
if __name__ == '__main__':
@@ -47,12 +48,23 @@
4748
ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"),
4849
ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
4950

51+
average_distortion_constraint = {"salary": [0.0, 0.0, 0.0], "education": [0.0, 0.0, 0.0],
52+
"ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0]}
53+
5054
# specify the reward manager to use
51-
reward_manager = RewardManager()
55+
reward_manager = RewardManager(average_distortion_constraint=average_distortion_constraint)
56+
57+
env_config = EnvConfig()
58+
env_config.start_column = "gender"
59+
env_config.action_space = action_space
60+
env_config.reward_manager = reward_manager
61+
env_config.data_set = ds
62+
env_config.gamma = 0.99
63+
env_config.numeric_column_distortion_metric_type = NumericDistanceType.L2
5264

5365
# create the environment
54-
env = Environment(data_set=ds, action_space=action_space,
55-
gamma=0.99, start_column="gender", reward_manager=reward_manager)
66+
env = Environment(env_config=env_config)
67+
5668
# initialize text distances
5769
env.initialize_text_distances(distance_type=DistanceType.COSINE)
5870

src/datasets/dataset_information_leakage.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,22 @@
44
"""
55
import numpy as np
66
from typing import TypeVar
7-
from src.exceptions.exceptions import InvalidSchemaException
7+
from src.exceptions.exceptions import InvalidSchemaException, Error
88
from src.datasets.dataset_distances import lp_distance
9+
from src.utils import numeric_distance_type
910

1011
DataSet = TypeVar("DataSet")
12+
State = TypeVar("State")
13+
14+
15+
def state_leakage(state1: State, state2: State, dist_type: numeric_distance_type.NumericDistanceType) -> float:
16+
17+
if dist_type == numeric_distance_type.NumericDistanceType.L2:
18+
return _l2_state_leakage(state1=state1, state2=state2)
19+
elif dist_type == numeric_distance_type.NumericDistanceType.L1:
20+
return _l1_state_leakage(state1=state1, state2=state2)
21+
22+
raise Error("Invalid distance type {0}".format(dist_type.name))
1123

1224

1325
def info_leakage(ds1: DataSet, ds2: DataSet, column_distances: dict = None, p=None) -> tuple:
@@ -43,4 +55,14 @@ def info_leakage(ds1: DataSet, ds2: DataSet, column_distances: dict = None, p=No
4355
return distances, sum_distances
4456

4557

58+
def _l2_state_leakage(state1: State, state2: State) -> float:
59+
return np.linalg.norm(state1 - state2, ord=None)
60+
61+
def _l1_state_leakage(state1: State, state2: State) -> float:
62+
return np.linalg.norm(state1 - state2, ord=1)
63+
64+
65+
66+
67+
4668

src/spaces/environment.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from src.spaces.actions import ActionBase, ActionType
1616
from src.spaces.state_space import StateSpace, State
1717
from src.utils.string_distance_calculator import DistanceType, TextDistanceCalculator
18+
from src.utils.numeric_distance_type import NumericDistanceType
19+
from src.datasets.dataset_information_leakage import state_leakage
1820

1921
DataSet = TypeVar("DataSet")
2022
RewardManager = TypeVar("RewardManager")
@@ -77,6 +79,7 @@ def __init__(self):
7779
self.average_distortion_constraint: float = 0
7880
self.start_column: str = "None_Column"
7981
self.gamma: float = 0.99
82+
self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
8083

8184

8285
class Environment(object):
@@ -99,6 +102,7 @@ def __init__(self, env_config: EnvConfig):
99102
self.state_space = StateSpace()
100103
self.distance_calculator = None
101104
self.reward_manager: RewardManager = env_config.reward_manager
105+
self.numeric_column_distortion_metric_type = env_config.numeric_column_distortion_metric_type
102106

103107
# initialize the state space
104108
self.state_space.init_from_environment(env=self)
@@ -219,15 +223,26 @@ def prepare_column_state(self, column_name):
219223
start_column = self.start_ds.get_column(col_name=column_name)
220224

221225
row_count = 0
222-
print("Distance {0} ".format(self.distance_calculator.calculate(txt1="".join(current_column.values),
223-
txt2="".join(start_column.values))))
224226

227+
# join the column to calculate the distance
225228
self.column_distances[column_name] = self.distance_calculator.calculate(txt1="".join(current_column.values),
226229
txt2="".join(start_column.values))
227-
#for item1, item2 in zip(current_column.values, start_column.values):
228-
# #self.column_distances[column_name][row_count] = self.distance_calculator.calculate(txt1=item1, txt2=item2)
229230

230-
# row_count += 1
231+
def get_state_distortion(self, state_name) -> float:
232+
"""
233+
Returns the distortion for the state with the given name
234+
:param state_name:
235+
:return:
236+
"""
237+
if self.start_ds.columns[state_name] == str:
238+
return self.column_distances[state_name]
239+
else:
240+
241+
current_column = self.data_set.get_column(col_name=state_name)
242+
start_column = self.start_ds.get_column(col_name=state_name)
243+
244+
return state_leakage(state1=current_column,
245+
state2=start_column, dist_type=self.numeric_column_distortion_metric_type)
231246

232247
def prepare_columns_state(self):
233248
"""
@@ -299,6 +314,7 @@ def apply_action(self, action: ActionBase):
299314
:return:
300315
"""
301316

317+
# nothing to act on identity
302318
if action.action_type == ActionType.IDENTITY:
303319
return
304320

@@ -333,14 +349,17 @@ def step(self, action: ActionBase) -> TimeStep:
333349
# update the state space
334350
self.state_space.update_state(state_name=action.column_name, status=action.action_type)
335351

352+
# prepare the column state. We only do work
353+
# if the column is a string
336354
self.prepare_column_state(column_name=action.column_name)
337355

338356
# perform the action on the data set
339357
#self.prepare_columns_state()
340358

341359
# calculate the information leakage and establish the reward
342360
# to return to the agent
343-
reward = self.reward_manager.get_state_reward(self.state_space, action)
361+
state_distortion = self.get_state_distortion(state_name=action.column_name)
362+
reward = self.reward_manager.get_state_reward(action.column_name, action, state_distortion)
344363

345364
# what is the next state? maybe do it randomly?
346365
# or select the next column in the dataset

src/utils/reward_manager.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,34 @@
22
Various utilities to handle reward assignment
33
"""
44

5+
from typing import TypeVar
6+
7+
8+
State = TypeVar("State")
9+
Action = TypeVar("Action")
10+
511

612
class RewardManager(object):
713
"""
814
Helper class to assign rewards
915
"""
10-
def __init__(self) -> None:
11-
pass
16+
def __init__(self, average_distortion_constraint: dict) -> None:
17+
self.average_distortion_constraint: dict = average_distortion_constraint
18+
19+
def get_state_reward(self, state_name: str, action: Action, state_distortion: float) -> float:
20+
"""
21+
Returns the reward associated with the action
22+
applied
23+
:param options:
24+
:return:
25+
"""
26+
27+
if state_name not in self.average_distortion_constraint:
28+
raise KeyError("state {0} does not exist".format(state_name))
29+
30+
state_rewards = self.average_distortion_constraint[state_name]
31+
32+
if state_distortion < state_rewards[0]:
33+
return state_rewards[1]
1234

13-
def get_state_reward(self, *options) -> float:
14-
return 0.0
35+
return state_rewards[2]

0 commit comments

Comments
 (0)