1+ import unittest
2+
3+ import unittest
4+ from pathlib import Path
5+
6+ import pytest
7+
8+ from src .spaces .environment import Environment
9+ from src .spaces .action_space import ActionSpace
10+ from src .spaces .actions import ActionSuppress , ActionGeneralize
11+ from src .exceptions .exceptions import Error
12+ from src .utils .serial_hierarchy import SerialHierarchy
13+ from src .utils .string_distance_calculator import DistanceType
14+ from src .datasets .dataset_wrapper import PandasDSWrapper
15+ from src .spaces .state_space import StateSpace , State
16+
17+ class TestStateSpace (unittest .TestCase ):
18+
19+ def setUp (self ) -> None :
20+ """
21+ Setup the PandasDSWrapper to be used in the tests
22+ :return: None
23+ """
24+
25+ # read the data
26+ filename = Path ("../../data/mocksubjects.csv" )
27+
28+ cols_types = {"gender" : str , "ethnicity" : str , "education" : int ,
29+ "salary" : int , "diagnosis" : int , "preventative_treatment" : str ,
30+ "mutation_status" : int , }
31+
32+ self .ds = PandasDSWrapper (columns = cols_types )
33+ self .ds .read (filename = filename , ** {"features_drop_names" : ["NHSno" , "given_name" , "surname" , "dob" ],
34+ "names" : ["NHSno" , "given_name" , "surname" , "gender" ,
35+ "dob" , "ethnicity" , "education" , "salary" ,
36+ "mutation_status" , "preventative_treatment" , "diagnosis" ],
37+ "drop_na" : True ,
38+ "change_col_vals" : {"diagnosis" : [('N' , 0 )]}})
39+
40+ def test_creation (self ):
41+
42+ action_space = ActionSpace (n = 3 )
43+
44+ generalization_table = {"Mixed White/Asian" : SerialHierarchy (values = ["Mixed" , ]),
45+ "Chinese" : SerialHierarchy (values = ["Asian" , ]),
46+ "Indian" : SerialHierarchy (values = ["Asian" , ]),
47+ "Mixed White/Black African" : SerialHierarchy (values = ["Mixed" , ]),
48+ "Black African" : SerialHierarchy (values = ["Black" , ]),
49+ "Asian other" : SerialHierarchy (values = ["Asian" , ]),
50+ "Black other" : SerialHierarchy (values = ["Black" , ]),
51+ "Mixed White/Black Caribbean" : SerialHierarchy (values = ["Mixed" , ]),
52+ "Mixed other" : SerialHierarchy (values = ["Mixed" , ]),
53+ "Arab" : SerialHierarchy (values = ["Asian" , ]),
54+ "White Irish" : SerialHierarchy (values = ["White" , ]),
55+ "Not stated" : SerialHierarchy (values = ["Not stated" ]),
56+ "White Gypsy/Traveller" : SerialHierarchy (values = ["White" , ]),
57+ "White British" : SerialHierarchy (values = ["White" , ]),
58+ "Bangladeshi" : SerialHierarchy (values = ["Asian" , ]),
59+ "White other" : SerialHierarchy (values = ["White" , ]),
60+ "Black Caribbean" : SerialHierarchy (values = ["Black" , ]),
61+ "Pakistani" : SerialHierarchy (values = ["Asian" , ])}
62+
63+ action_space .add (ActionGeneralize (column_name = "ethnicity" , generalization_table = generalization_table ))
64+
65+ # create the environment from the given dataset
66+ env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
67+
68+ state_space = StateSpace ()
69+ state_space .init_from_environment (env = env )
70+
71+ print (state_space .states .keys ())
72+
73+ self .assertEqual (env .n_features , state_space .n )
74+
75+
76+ if __name__ == '__main__' :
77+ unittest .main ()
0 commit comments