Skip to content

Commit d49ef0e

Browse files
committed
Fix review comments. Update API
1 parent 4e3d5bc commit d49ef0e

File tree

10 files changed

+137
-57
lines changed

10 files changed

+137
-57
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
src/preprocessor/__pycache__/
22
src/exceptions/__pycache__/
3+
src/utils/__pycache__/
4+
src/tests/.pytest_cache/
5+
src/spaces/__pycache__/

src/datasets/dataset_distances.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212

1313
def lp_distance(ds1: DataSet, ds2: DataSet, p=None):
14+
"""
15+
Compute the Lp norms between the respective columns in the given data sets.
16+
This means that the two datasets should have the same schema. It is
17+
up to the application to ensure that the calculation is meaningless
18+
:param ds1: Dataset 1
19+
:param ds2: Dataset 2
20+
:param p: The order of the norm to calculate
21+
:return: The calculated Lp-norm
22+
"""
1423

1524
assert ds1.schema == ds2.schema, "Invalid schema for datasets"
1625

src/datasets/dataset_wrapper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,16 @@ def sample_column(self):
111111
col_idx = np.random.choice(col_names, 1)
112112
return self.get_column(col_name=col_names[col_idx])
113113

114-
def apply_transform(self, transform: Transform) -> None:
114+
def apply_column_transform(self, column_name: str, transform: Transform) -> None:
115+
"""
116+
Apply the given transformation on the underlying dataset
117+
:param column_name: The column to transform
118+
:param transform: The transformation to apply
119+
:return: None
120+
"""
115121

116122
# get the column
117-
column = self.get_column(col_name=transform.column_name)
123+
column = self.get_column(col_name=column_name)
118124
column = transform.act(**{"data": column})
119125
self.ds[transform.column_name] = column
120126

src/spaces/action_space.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,47 @@
1+
"""
2+
ActionSpace class. This is a wrapper to the discrete
3+
actions in the actions.py module
4+
"""
5+
16
from gym.spaces.discrete import Discrete
27
from src.spaces.actions import ActionBase
38

49

510
class ActionSpace(Discrete):
11+
"""
12+
ActionSpace class models a discrete action space of size n
13+
"""
614

715
def __init__(self, n: int) -> None:
816

917
super(ActionSpace, self).__init__(n=n)
18+
19+
# the list of actions the space contains
1020
self.actions = []
1121

12-
def __getitem__(self, item):
22+
def __getitem__(self, item) -> ActionBase:
23+
"""
24+
Returns the item-th action
25+
:param item: The index of the action to return
26+
:return: An action obeject
27+
"""
1328
return self.actions[item]
1429

30+
def __setitem__(self, key: int, value: ActionBase) -> None:
31+
"""
32+
Update the key-th Action with the new value
33+
:param key: The index to the action to update
34+
:param value: The new action
35+
:return: None
36+
"""
37+
self.actions[key] = value
38+
1539
def add(self, action: ActionBase) -> None:
40+
"""
41+
Add a new action in the space
42+
:param action:
43+
:return:
44+
"""
1645

1746
if len(self.actions) >= self.n:
1847
raise ValueError("Action space is saturated. You cannot add a new action")
@@ -21,11 +50,19 @@ def add(self, action: ActionBase) -> None:
2150
action.idx = len(self.actions)
2251
self.actions.append(action)
2352

24-
def add_may(self, *actions) -> None:
53+
def add_many(self, *actions) -> None:
54+
"""
55+
Add many actions in one go
56+
:param actions: List of actions to add
57+
:return: None
58+
"""
2559
for a in actions:
2660
self.add(action=a)
2761

2862
def sample_and_get(self) -> ActionBase:
29-
63+
"""
64+
Sample the space and return an action to the application
65+
:return: The sampled action
66+
"""
3067
action_idx = self.sample()
3168
return self.actions[action_idx]

