1010from src .utils .serial_hierarchy import SerialHierarchy
1111from src .utils .string_distance_calculator import DistanceType
1212from src .datasets .dataset_wrapper import PandasDSWrapper
13+ from src .utils .reward_manager import RewardManager
1314
1415
1516class TestEnvironment (unittest .TestCase ):
@@ -20,6 +21,9 @@ def setUp(self) -> None:
2021 :return: None
2122 """
2223
24+ # specify the reward manager to use
25+ self .reward_manager = RewardManager ()
26+
2327 # read the data
2428 filename = Path ("../../data/mocksubjects.csv" )
2529
@@ -35,7 +39,26 @@ def setUp(self) -> None:
3539 "drop_na" : True ,
3640 "change_col_vals" : {"diagnosis" : [('N' , 0 )]}})
3741
38- #@pytest.mark.skip(reason="no way of currently testing this")
42+ self .generalization_table = {"Mixed White/Asian" : SerialHierarchy (values = ["Mixed" , ]),
43+ "Chinese" : SerialHierarchy (values = ["Asian" , ]),
44+ "Indian" : SerialHierarchy (values = ["Asian" , ]),
45+ "Mixed White/Black African" : SerialHierarchy (values = ["Mixed" , ]),
46+ "Black African" : SerialHierarchy (values = ["Black" , ]),
47+ "Asian other" : SerialHierarchy (values = ["Asian" , ]),
48+ "Black other" : SerialHierarchy (values = ["Black" , ]),
49+ "Mixed White/Black Caribbean" : SerialHierarchy (values = ["Mixed" , ]),
50+ "Mixed other" : SerialHierarchy (values = ["Mixed" , ]),
51+ "Arab" : SerialHierarchy (values = ["Asian" , ]),
52+ "White Irish" : SerialHierarchy (values = ["White" , ]),
53+ "Not stated" : SerialHierarchy (values = ["Not stated" ]),
54+ "White Gypsy/Traveller" : SerialHierarchy (values = ["White" , ]),
55+ "White British" : SerialHierarchy (values = ["White" , ]),
56+ "Bangladeshi" : SerialHierarchy (values = ["Asian" , ]),
57+ "White other" : SerialHierarchy (values = ["White" , ]),
58+ "Black Caribbean" : SerialHierarchy (values = ["Black" , ]),
59+ "Pakistani" : SerialHierarchy (values = ["Asian" , ])}
60+
61+ @pytest .mark .skip (reason = "no way of currently testing this" )
3962 def test_prepare_column_states_throw_Error (self ):
4063 # specify the action space. We need to establish how these actions
4164 # are performed
@@ -47,7 +70,7 @@ def test_prepare_column_states_throw_Error(self):
4770 with pytest .raises (Error ):
4871 env .prepare_column_states ()
4972
50- # @pytest.mark.skip(reason="no way of currently testing this")
73+ @pytest .mark .skip (reason = "no way of currently testing this" )
5174 def test_prepare_column_states (self ):
5275 # specify the action space. We need to establish how these actions
5376 # are performed
@@ -59,14 +82,15 @@ def test_prepare_column_states(self):
5982 env .initialize_text_distances (distance_type = DistanceType .COSINE )
6083 env .prepare_column_states ()
6184
62- # @pytest.mark.skip(reason="no way of currently testing this")
85+ @pytest .mark .skip (reason = "no way of currently testing this" )
6386 def test_get_numeric_ds (self ):
6487 # specify the action space. We need to establish how these actions
6588 # are performed
6689 action_space = ActionSpace (n = 1 )
6790
6891 # create the environment and
69- env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
92+ env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 ,
93+ start_column = "gender" , reward_manager = self .reward_manager )
7094
7195 env .initialize_text_distances (distance_type = DistanceType .COSINE )
7296 env .prepare_column_states ()
@@ -85,6 +109,7 @@ def test_apply_action(self):
85109 # are performed
86110 action_space = ActionSpace (n = 1 )
87111
112+ """
88113 generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
89114 "Chinese": SerialHierarchy(values=["Asian", ]),
90115 "Indian": SerialHierarchy(values=["Asian", ]),
@@ -103,11 +128,13 @@ def test_apply_action(self):
103128 "White other": SerialHierarchy(values=["White", ]),
104129 "Black Caribbean": SerialHierarchy(values=["Black", ]),
105130 "Pakistani": SerialHierarchy(values=["Asian", ])}
131+ """
106132
107- action_space .add (ActionGeneralize (column_name = "ethnicity" , generalization_table = generalization_table ))
133+ action_space .add (ActionGeneralize (column_name = "ethnicity" , generalization_table = self . generalization_table ))
108134
109135 # create the environment and
110- env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
136+ env = Environment (data_set = self .ds , action_space = action_space ,
137+ gamma = 0.99 , start_column = "gender" , reward_manager = self .reward_manager )
111138
112139 # this will update the environment
113140 env .apply_action (action = action_space [0 ])
@@ -116,10 +143,29 @@ def test_apply_action(self):
116143 # get the unique values for the ethnicity column
117144 unique_col_vals = env .data_set .get_column_unique_values (col_name = "ethnicity" )
118145
119- print (unique_col_vals )
120-
121146 unique_vals = ["Mixed" , "Asian" , "Not stated" , "White" , "Black" ]
122147 self .assertEqual (len (unique_vals ), len (unique_col_vals ))
148+ self .assertEqual (unique_vals , unique_col_vals )
149+
150+ def test_step (self ):
151+ # specify the action space. We need to establish how these actions
152+ # are performed
153+ action_space = ActionSpace (n = 1 )
154+ action_space .add (ActionGeneralize (column_name = "ethnicity" , generalization_table = self .generalization_table ))
155+
156+ # create the environment and
157+ env = Environment (data_set = self .ds , action_space = action_space ,
158+ gamma = 0.99 , start_column = "gender" , reward_manager = self .reward_manager )
159+
160+ action = env .sample_action ()
161+
162+ # this will update the environment
163+ time_step = env .step (action = action )
164+
165+
166+
167+
168+
123169
124170if __name__ == '__main__' :
125171 unittest .main ()
0 commit comments