Skip to content

Commit e9e1559

Browse files
committed
Add example documentation
1 parent 87eecf3 commit e9e1559

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed
111 KB
Loading
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
Semi-gradient SARSA algorithm
2+
=============================
3+
4+
In this example, we continue using a three-column data set as in the `Q-learning on a three columns dataset <qlearning_three_columns.html>`_.
5+
In that example, we used a state aggregation approach to model the overall distortion of the data set in the range :math:`\[0, 1]`.
6+
In this example, we take an alternative approach. We will use bins to discretize the deformation range for each column in the data set.
7+
The state vector will contain these deformations. Hence, for the three column data set, the state vector will have three entries,
8+
each indicating the distortion of the respective column.
9+
10+
11+
12+
13+
The semi-gradient SARSA algorithm is shown below
14+
15+
.. figure:: images/semi_gradient_sarsa.png
16+
17+
Episodic semi-gradient SARSA algorithm. Image from [1]
18+
19+
20+
21+
22+
Tiling
23+
------
24+
25+
We will use a linear function approximation for :math:`\hat{q}`:
26+
27+
.. math::
28+
\hat{q} = \mathbf{w}^T\mathbf{x}
29+
30+
31+
Code
32+
----
33+
34+
.. code-block::
35+
36+
import random
37+
from pathlib import Path
38+
import numpy as np
39+
40+
from src.algorithms.semi_gradient_sarsa import SemiGradSARSAConfig, SemiGradSARSA
41+
from src.utils.serial_hierarchy import SerialHierarchy
42+
from src.spaces.tiled_environment import TiledEnv, TiledEnvConfig, Layer
43+
from src.spaces.discrete_state_environment import DiscreteStateEnvironment
44+
from src.datasets.datasets_loaders import MockSubjectsLoader, MockSubjectsData
45+
from src.spaces.action_space import ActionSpace
46+
from src.spaces.actions import ActionIdentity, ActionStringGeneralize, ActionNumericBinGeneralize
47+
from src.algorithms.trainer import Trainer
48+
from src.policies.epsilon_greedy_policy import EpsilonDecayOption
49+
from src.algorithms.epsilon_greedy_q_estimator import EpsilonGreedyQEstimatorConfig, EpsilonGreedyQEstimator
50+
from src.utils.distortion_calculator import DistortionCalculationType, DistortionCalculator
51+
from src.utils.numeric_distance_type import NumericDistanceType
52+
from src.utils.string_distance_calculator import StringDistanceType
53+
from src.utils.reward_manager import RewardManager
54+
55+
.. code-block::
56+
57+
N_LAYERS = 5
58+
N_BINS = 10
59+
N_EPISODES = 1000
60+
OUTPUT_MSG_FREQUENCY = 100
61+
GAMMA = 0.99
62+
ALPHA = 0.1
63+
N_ITRS_PER_EPISODE = 30
64+
EPS = 1.0
65+
EPSILON_DECAY_OPTION = EpsilonDecayOption.CONSTANT_RATE #.INVERSE_STEP
66+
EPSILON_DECAY_FACTOR = 0.01
67+
MAX_DISTORTION = 0.7
68+
MIN_DISTORTION = 0.3
69+
OUT_OF_MAX_BOUND_REWARD = -1.0
70+
OUT_OF_MIN_BOUND_REWARD = -1.0
71+
IN_BOUNDS_REWARD = 5.0
72+
N_ROUNDS_BELOW_MIN_DISTORTION = 10
73+
SAVE_DISTORTED_SETS_DIR = "/home/alex/qi3/drl_anonymity/src/examples/semi_grad_sarsa/distorted_set"
74+
REWARD_FACTOR = 0.95
75+
PUNISH_FACTOR = 2.0
76+
77+
.. code-block::
78+
79+
def get_ethinicity_hierarchy():
80+
ethnicity_hierarchy = SerialHierarchy(values={})
81+
82+
ethnicity_hierarchy["Mixed White/Asian"] = "White/Asian"
83+
ethnicity_hierarchy["White/Asian"] = "Mixed"
84+
85+
ethnicity_hierarchy["Chinese"] = "Asian"
86+
ethnicity_hierarchy["Indian"] = "Asian"
87+
ethnicity_hierarchy["Mixed White/Black African"] = "White/Black"
88+
ethnicity_hierarchy["White/Black"] = "Mixed"
89+
90+
ethnicity_hierarchy["Black African"] = "African"
91+
ethnicity_hierarchy["African"] = "Black"
92+
ethnicity_hierarchy["Asian other"] = "Asian"
93+
ethnicity_hierarchy["Black other"] = "Black"
94+
ethnicity_hierarchy["Mixed White/Black Caribbean"] = "White/Black"
95+
ethnicity_hierarchy["White/Black"] = "Mixed"
96+
97+
ethnicity_hierarchy["Mixed other"] = "Mixed"
98+
ethnicity_hierarchy["Arab"] = "Asian"
99+
ethnicity_hierarchy["White Irish"] = "Irish"
100+
ethnicity_hierarchy["Irish"] = "European"
101+
ethnicity_hierarchy["Not stated"] = "Not stated"
102+
ethnicity_hierarchy["White Gypsy/Traveller"] = "White"
103+
ethnicity_hierarchy["White British"] = "British"
104+
ethnicity_hierarchy["British"] = "European"
105+
ethnicity_hierarchy["Bangladeshi"] = "Asian"
106+
ethnicity_hierarchy["White other"] = "White"
107+
ethnicity_hierarchy["Black Caribbean"] = "Caribbean"
108+
ethnicity_hierarchy["Caribbean"] = "Black"
109+
ethnicity_hierarchy["Pakistani"] = "Asian"
110+
111+
ethnicity_hierarchy["European"] = "European"
112+
ethnicity_hierarchy["Mixed"] = "Mixed"
113+
ethnicity_hierarchy["Asian"] = "Asian"
114+
ethnicity_hierarchy["Black"] = "Black"
115+
ethnicity_hierarchy["White"] = "White"
116+
return ethnicity_hierarchy
117+
118+
.. code-block::
119+
120+
def load_mock_subjects() -> MockSubjectsLoader:
121+
122+
mock_data = MockSubjectsData(FILENAME=Path("../../data/mocksubjects.csv"),
123+
COLUMNS_TYPES={"ethnicity": str, "salary": float, "diagnosis": int},
124+
FEATURES_DROP_NAMES=["NHSno", "given_name",
125+
"surname", "dob"] + ["preventative_treatment",
126+
"gender", "education", "mutation_status"],
127+
NORMALIZED_COLUMNS=["salary"])
128+
129+
ds = MockSubjectsLoader(mock_data)
130+
131+
assert ds.n_columns == 3, "Invalid number of columns {0} not equal to 3".format(ds.n_columns)
132+
133+
return ds
134+
135+
136+
def load_discrete_env() -> DiscreteStateEnvironment:
137+
138+
mock_ds = load_mock_subjects()
139+
140+
# create bins for the salary generalization
141+
unique_salary = mock_ds.get_column_unique_values(col_name="salary")
142+
unique_salary.sort()
143+
144+
# modify slightly the max value because
145+
# we get out of bounds for the maximum salary
146+
bins = np.linspace(unique_salary[0], unique_salary[-1] + 1, N_BINS)
147+
148+
action_space = ActionSpace(n=5)
149+
action_space.add_many(ActionIdentity(column_name="ethnicity"),
150+
ActionStringGeneralize(column_name="ethnicity",
151+
generalization_table=get_ethinicity_hierarchy()),
152+
ActionIdentity(column_name="salary"),
153+
ActionNumericBinGeneralize(column_name="salary", generalization_table=bins),
154+
ActionIdentity(column_name="diagnosis"))
155+
156+
action_space.shuffle()
157+
158+
env = DiscreteStateEnvironment.from_options(data_set=mock_ds,
159+
action_space=action_space,
160+
distortion_calculator=DistortionCalculator(
161+
numeric_column_distortion_metric_type=NumericDistanceType.L2_AVG,
162+
string_column_distortion_metric_type=StringDistanceType.COSINE_NORMALIZE,
163+
dataset_distortion_type=DistortionCalculationType.SUM),
164+
reward_manager=RewardManager(bounds=(MIN_DISTORTION, MAX_DISTORTION),
165+
out_of_max_bound_reward=OUT_OF_MAX_BOUND_REWARD,
166+
out_of_min_bound_reward=OUT_OF_MIN_BOUND_REWARD,
167+
in_bounds_reward=IN_BOUNDS_REWARD),
168+
gamma=GAMMA,
169+
reward_factor=REWARD_FACTOR,
170+
punish_factor=PUNISH_FACTOR,
171+
min_distortion=MIN_DISTORTION, max_distortion=MAX_DISTORTION,
172+
n_rounds_below_min_distortion=N_ROUNDS_BELOW_MIN_DISTORTION,
173+
distorted_set_path=Path(SAVE_DISTORTED_SETS_DIR),
174+
n_states=N_LAYERS * Layer.n_tiles_per_action(N_BINS,
175+
mock_ds.n_columns))
176+
177+
return env
178+
179+
.. code-block::
180+
181+
if __name__ == '__main__':
182+
183+
# set the seed for random engine
184+
random.seed(42)
185+
186+
discrete_env = load_discrete_env()
187+
tiled_env_config = TiledEnvConfig(n_layers=N_LAYERS, n_bins=N_BINS,
188+
env=discrete_env,
189+
column_ranges={"ethnicity": [0.0, 1.0],
190+
"salary": [0.0, 1.0],
191+
"diagnosis": [0.0, 1.0]})
192+
tiled_env = TiledEnv(tiled_env_config)
193+
tiled_env.create_tiles()
194+
195+
configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": OUTPUT_MSG_FREQUENCY}
196+
197+
agent_config = SemiGradSARSAConfig(gamma=GAMMA, alpha=ALPHA, n_itrs_per_episode=N_ITRS_PER_EPISODE,
198+
policy=EpsilonGreedyQEstimator(EpsilonGreedyQEstimatorConfig(eps=EPS, n_actions=tiled_env.n_actions,
199+
decay_op=EPSILON_DECAY_OPTION,
200+
epsilon_decay_factor=EPSILON_DECAY_FACTOR,
201+
env=tiled_env, gamma=GAMMA, alpha=ALPHA)))
202+
agent = SemiGradSARSA(agent_config)
203+
204+
# create a trainer to train the Qlearning agent
205+
trainer = Trainer(env=tiled_env, agent=agent, configuration=configuration)
206+
trainer.train()
207+
208+
209+
210+
References
211+
----------
212+
213+
1. Sutton and Barto, Reinforcement Learning

0 commit comments

Comments
 (0)