Skip to content

Commit 453398c

Browse files
committed
#13 Add mock example
1 parent 72e98cc commit 453398c

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

src/apps/qlearning_on_mock.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from src.algorithms.q_learning import QLearning, QLearnConfig
2+
from src.algorithms.trainer import Trainer
3+
from src.utils.string_distance_calculator import DistanceType
4+
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionGeneralize, ActionTransform
5+
from src.spaces.environment import Environment
6+
from src.spaces.action_space import ActionSpace
7+
from src.datasets.datasets_loaders import MockSubjectsLoader
8+
from src.utils.reward_manager import RewardManager
9+
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonDecreaseOption
10+
from src.utils.serial_hierarchy import SerialHierarchy
11+
12+
13+
if __name__ == '__main__':
14+
15+
EPS = 1.0
16+
GAMMA = 0.99
17+
ALPHA = 0.1
18+
19+
# load the dataset
20+
ds = MockSubjectsLoader()
21+
22+
# specify the action space. We need to establish how these actions
23+
# are performed
24+
action_space = ActionSpace(n=4)
25+
26+
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
27+
"Chinese": SerialHierarchy(values=["Asian", ]),
28+
"Indian": SerialHierarchy(values=["Asian", ]),
29+
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
30+
"Black African": SerialHierarchy(values=["Black", ]),
31+
"Asian other": SerialHierarchy(values=["Asian", ]),
32+
"Black other": SerialHierarchy(values=["Black", ]),
33+
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
34+
"Mixed other": SerialHierarchy(values=["Mixed", ]),
35+
"Arab": SerialHierarchy(values=["Asian", ]),
36+
"White Irish": SerialHierarchy(values=["White", ]),
37+
"Not stated": SerialHierarchy(values=["Not stated"]),
38+
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
39+
"White British": SerialHierarchy(values=["White", ]),
40+
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
41+
"White other": SerialHierarchy(values=["White", ]),
42+
"Black Caribbean": SerialHierarchy(values=["Black", ]),
43+
"Pakistani": SerialHierarchy(values=["Asian", ])}
44+
45+
action_space.add_many(ActionSuppress(column_name="gender", suppress_table={"F": SerialHierarchy(values=['*', ]),
46+
'M': SerialHierarchy(values=['*', ])}),
47+
ActionIdentity(column_name="salary"), ActionIdentity(column_name="education"),
48+
ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
49+
50+
# specify the reward manager to use
51+
reward_manager = RewardManager()
52+
53+
# create the environment
54+
env = Environment(data_set=ds, action_space=action_space,
55+
gamma=0.99, start_column="gender", reward_manager=reward_manager)
56+
# initialize text distances
57+
env.initialize_text_distances(distance_type=DistanceType.COSINE)
58+
59+
algo_config = QLearnConfig()
60+
algo_config.n_itrs_per_episode = 1000
61+
algo_config.gamma = 0.99
62+
algo_config.alpha = 0.1
63+
algo_config.policy = EpsilonGreedyPolicy(eps=EPS, env=env,
64+
decay_op=EpsilonDecreaseOption.INVERSE_STEP)
65+
66+
agent = QLearning(algo_config=algo_config)
67+
68+
configuration = {"n_episodes": 10, "update_frequency": 100}
69+
70+
# create a trainer to train the A2C agent
71+
trainer = Trainer(env=env, agent=agent, configuration=configuration)
72+
73+
trainer.train()

0 commit comments

Comments
 (0)