src/spaces/environment.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,12 @@ def apply_action(self, action: ActionBase):
175175
:param action: The action to apply on the environment
176176
:return:
177177
"""
178+
179+
if action.action_type == ActionType.IDENTITY:
180+
return
181+
178182
# apply the transform of the data set
179-
self.data_set.apply_transform(transform=action)
183+
self.data_set.apply_column_transform(column_name=action.column_name, transform=action)
180184

181185
def step(self, action: ActionBase) -> TimeStep:
182186
"""
@@ -191,14 +195,16 @@ def step(self, action: ActionBase) -> TimeStep:
191195
`action` will be ignored.
192196
"""
193197

198+
self.apply_action(action=action)
199+
194200
# if the action is identity don't bother
195201
# doing anything
196-
if action.action_type == ActionType.IDENTITY:
197-
return TimeStep(step_type=StepType.MID, reward=0.0,
198-
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
202+
#if action.action_type == ActionType.IDENTITY:
203+
# return TimeStep(step_type=StepType.MID, reward=0.0,
204+
# observation=self.get_ds_as_tensor().float(), discount=self.gamma)
199205

200206
# apply the transform of the data set
201-
self.data_set.apply_transform(transform=action)
207+
#self.data_set.apply_column_transform(transform=action)
202208

203209
# perform the action on the data set
204210
self.prepare_column_states()

src/tests/test_actions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import unittest
22

3-
from src.utils.default_hierarchy import DefaultHierarchy
3+
from src.utils.serial_hierarchy import SerialHierarchy
44
from src.spaces.actions import ActionSuppress
55

66

77
class TestActions(unittest.TestCase):
88

99
def test_suppress_action_creation(self):
1010

11-
suppress_table = {"test": DefaultHierarchy(values=["test", "tes*", "te**", "t***", "****"]),
12-
"do_not_test": DefaultHierarchy(values=["do_not_test", "do_not_tes*", "do_not_te**", "do_not_t***", "do_not_****"])}
11+
suppress_table = {"test": SerialHierarchy(values=["test", "tes*", "te**", "t***", "****"]),
12+
"do_not_test": SerialHierarchy(values=["do_not_test", "do_not_tes*", "do_not_te**", "do_not_t***", "do_not_****"])}
1313

1414
suppress_action = ActionSuppress(column_name="none", suppress_table=suppress_table)
1515

@@ -19,8 +19,8 @@ def test_suppress_action_act(self):
1919

2020
data = ["test", "do_not_test", "invalid"]
2121

22-
suppress_table = {"test": DefaultHierarchy(values=["test", "tes*", "te**", "t***", "****"]),
23-
"do_not_test": DefaultHierarchy(values=["do_not_test", "do_not_tes*",
22+
suppress_table = {"test": SerialHierarchy(values=["test", "tes*", "te**", "t***", "****"]),
23+
"do_not_test": SerialHierarchy(values=["do_not_test", "do_not_tes*",
2424
"do_not_te**", "do_not_t***", "do_not_****"])}
2525

2626
suppress_action = ActionSuppress(column_name="none", suppress_table=suppress_table)

src/tests/test_default_hierarchy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import unittest
2-
from src.utils.default_hierarchy import DefaultHierarchy
2+
from src.utils.serial_hierarchy import SerialHierarchy
33

44

55
class TestDefaultHierarchy(unittest.TestCase):
66

77
def test_iteration(self):
88
values = ["test", "tes*", "te**", "t***", "****"]
9-
d_hierarchy = DefaultHierarchy(values=values)
9+
d_hierarchy = SerialHierarchy(values=values)
1010

1111
self.assertEqual(d_hierarchy.value, values[0], "Invalid hierarchy value")
1212
next(d_hierarchy.__iter__())

src/tests/test_environment.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from src.spaces.action_space import ActionSpace
88
from src.spaces.actions import ActionSuppress, ActionGeneralize
99
from src.exceptions.exceptions import Error
10-
from src.utils.default_hierarchy import DefaultHierarchy
10+
from src.utils.serial_hierarchy import SerialHierarchy
1111
from src.utils.string_distance_calculator import DistanceType
1212
from src.datasets.dataset_wrapper import PandasDSWrapper
1313

@@ -35,7 +35,7 @@ def setUp(self) -> None:
3535
"drop_na": True,
3636
"change_col_vals": {"diagnosis": [('N', 0)]}})
3737

38-
@pytest.mark.skip(reason="no way of currently testing this")
38+
#@pytest.mark.skip(reason="no way of currently testing this")
3939
def test_prepare_column_states_throw_Error(self):
4040
# specify the action space. We need to establish how these actions
4141
# are performed
@@ -47,7 +47,7 @@ def test_prepare_column_states_throw_Error(self):
4747
with pytest.raises(Error):
4848
env.prepare_column_states()
4949

50-
@pytest.mark.skip(reason="no way of currently testing this")
50+
#@pytest.mark.skip(reason="no way of currently testing this")
5151
def test_prepare_column_states(self):
5252
# specify the action space. We need to establish how these actions
5353
# are performed
@@ -59,7 +59,7 @@ def test_prepare_column_states(self):
5959
env.initialize_text_distances(distance_type=DistanceType.COSINE)
6060
env.prepare_column_states()
6161

62-
@pytest.mark.skip(reason="no way of currently testing this")
62+
#@pytest.mark.skip(reason="no way of currently testing this")
6363
def test_get_numeric_ds(self):
6464
# specify the action space. We need to establish how these actions
6565
# are performed
@@ -85,24 +85,24 @@ def test_apply_action(self):
8585
# are performed
8686
action_space = ActionSpace(n=1)
8787

88-
generalization_table = {"Mixed White/Asian": DefaultHierarchy(values=["Mixed", ]),
89-
"Chinese": DefaultHierarchy(values=["Asian", ]),
90-
"Indian": DefaultHierarchy(values=["Asian", ]),
91-
"Mixed White/Black African": DefaultHierarchy(values=["Mixed", ]),
92-
"Black African": DefaultHierarchy(values=["Black", ]),
93-
"Asian other": DefaultHierarchy(values=["Asian", ]),
94-
"Black other": DefaultHierarchy(values=["Black", ]),
95-
"Mixed White/Black Caribbean": DefaultHierarchy(values=["Mixed", ]),
96-
"Mixed other": DefaultHierarchy(values=["Mixed", ]),
97-
"Arab": DefaultHierarchy(values=["Asian", ]),
98-
"White Irish": DefaultHierarchy(values=["White", ]),
99-
"Not stated": DefaultHierarchy(values=["Not stated"]),
100-
"White Gypsy/Traveller": DefaultHierarchy(values=["White", ]),
101-
"White British": DefaultHierarchy(values=["White", ]),
102-
"Bangladeshi": DefaultHierarchy(values=["Asian", ]),
103-
"White other": DefaultHierarchy(values=["White", ]),
104-
"Black Caribbean": DefaultHierarchy(values=["Black", ]),
105-
"Pakistani": DefaultHierarchy(values=["Asian", ])}
88+
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
89+
"Chinese": SerialHierarchy(values=["Asian", ]),
90+
"Indian": SerialHierarchy(values=["Asian", ]),
91+
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
92+
"Black African": SerialHierarchy(values=["Black", ]),
93+
"Asian other": SerialHierarchy(values=["Asian", ]),
94+
"Black other": SerialHierarchy(values=["Black", ]),
95+
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
96+
"Mixed other": SerialHierarchy(values=["Mixed", ]),
97+
"Arab": SerialHierarchy(values=["Asian", ]),
98+
"White Irish": SerialHierarchy(values=["White", ]),
99+
"Not stated": SerialHierarchy(values=["Not stated"]),
100+
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
101+
"White British": SerialHierarchy(values=["White", ]),
102+
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
103+
"White other": SerialHierarchy(values=["White", ]),
104+
"Black Caribbean": SerialHierarchy(values=["Black", ]),
105+
"Pakistani": SerialHierarchy(values=["Asian", ])}
106106

107107
action_space.add(ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
108108

src/utils/hierarchy_base.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
import abc
2-
from pathlib import Path
3-
from typing import TypeVar
4-
1+
"""
2+
HierarchyBase. A hierarchy represents as series of transformations
3+
that can be applied on data. For example assume that the
4+
data field has the value 'foo' a hierarchy of transformations then may be
5+
the following list ['fo*', 'f**', '***']. If this hierarchy is fully applied
6+
on 'foo' then 'foo' will be completely suppressed
7+
"""
58

6-
#HierarchyBase = TypeVar("HierarchyBase")
9+
import abc
710

811

912
class HierarchyBase(metaclass=abc.ABCMeta):
1013

1114
def __init__(self):
12-
pass
13-
14-
#@abc.abstractmethod
15-
#def read_from(self, filename: Path) -> HierarchyBase:
16-
"""
17-
Reads the values of the hierarchy from the file
18-
:param filename: The file to read the values of the hierarchy
19-
:return: None
20-
"""
15+
pass
Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
"""
2+
A SerialHierarchy represents a hierarchy of transformations
3+
that are applied one after the other
4+
"""
5+
16
from typing import List, Any
27
from src.utils.hierarchy_base import HierarchyBase
38

49

5-
class DefaultHierarchyIterator(object):
10+
class SerialtHierarchyIterator(object):
11+
"""
12+
SerialtHierarchyIterator class. Helper class to iterate over a
13+
SerialHierarchy object
14+
"""
615

716
def __init__(self, values: List):
817
self.current_position = 0
@@ -27,11 +36,23 @@ def __next__(self):
2736
raise StopIteration
2837

2938

30-
class DefaultHierarchy(HierarchyBase):
39+
class SerialHierarchy(HierarchyBase):
3140

41+
"""
42+
A SerialHierarchy represents a hierarchy of transformations
43+
that are applied one after the other. Applications should explicitly
44+
provide the list of the ensuing transformations. For example assume that the
45+
data field has the value 'foo' then values
46+
the following list ['fo*', 'f**', '***']
47+
"""
3248
def __init__(self, values: List) -> None:
33-
super(DefaultHierarchy, self).__init__()
34-
self.iterator = DefaultHierarchyIterator(values=values)
49+
"""
50+
Constructor. Initialize the hierarchy by passing the
51+
list of the ensuing transformations.
52+
:param values:
53+
"""
54+
super(SerialHierarchy, self).__init__()
55+
self.iterator = SerialtHierarchyIterator(values=values)
3556

3657
def __iter__(self):
3758
"""
@@ -42,5 +63,8 @@ def __iter__(self):
4263

4364
@property
4465
def value(self) -> Any:
66+
"""
67+
:return: the current value the hierarchy assumes
68+
"""
4569
return self.iterator.at
4670

0 commit comments

Comments
 (0)