55
66from src .spaces .environment import Environment
77from src .spaces .action_space import ActionSpace
8+ from src .spaces .actions import ActionSuppress , ActionGeneralize
89from src .exceptions .exceptions import Error
10+ from src .utils .default_hierarchy import DefaultHierarchy
911from src .utils .string_distance_calculator import DistanceType
1012from src .datasets .dataset_wrapper import PandasDSWrapper
1113
@@ -33,7 +35,7 @@ def setUp(self) -> None:
3335 "drop_na" : True ,
3436 "change_col_vals" : {"diagnosis" : [('N' , 0 )]}})
3537
36- # @pytest.mark.skip(reason="no way of currently testing this")
38+ @pytest .mark .skip (reason = "no way of currently testing this" )
3739 def test_prepare_column_states_throw_Error (self ):
3840 # specify the action space. We need to establish how these actions
3941 # are performed
@@ -45,7 +47,7 @@ def test_prepare_column_states_throw_Error(self):
4547 with pytest .raises (Error ):
4648 env .prepare_column_states ()
4749
48- # @pytest.mark.skip(reason="no way of currently testing this")
50+ @pytest .mark .skip (reason = "no way of currently testing this" )
4951 def test_prepare_column_states (self ):
5052 # specify the action space. We need to establish how these actions
5153 # are performed
@@ -57,6 +59,7 @@ def test_prepare_column_states(self):
5759 env .initialize_text_distances (distance_type = DistanceType .COSINE )
5860 env .prepare_column_states ()
5961
62+ @pytest .mark .skip (reason = "no way of currently testing this" )
6063 def test_get_numeric_ds (self ):
6164 # specify the action space. We need to establish how these actions
6265 # are performed
@@ -74,12 +77,49 @@ def test_get_numeric_ds(self):
7477 shape0 = tensor .size (dim = 0 )
7578 shape1 = tensor .size (dim = 1 )
7679
77- self .assertEqual (shape0 , env .start_ds .n_rows () )
78- self .assertEqual (shape1 , env .start_ds .n_columns () )
80+ self .assertEqual (shape0 , env .start_ds .n_rows )
81+ self .assertEqual (shape1 , env .start_ds .n_columns )
7982
83+ def test_apply_action (self ):
84+ # specify the action space. We need to establish how these actions
85+ # are performed
86+ action_space = ActionSpace (n = 1 )
87+
88+ generalization_table = {"Mixed White/Asian" : DefaultHierarchy (values = ["Mixed" , ]),
89+ "Chinese" : DefaultHierarchy (values = ["Asian" , ]),
90+ "Indian" : DefaultHierarchy (values = ["Asian" , ]),
91+ "Mixed White/Black African" : DefaultHierarchy (values = ["Mixed" , ]),
92+ "Black African" : DefaultHierarchy (values = ["Black" , ]),
93+ "Asian other" : DefaultHierarchy (values = ["Asian" , ]),
94+ "Black other" : DefaultHierarchy (values = ["Black" , ]),
95+ "Mixed White/Black Caribbean" : DefaultHierarchy (values = ["Mixed" , ]),
96+ "Mixed other" : DefaultHierarchy (values = ["Mixed" , ]),
97+ "Arab" : DefaultHierarchy (values = ["Asian" , ]),
98+ "White Irish" : DefaultHierarchy (values = ["White" , ]),
99+ "Not stated" : DefaultHierarchy (values = ["Not stated" ]),
100+ "White Gypsy/Traveller" : DefaultHierarchy (values = ["White" , ]),
101+ "White British" : DefaultHierarchy (values = ["White" , ]),
102+ "Bangladeshi" : DefaultHierarchy (values = ["Asian" , ]),
103+ "White other" : DefaultHierarchy (values = ["White" , ]),
104+ "Black Caribbean" : DefaultHierarchy (values = ["Black" , ]),
105+ "Pakistani" : DefaultHierarchy (values = ["Asian" , ])}
106+
107+ action_space .add (ActionGeneralize (column_name = "ethnicity" , generalization_table = generalization_table ))
108+
109+ # create the environment and
110+ env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
111+
112+ # this will update the environment
113+ env .apply_action (action = action_space [0 ])
80114
115+ # test that the ethnicity column has been changed
116+ # get the unique values for the ethnicity column
117+ unique_col_vals = env .data_set .get_column_unique_values (col_name = "ethnicity" )
81118
119+ print (unique_col_vals )
82120
121+ unique_vals = ["Mixed" , "Asian" , "Not stated" , "White" , "Black" ]
122+ self .assertEqual (len (unique_vals ), len (unique_col_vals ))
83123
84124if __name__ == '__main__' :
85125 unittest .main ()
0 commit comments