Skip to content

Commit aa36ce4

Browse files
committed
#13 Add test for info leakage
1 parent 64d43e9 commit aa36ce4

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import unittest
2+
import pytest
3+
4+
from src.spaces.environment import Environment, EnvConfig
5+
from src.spaces.action_space import ActionSpace
6+
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize
7+
from src.utils.serial_hierarchy import SerialHierarchy
8+
from src.datasets.datasets_loaders import MockSubjectsLoader
9+
from src.utils.reward_manager import RewardManager
10+
from src.utils.string_distance_calculator import DistanceType
11+
from src.datasets.dataset_information_leakage import info_leakage
12+
13+
14+
class TestDatasetInfoLeakage(unittest.TestCase):
15+
16+
def setUp(self) -> None:
17+
"""
18+
Setup the PandasDSWrapper to be used in the tests
19+
:return: None
20+
"""
21+
22+
# load the dataset
23+
self.ds = MockSubjectsLoader()
24+
25+
# specify the action space. We need to establish how these actions
26+
# are performed
27+
self.action_space = ActionSpace(n=4)
28+
29+
self.generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
30+
"Chinese": SerialHierarchy(values=["Asian", ]),
31+
"Indian": SerialHierarchy(values=["Asian", ]),
32+
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
33+
"Black African": SerialHierarchy(values=["Black", ]),
34+
"Asian other": SerialHierarchy(values=["Asian", ]),
35+
"Black other": SerialHierarchy(values=["Black", ]),
36+
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
37+
"Mixed other": SerialHierarchy(values=["Mixed", ]),
38+
"Arab": SerialHierarchy(values=["Asian", ]),
39+
"White Irish": SerialHierarchy(values=["White", ]),
40+
"Not stated": SerialHierarchy(values=["Not stated"]),
41+
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
42+
"White British": SerialHierarchy(values=["White", ]),
43+
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
44+
"White other": SerialHierarchy(values=["White", ]),
45+
"Black Caribbean": SerialHierarchy(values=["Black", ]),
46+
"Pakistani": SerialHierarchy(values=["Asian", ])}
47+
48+
49+
self.action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]),
50+
'M': SerialHierarchy(values=['*', ])}),
51+
ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"),
52+
ActionGeneralize(column_name="ethnicity", generalization_table=self.generalization_table))
53+
self.reward_manager = RewardManager()
54+
55+
def test_info_leakage_1(self):
56+
"""
57+
No distortion is applied on the data set so total distortion
58+
should be zero
59+
"""
60+
61+
env_config = EnvConfig()
62+
env_config.action_space = self.action_space
63+
env_config.reward_manager = self.reward_manager
64+
env_config.data_set = self.ds
65+
env_config.start_column = "gender"
66+
env_config.gamma = 0.99
67+
68+
# create the environment
69+
env = Environment(env_config=env_config)
70+
71+
# initialize text distances
72+
env.initialize_text_distances(distance_type=DistanceType.COSINE)
73+
74+
distances, sum_distances = info_leakage(ds1=env.data_set, ds2=env.start_ds, column_distances=env.column_distances)
75+
76+
# no leakage should exist as no trasformation is applied
77+
self.assertEqual(0.0, sum_distances)
78+
79+
#@pytest.mark.skip(reason="no way of currently testing this")
80+
def test_info_leakage_2(self):
81+
"""
82+
We apply distortion on column gender
83+
"""
84+
85+
env_config = EnvConfig()
86+
env_config.action_space = self.action_space
87+
env_config.reward_manager = self.reward_manager
88+
env_config.data_set = self.ds
89+
env_config.start_column = "gender"
90+
env_config.gamma = 0.99
91+
92+
# create the environment
93+
env = Environment(env_config=env_config)
94+
95+
# initialize text distances
96+
env.initialize_text_distances(distance_type=DistanceType.COSINE)
97+
98+
action = env.action_space.get_action_by_column_name(column_name="gender")
99+
100+
env.step(action=action)
101+
102+
distances, sum_distances = info_leakage(ds1=env.data_set, ds2=env.start_ds,
103+
column_distances=env.column_distances)
104+
105+
# leakage should exist as we suppress the gender column
106+
self.assertNotEqual(0.0, sum_distances)
107+
108+
109+
if __name__ == '__main__':
110+
unittest.main()

0 commit comments

Comments
 (0)