Skip to content

Commit 6d5be8e

Browse files
committed
Update API
1 parent a26442c commit 6d5be8e

File tree

5 files changed

+167
-24
lines changed

5 files changed

+167
-24
lines changed

src/spaces/actions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def act(self, **ops) -> None:
4646
:return:
4747
"""
4848

49+
@abc.abstractmethod
50+
def get_maximum_number_of_transforms(self):
51+
"""
52+
Returns the maximum number of transforms that the action applies
53+
:return:
54+
"""
55+
4956

5057
def move_next(iterators: List) -> None:
5158
"""
@@ -90,6 +97,13 @@ def act(self, **ops):
9097
"""
9198
pass
9299

100+
def get_maximum_number_of_transforms(self):
101+
"""
102+
Returns the maximum number of transforms that the action applies
103+
:return:
104+
"""
105+
return 1
106+
93107

94108
class ActionTransform(ActionBase):
95109

@@ -106,6 +120,13 @@ def act(self, **ops):
106120
"""
107121
pass
108122

123+
def get_maximum_number_of_transforms(self):
124+
"""
125+
Returns the maximum number of transforms that the action applies
126+
:return:
127+
"""
128+
raise NotImplementedError("Method not implemented")
129+
109130

110131
class ActionSuppress(ActionBase, _WithTable):
111132

@@ -138,6 +159,21 @@ def act(self, **ops) -> None:
138159
# update the generalization
139160
move_next(iterators=self.iterators)
140161

162+
def get_maximum_number_of_transforms(self):
163+
"""
164+
Returns the maximum number of transforms that the action applies
165+
:return:
166+
"""
167+
max_transform = 0
168+
169+
for item in self.table:
170+
size = len(self.table[item])
171+
172+
if size > max_transform:
173+
max_transform = size
174+
175+
return max_transform
176+
141177

142178
class ActionGeneralize(ActionBase, _WithTable):
143179
"""
@@ -181,5 +217,21 @@ def act(self, **ops):
181217
def add_generalization(self, key: str, values: HierarchyBase) -> None:
182218
self.table[key] = values
183219

220+
def get_maximum_number_of_transforms(self):
221+
"""
222+
Returns the maximum number of transforms that the action applies
223+
:return:
224+
"""
225+
max_transform = 0
226+
227+
for item in self.table:
228+
size = len(self.table[item])
229+
230+
if size > max_transform:
231+
max_transform = size
232+
233+
return max_transform
234+
235+
184236

185237

src/spaces/environment.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313

1414
from src.exceptions.exceptions import Error
1515
from src.spaces.actions import ActionBase, ActionType
16+
from src.spaces.state_space import StateSpace, State
1617
from src.utils.string_distance_calculator import DistanceType, TextDistanceCalculator
1718

1819
DataSet = TypeVar("DataSet")
20+
RewardManager = TypeVar("RewardManager")
1921

2022
_Reward = TypeVar('_Reward')
2123
_Discount = TypeVar('_Discount')
@@ -65,20 +67,37 @@ def last(self) -> bool:
6567
class Environment(object):
6668

6769
def __init__(self, data_set, action_space,
68-
gamma: float, start_column: str, ):
70+
gamma: float, start_column: str, reward_manager: RewardManager):
6971
self.data_set = data_set
7072
self.start_ds = copy.deepcopy(data_set)
7173
self.current_time_step = self.start_ds
7274
self.action_space = action_space
7375
self.gamma = gamma
7476
self.start_column = start_column
7577
self.column_distances = {}
78+
self.state_space = StateSpace()
7679
self.distance_calculator = None
80+
self.reward_manager: RewardManager = reward_manager
81+
82+
# initialize the state space
83+
self.state_space.init_from_environment(env=self)
7784

7885
@property
7986
def n_features(self) -> int:
87+
"""
88+
Returns the number of features in the dataset
89+
:return:
90+
"""
8091
return self.start_ds.n_columns
8192

93+
@property
94+
def feature_names(self) -> list:
95+
"""
96+
Returns the feature names in the dataset
97+
:return:
98+
"""
99+
return self.start_ds.get_columns_names()
100+
82101
@property
83102
def n_examples(self) -> int:
84103
return self.start_ds.n_rows
@@ -99,6 +118,24 @@ def initialize_text_distances(self, distance_type: DistanceType) -> None:
99118
def sample_action(self) -> ActionBase:
100119
return self.action_space.sample_and_get()
101120

121+
def get_column_as_tensor(self, column_name) -> torch.Tensor:
122+
"""
123+
Returns the column in the dataset as a torch tensor
124+
:param column_name:
125+
:return:
126+
"""
127+
data = {}
128+
129+
if self.start_ds.columns[column_name] == str:
130+
131+
numpy_vals = self.column_distances[column_name]
132+
data[column_name] = numpy_vals
133+
else:
134+
data[column_name] = self.data_set.get_column(col_name=column_name).to_numpy()
135+
136+
target_df = pd.DataFrame(data)
137+
return torch.tensor(target_df.to_numpy(), dtype=torch.float64)
138+
102139
def get_ds_as_tensor(self) -> torch.Tensor:
103140

104141
"""
@@ -111,7 +148,6 @@ def get_ds_as_tensor(self) -> torch.Tensor:
111148
for col in col_names:
112149

