|
| 1 | +import matplotlib.pyplot as plt |
| 2 | +import numpy as np |
| 3 | + |
1 | 4 | from src.algorithms.q_learning import QLearning, QLearnConfig |
2 | 5 | from src.algorithms.trainer import Trainer |
3 | | -from src.utils.string_distance_calculator import DistanceType |
4 | | -from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform |
5 | | -from src.spaces.environment import Environment, EnvConfig |
| 6 | +from src.utils.string_distance_calculator import StringDistanceType |
| 7 | +from src.spaces.actions import ActionSuppress, ActionIdentity, ActionStringGeneralize, ActionTransform |
| 8 | +from src.spaces.discrete_state_environment import Environment, EnvConfig |
6 | 9 | from src.spaces.action_space import ActionSpace |
7 | 10 | from src.datasets.datasets_loaders import MockSubjectsLoader |
8 | 11 | from src.utils.reward_manager import RewardManager |
|
11 | 14 | from src.utils.numeric_distance_type import NumericDistanceType |
12 | 15 |
|
13 | 16 |
|
| 17 | +def plot_running_avg(avg_rewards): |
| 18 | + |
| 19 | + running_avg = np.empty(avg_rewards.shape[0]) |
| 20 | + for t in range(avg_rewards.shape[0]): |
| 21 | + running_avg[t] = np.mean(avg_rewards[max(0, t-100) : (t+1)]) |
| 22 | + plt.plot(running_avg) |
| 23 | + plt.xlabel("Number of episodes") |
| 24 | + plt.ylabel("Reward") |
| 25 | + plt.title("Running average") |
| 26 | + plt.show() |
| 27 | + |
| 28 | +def get_ethinicity_hierarchies(): |
| 29 | + |
| 30 | + ethnicity_hierarchy = SerialHierarchy() |
| 31 | + ethnicity_hierarchy.add("Mixed White/Asian", values=["Mixed", '*']) |
| 32 | + ethnicity_hierarchy.add("Chinese", values=["Asian", '*']) |
| 33 | + ethnicity_hierarchy.add("Indian", values=["Asian", '*']) |
| 34 | + ethnicity_hierarchy.add("Mixed White/Black African", values=["Mixed", '*']) |
| 35 | + ethnicity_hierarchy.add("Black African", values=["Black", '*']) |
| 36 | + ethnicity_hierarchy.add("Asian other", values=["Asian", "*"]) |
| 37 | + ethnicity_hierarchy.add("Black other", values=["Black", "*"]) |
| 38 | + ethnicity_hierarchy.add("Mixed White/Black Caribbean", values=["Mixed", "*"]) |
| 39 | + ethnicity_hierarchy.add("Mixed other", values=["Mixed", "*"]) |
| 40 | + ethnicity_hierarchy.add("Arab", values=["Asian", "*"]) |
| 41 | + ethnicity_hierarchy.add("White Irish", values=["White", "*"]) |
| 42 | + ethnicity_hierarchy.add("Not stated", values=["Not stated", "*"]) |
| 43 | + ethnicity_hierarchy.add("White Gypsy/Traveller", values=["White", "*"]) |
| 44 | + ethnicity_hierarchy.add("White British", values=["White", "*"]) |
| 45 | + ethnicity_hierarchy.add("Bangladeshi", values=["Asian", "*"]) |
| 46 | + ethnicity_hierarchy.add("White other", values=["White", "*"]) |
| 47 | + ethnicity_hierarchy.add("Black Caribbean", values=["Black", "*"]) |
| 48 | + ethnicity_hierarchy.add("Pakistani", values=["Asian", "*"]) |
| 49 | + |
| 50 | + return ethnicity_hierarchy |
| 51 | + |
| 52 | + |
14 | 53 | if __name__ == '__main__': |
15 | 54 |
|
16 | 55 | EPS = 1.0 |
17 | 56 | GAMMA = 0.99 |
18 | 57 | ALPHA = 0.1 |
| 58 | + N_EPISODES = 100 |
19 | 59 |
|
20 | 60 | # load the dataset |
21 | 61 | ds = MockSubjectsLoader() |
22 | 62 |
|
| 63 | + # generalization table for the ethnicity column |
| 64 | + ethinicity_table = get_ethinicity_hierarchies() |
| 65 | + |
23 | 66 | # specify the action space. We need to establish how these actions |
24 | 67 | # are performed |
25 | | - action_space = ActionSpace(n=4) |
26 | | - |
27 | | - generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]), |
28 | | - "Chinese": SerialHierarchy(values=["Asian", ]), |
29 | | - "Indian": SerialHierarchy(values=["Asian", ]), |
30 | | - "Mixed White/Black African": SerialHierarchy(values=["Mixed", ]), |
31 | | - "Black African": SerialHierarchy(values=["Black", ]), |
32 | | - "Asian other": SerialHierarchy(values=["Asian", ]), |
33 | | - "Black other": SerialHierarchy(values=["Black", ]), |
34 | | - "Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]), |
35 | | - "Mixed other": SerialHierarchy(values=["Mixed", ]), |
36 | | - "Arab": SerialHierarchy(values=["Asian", ]), |
37 | | - "White Irish": SerialHierarchy(values=["White", ]), |
38 | | - "Not stated": SerialHierarchy(values=["Not stated"]), |
39 | | - "White Gypsy/Traveller": SerialHierarchy(values=["White", ]), |
40 | | - "White British": SerialHierarchy(values=["White", ]), |
41 | | - "Bangladeshi": SerialHierarchy(values=["Asian", ]), |
42 | | - "White other": SerialHierarchy(values=["White", ]), |
43 | | - "Black Caribbean": SerialHierarchy(values=["Black", ]), |
44 | | - "Pakistani": SerialHierarchy(values=["Asian", ])} |
45 | | - |
| 68 | + action_space = ActionSpace(n=5) |
46 | 69 | action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]), |
47 | 70 | 'M': SerialHierarchy(values=['*', ])}), |
48 | | - ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"), |
49 | | - ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table)) |
50 | | - |
| 71 | + ActionIdentity(column_name="salary"), |
| 72 | + ActionIdentity(column_name="education"), |
| 73 | + ActionStringGeneralize(column_name="ethnicity", generalization_table=ethinicity_table), |
| 74 | + ActionSuppress(column_name="preventative_treatment", |
| 75 | + suppress_table={"No": SerialHierarchy(values=['Maybe', '*']), |
| 76 | + 'Yes': SerialHierarchy(values=['Maybe', '*']), |
| 77 | + "NA": SerialHierarchy(values=['Maybe', '*']), |
| 78 | + "Maybe": SerialHierarchy(values=['*', '*']) |
| 79 | + })) |
| 80 | + |
| 81 | + # average distirtion |
51 | 82 | average_distortion_constraint = {"salary": [0.0, 0.0, 0.0], "education": [0.0, 0.0, 0.0], |
52 | | - "ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0]} |
| 83 | + "ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0], |
| 84 | + "preventative_treatment": [4.0, 1.0, -1.0]} |
53 | 85 |
|
54 | 86 | # specify the reward manager to use |
55 | 87 | reward_manager = RewardManager(average_distortion_constraint=average_distortion_constraint) |
|
66 | 98 | env = Environment(env_config=env_config) |
67 | 99 |
|
68 | 100 | # initialize text distances |
69 | | - env.initialize_text_distances(distance_type=DistanceType.COSINE) |
| 101 | + env.initialize_text_distances(distance_type=StringDistanceType.COSINE) |
70 | 102 |
|
71 | 103 | algo_config = QLearnConfig() |
72 | | - algo_config.n_itrs_per_episode = 1000 |
| 104 | + algo_config.n_itrs_per_episode = 10 |
73 | 105 | algo_config.gamma = 0.99 |
74 | 106 | algo_config.alpha = 0.1 |
75 | 107 | algo_config.policy = EpsilonGreedyPolicy(eps=EPS, env=env, |
76 | 108 | decay_op=EpsilonDecreaseOption.INVERSE_STEP) |
77 | 109 |
|
78 | 110 | agent = QLearning(algo_config=algo_config) |
79 | 111 |
|
80 | | - configuration = {"n_episodes": 10, "output_msg_frequency": 100} |
| 112 | + configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": 10} |
81 | 113 |
|
82 | 114 | # create a trainer to train the A2C agent |
83 | 115 | trainer = Trainer(env=env, agent=agent, configuration=configuration) |
84 | 116 |
|
85 | 117 | trainer.train() |
| 118 | + |
| 119 | + # get the state space |
| 120 | + state_space = env.state_space |
| 121 | + |
| 122 | + for state in state_space: |
| 123 | + print("Column {0} history {1}".format(state, state_space[state].history)) |
| 124 | + |
| 125 | + total_reward = trainer.total_rewards |
| 126 | + episodes = [episode for episode in range(N_EPISODES)] |
| 127 | + |
| 128 | + plt.plot(episodes, total_reward) |
| 129 | + plt.xlabel("Episodes") |
| 130 | + plt.ylabel("Reward") |
| 131 | + plt.show() |
| 132 | + |
| 133 | + |
0 commit comments