1+ from src .algorithms .q_learning import QLearning , QLearnConfig
2+ from src .algorithms .trainer import Trainer
3+ from src .utils .string_distance_calculator import DistanceType
4+ from src .spaces .actions import ActionSuppress , ActionIdentity , ActionGeneralize , ActionTransform
5+ from src .spaces .environment import Environment
6+ from src .spaces .action_space import ActionSpace
7+ from src .datasets .datasets_loaders import MockSubjectsLoader
8+ from src .utils .reward_manager import RewardManager
9+ from src .policies .epsilon_greedy_policy import EpsilonGreedyPolicy , EpsilonDecreaseOption
10+ from src .utils .serial_hierarchy import SerialHierarchy
11+
12+
13+ if __name__ == '__main__' :
14+
15+ EPS = 1.0
16+ GAMMA = 0.99
17+ ALPHA = 0.1
18+
19+ # load the dataset
20+ ds = MockSubjectsLoader ()
21+
22+ # specify the action space. We need to establish how these actions
23+ # are performed
24+ action_space = ActionSpace (n = 4 )
25+
26+ generalization_table = {"Mixed White/Asian" : SerialHierarchy (values = ["Mixed" , ]),
27+ "Chinese" : SerialHierarchy (values = ["Asian" , ]),
28+ "Indian" : SerialHierarchy (values = ["Asian" , ]),
29+ "Mixed White/Black African" : SerialHierarchy (values = ["Mixed" , ]),
30+ "Black African" : SerialHierarchy (values = ["Black" , ]),
31+ "Asian other" : SerialHierarchy (values = ["Asian" , ]),
32+ "Black other" : SerialHierarchy (values = ["Black" , ]),
33+ "Mixed White/Black Caribbean" : SerialHierarchy (values = ["Mixed" , ]),
34+ "Mixed other" : SerialHierarchy (values = ["Mixed" , ]),
35+ "Arab" : SerialHierarchy (values = ["Asian" , ]),
36+ "White Irish" : SerialHierarchy (values = ["White" , ]),
37+ "Not stated" : SerialHierarchy (values = ["Not stated" ]),
38+ "White Gypsy/Traveller" : SerialHierarchy (values = ["White" , ]),
39+ "White British" : SerialHierarchy (values = ["White" , ]),
40+ "Bangladeshi" : SerialHierarchy (values = ["Asian" , ]),
41+ "White other" : SerialHierarchy (values = ["White" , ]),
42+ "Black Caribbean" : SerialHierarchy (values = ["Black" , ]),
43+ "Pakistani" : SerialHierarchy (values = ["Asian" , ])}
44+
45+ action_space .add_many (ActionSuppress (column_name = "gender" , suppress_table = {"F" : SerialHierarchy (values = ['*' , ]),
46+ 'M' : SerialHierarchy (values = ['*' , ])}),
47+ ActionIdentity (column_name = "salary" ), ActionIdentity (column_name = "education" ),
48+ ActionGeneralize (column_name = "ethnicity" , generalization_table = generalization_table ))
49+
50+ # specify the reward manager to use
51+ reward_manager = RewardManager ()
52+
53+ # create the environment
54+ env = Environment (data_set = ds , action_space = action_space ,
55+ gamma = 0.99 , start_column = "gender" , reward_manager = reward_manager )
56+ # initialize text distances
57+ env .initialize_text_distances (distance_type = DistanceType .COSINE )
58+
59+ algo_config = QLearnConfig ()
60+ algo_config .n_itrs_per_episode = 1000
61+ algo_config .gamma = 0.99
62+ algo_config .alpha = 0.1
63+ algo_config .policy = EpsilonGreedyPolicy (eps = EPS , env = env ,
64+ decay_op = EpsilonDecreaseOption .INVERSE_STEP )
65+
66+ agent = QLearning (algo_config = algo_config )
67+
68+ configuration = {"n_episodes" : 10 , "update_frequency" : 100 }
69+
70+ # create a trainer to train the A2C agent
71+ trainer = Trainer (env = env , agent = agent , configuration = configuration )
72+
73+ trainer .train ()
0 commit comments