113150
if self.start_ds.columns[col] == str:
114-
#print("col: {0} type {1}".format(col, self.start_ds.get_column_type(col_name=col)))
115151
numpy_vals = self.column_distances[col]
116152
data[col] = numpy_vals
117153
else:
@@ -195,28 +231,22 @@ def step(self, action: ActionBase) -> TimeStep:
195231
`action` will be ignored.
196232
"""
197233

234+
# apply the action
198235
self.apply_action(action=action)
199236

200-
# if the action is identity don't bother
201-
# doing anything
202-
#if action.action_type == ActionType.IDENTITY:
203-
# return TimeStep(step_type=StepType.MID, reward=0.0,
204-
# observation=self.get_ds_as_tensor().float(), discount=self.gamma)
205-
206-
# apply the transform of the data set
207-
#self.data_set.apply_column_transform(transform=action)
237+
# update the state space
238+
self.state_space.update_state(state_name=action.column_name, status=action.action_type)
208239

209240
# perform the action on the data set
210241
self.prepare_column_states()
211242

212243
# calculate the information leakage and establish the reward
213244
# to return to the agent
245+
reward = self.reward_manager.get_state_reward(self.state_space, action)
214246

215-
return TimeStep(step_type=StepType.MID, reward=0.0,
216-
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
217-
218-
219-
247+
return TimeStep(step_type=StepType.MID, reward=reward,
248+
observation=self.get_column_as_tensor(column_name=action.column_name).float(),
249+
discount=self.gamma)
220250

221251

222252
class MultiprocessEnv(object):

src/tests/test_environment.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from src.utils.serial_hierarchy import SerialHierarchy
1111
from src.utils.string_distance_calculator import DistanceType
1212
from src.datasets.dataset_wrapper import PandasDSWrapper
13+
from src.utils.reward_manager import RewardManager
1314

1415

1516
class TestEnvironment(unittest.TestCase):
@@ -20,6 +21,9 @@ def setUp(self) -> None:
2021
:return: None
2122
"""
2223

24+
# specify the reward manager to use
25+
self.reward_manager = RewardManager()
26+
2327
# read the data
2428
filename = Path("../../data/mocksubjects.csv")
2529

@@ -35,7 +39,26 @@ def setUp(self) -> None:
3539
"drop_na": True,
3640
"change_col_vals": {"diagnosis": [('N', 0)]}})
3741

38-
#@pytest.mark.skip(reason="no way of currently testing this")
42+
self.generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
43+
"Chinese": SerialHierarchy(values=["Asian", ]),
44+
"Indian": SerialHierarchy(values=["Asian", ]),
45+
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
46+
"Black African": SerialHierarchy(values=["Black", ]),
47+
"Asian other": SerialHierarchy(values=["Asian", ]),
48+
"Black other": SerialHierarchy(values=["Black", ]),
49+
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
50+
"Mixed other": SerialHierarchy(values=["Mixed", ]),
51+
"Arab": SerialHierarchy(values=["Asian", ]),
52+
"White Irish": SerialHierarchy(values=["White", ]),
53+
"Not stated": SerialHierarchy(values=["Not stated"]),
54+
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
55+
"White British": SerialHierarchy(values=["White", ]),
56+
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
57+
"White other": SerialHierarchy(values=["White", ]),
58+
"Black Caribbean": SerialHierarchy(values=["Black", ]),
59+
"Pakistani": SerialHierarchy(values=["Asian", ])}
60+
61+
@pytest.mark.skip(reason="no way of currently testing this")
3962
def test_prepare_column_states_throw_Error(self):
4063
# specify the action space. We need to establish how these actions
4164
# are performed
@@ -47,7 +70,7 @@ def test_prepare_column_states_throw_Error(self):
4770
with pytest.raises(Error):
4871
env.prepare_column_states()
4972

