|
| 1 | +from src.algorithms.q_learning import QLearning, QLearnConfig |
| 2 | +from src.algorithms.trainer import Trainer |
| 3 | +from src.utils.string_distance_calculator import DistanceType |
| 4 | +from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform |
| 5 | +from src.spaces.environment import Environment, EnvConfig |
| 6 | +from src.spaces.action_space import ActionSpace |
| 7 | +from src.datasets.datasets_loaders import MockSubjectsLoader |
| 8 | +from src.utils.reward_manager import RewardManager |
| 9 | +from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonDecreaseOption |
| 10 | +from src.utils.serial_hierarchy import SerialHierarchy |
| 11 | +from src.utils.numeric_distance_type import NumericDistanceType |
| 12 | + |
| 13 | + |
| 14 | +if __name__ == '__main__': |
| 15 | + |
| 16 | + EPS = 1.0 |
| 17 | + GAMMA = 0.99 |
| 18 | + ALPHA = 0.1 |
| 19 | + |
| 20 | + # load the dataset |
| 21 | + ds = MockSubjectsLoader() |
| 22 | + |
| 23 | + # specify the action space. We need to establish how these actions |
| 24 | + # are performed |
| 25 | + action_space = ActionSpace(n=4) |
| 26 | + |
| 27 | + generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]), |
| 28 | + "Chinese": SerialHierarchy(values=["Asian", ]), |
| 29 | + "Indian": SerialHierarchy(values=["Asian", ]), |
| 30 | + "Mixed White/Black African": SerialHierarchy(values=["Mixed", ]), |
| 31 | + "Black African": SerialHierarchy(values=["Black", ]), |
| 32 | + "Asian other": SerialHierarchy(values=["Asian", ]), |
| 33 | + "Black other": SerialHierarchy(values=["Black", ]), |
| 34 | + "Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]), |
| 35 | + "Mixed other": SerialHierarchy(values=["Mixed", ]), |
| 36 | + "Arab": SerialHierarchy(values=["Asian", ]), |
| 37 | + "White Irish": SerialHierarchy(values=["White", ]), |
| 38 | + "Not stated": SerialHierarchy(values=["Not stated"]), |
| 39 | + "White Gypsy/Traveller": SerialHierarchy(values=["White", ]), |
| 40 | + "White British": SerialHierarchy(values=["White", ]), |
| 41 | + "Bangladeshi": SerialHierarchy(values=["Asian", ]), |
| 42 | + "White other": SerialHierarchy(values=["White", ]), |
| 43 | + "Black Caribbean": SerialHierarchy(values=["Black", ]), |
| 44 | + "Pakistani": SerialHierarchy(values=["Asian", ])} |
| 45 | + |
| 46 | + action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]), |
| 47 | + 'M': SerialHierarchy(values=['*', ])}), |
| 48 | + ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"), |
| 49 | + ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table)) |
| 50 | + |
| 51 | + average_distortion_constraint = {"salary": [0.0, 0.0, 0.0], "education": [0.0, 0.0, 0.0], |
| 52 | + "ethnicity": [3.0, 1.0, -1.0], "gender": [4.0, 1.0, -1.0]} |
| 53 | + |
| 54 | + # specify the reward manager to use |
| 55 | + reward_manager = RewardManager(average_distortion_constraint=average_distortion_constraint) |
| 56 | + |
| 57 | + env_config = EnvConfig() |
| 58 | + env_config.start_column = "gender" |
| 59 | + env_config.action_space = action_space |
| 60 | + env_config.reward_manager = reward_manager |
| 61 | + env_config.data_set = ds |
| 62 | + env_config.gamma = 0.99 |
| 63 | + env_config.numeric_column_distortion_metric_type = NumericDistanceType.L2 |
| 64 | + |
| 65 | + # create the environment |
| 66 | + env = Environment(env_config=env_config) |
| 67 | + |
| 68 | + # initialize text distances |
| 69 | + env.initialize_text_distances(distance_type=DistanceType.COSINE) |
| 70 | + |
| 71 | + algo_config = QLearnConfig() |
| 72 | + algo_config.n_itrs_per_episode = 1000 |
| 73 | + algo_config.gamma = 0.99 |
| 74 | + algo_config.alpha = 0.1 |
| 75 | + algo_config.policy = EpsilonGreedyPolicy(eps=EPS, env=env, |
| 76 | + decay_op=EpsilonDecreaseOption.INVERSE_STEP) |
| 77 | + |
| 78 | + agent = QLearning(algo_config=algo_config) |
| 79 | + |
| 80 | + configuration = {"n_episodes": 10, "output_msg_frequency": 100} |
| 81 | + |
| 82 | + # create a trainer to train the A2C agent |
| 83 | + trainer = Trainer(env=env, agent=agent, configuration=configuration) |
| 84 | + |
| 85 | + trainer.train() |
0 commit comments