50-
#@pytest.mark.skip(reason="no way of currently testing this")
73+
@pytest.mark.skip(reason="no way of currently testing this")
5174
def test_prepare_column_states(self):
5275
# specify the action space. We need to establish how these actions
5376
# are performed
@@ -59,14 +82,15 @@ def test_prepare_column_states(self):
5982
env.initialize_text_distances(distance_type=DistanceType.COSINE)
6083
env.prepare_column_states()
6184

62-
#@pytest.mark.skip(reason="no way of currently testing this")
85+
@pytest.mark.skip(reason="no way of currently testing this")
6386
def test_get_numeric_ds(self):
6487
# specify the action space. We need to establish how these actions
6588
# are performed
6689
action_space = ActionSpace(n=1)
6790

6891
# create the environment and
69-
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
92+
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99,
93+
start_column="gender", reward_manager=self.reward_manager)
7094

7195
env.initialize_text_distances(distance_type=DistanceType.COSINE)
7296
env.prepare_column_states()
@@ -85,6 +109,7 @@ def test_apply_action(self):
85109
# are performed
86110
action_space = ActionSpace(n=1)
87111

112+
"""
88113
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
89114
"Chinese": SerialHierarchy(values=["Asian", ]),
90115
"Indian": SerialHierarchy(values=["Asian", ]),
@@ -103,11 +128,13 @@ def test_apply_action(self):
103128
"White other": SerialHierarchy(values=["White", ]),
104129
"Black Caribbean": SerialHierarchy(values=["Black", ]),
105130
"Pakistani": SerialHierarchy(values=["Asian", ])}
131+
"""
106132

107-
action_space.add(ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
133+
action_space.add(ActionGeneralize(column_name="ethnicity", generalization_table=self.generalization_table))
108134

109135
# create the environment and
110-
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
136+
env = Environment(data_set=self.ds, action_space=action_space,
137+
gamma=0.99, start_column="gender", reward_manager=self.reward_manager)
111138

112139
# this will update the environment
113140
env.apply_action(action=action_space[0])
@@ -116,10 +143,29 @@ def test_apply_action(self):
116143
# get the unique values for the ethnicity column
117144
unique_col_vals = env.data_set.get_column_unique_values(col_name="ethnicity")
118145

119-
print(unique_col_vals)
120-
121146
unique_vals = ["Mixed", "Asian", "Not stated", "White", "Black"]
122147
self.assertEqual(len(unique_vals), len(unique_col_vals))
148+
self.assertEqual(unique_vals, unique_col_vals)
149+
150+
def test_step(self):
151+
# specify the action space. We need to establish how these actions
152+
# are performed
153+
action_space = ActionSpace(n=1)
154+
action_space.add(ActionGeneralize(column_name="ethnicity", generalization_table=self.generalization_table))
155+
156+
# create the environment and
157+
env = Environment(data_set=self.ds, action_space=action_space,
158+
gamma=0.99, start_column="gender", reward_manager=self.reward_manager)
159+
160+
action = env.sample_action()
161+
162+
# this will update the environment
163+
time_step = env.step(action=action)
164+
165+
166+
167+
168+
123169

124170
if __name__ == '__main__':
125171
unittest.main()

src/utils/hierarchy_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
class HierarchyBase(metaclass=abc.ABCMeta):
1313

1414
def __init__(self):
15-
pass
15+
pass
16+

src/utils/serial_hierarchy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def __next__(self):
3535

3636
raise StopIteration
3737

38+
def __len__(self):
39+
"""
40+
Returns the total number of items in the iterator
41+
:return:
42+
"""
43+
return len(self.values)
44+
3845

3946
class SerialHierarchy(HierarchyBase):
4047

@@ -68,3 +75,10 @@ def value(self) -> Any:
6875
"""
6976
return self.iterator.at
7077

78+
def __len__(self):
79+
"""
80+
Returns the size of the hierarchy
81+
:return:
82+
"""
83+
return len(self.iterator)
84+

0 commit comments

Comments
 (0)