Skip to content

Commit dd5f73d

Browse files
authored
Merge pull request #73 from pockerman/investigate_sarsa_semi_gradient
Investigate sarsa semi gradient
2 parents 8f55d91 + 8ec6f83 commit dd5f73d

12 files changed

+219
-74
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
epsilon\_greedy\_policy
2+
=======================
3+
4+
.. automodule:: epsilon_greedy_policy
5+
6+
.. autoclass:: EpsilonDecayOption
7+
8+
.. autoclass:: EpsilonGreedyConfig
9+
10+
.. autoclass:: EpsilonGreedyPolicy
11+
:members: __init__, from_config, __str__, __call__, on_state, actions_after_episode

docs/source/API/time_step.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
time\_step
2+
==========
3+
4+
.. automodule:: time_step
5+
:members: copy_time_step
6+
7+
.. autoclass:: StepType
8+
.. autoclass:: TimeStep
9+
:members: first, mid, last, done
10+
11+
12+
13+
14+
15+
16+
17+
18+
19+
97.9 KB
Loading
98.5 KB
Loading

docs/source/Examples/qlearning_three_columns.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
Q-learning on a three columns dataset
22
=====================================
33

4+
Overview
5+
--------
6+
7+
In this example, we use a tabular Q-learning algorithm to anonymize a data set with three columns.
8+
9+
10+
11+
412
In this simple example we show how to apply QLearning on a dataset with three columns.
513

614

15+
Code
16+
----
17+
718
.. code-block::
819
920
import numpy as np

docs/source/Examples/semi_gradient_sarsa_three_columns.rst

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,33 @@
11
Semi-gradient SARSA algorithm
22
=============================
33

4+
Overview
5+
--------
6+
7+
In this example, we use the episodic semi-gradient SARSA algorithm to anonymize a data set with three columns.
8+
9+
10+
Semi-gradient SARSA algorithm
11+
-----------------------------
12+
413
In this example, we continue using a three-column data set as in the `Q-learning on a three columns dataset <qlearning_three_columns.html>`_.
5-
In that example, we used a state aggregation approach to model the overall distortion of the data set in the range :math:`[0, 1]`.
6-
Herein, we take an alternative approach. We will assume that the column distortion is in the range :math:`\[0, 1]` where the edge points mean no distortion
7-
and full distortion of the column respectively. For each column, we will use the same approach to discretize the continuous :math:`[0, 1]` range
8-
into a given number of disjoint bins.
14+
In that example, we used state aggregation to model the overall distortion of the data set in the range :math:`[0, 1]`.
15+
Herein, we take an alternative approach. We will assume that the column distortion is in the range :math:`[0, 1]` where the edge points mean no distortion
16+
and full distortion of the column respectively. For each column, we will use the same methodology as in `Q-learning on a three columns dataset <qlearning_three_columns.html>`_ to discretize the continuous :math:`[0, 1]` range into a given number of disjoint bins.
917

1018
Contrary to representing the state-action function :math:`q_{\pi}` using a table as we did in `Q-learning on a three columns dataset <qlearning_three_columns.html>`_, we will assume a functional form for
1119
it. Specifically, we assume that the state-action function can be approximated by :math:`\hat{q} \approx q_{\pi}` given by
1220

1321
.. math::
1422
\hat{q}(s, \alpha) = \mathbf{w}^T\mathbf{x}(s, \alpha) = \sum_{i}^{d} w_i, x_i(s, \alpha)
1523
16-
where :math:`\mathbf{w}` is the weights vector and :math:`\mathbf{x}(s, \alpha)` is called the feature vector representing state :math:`s` when taking action :math:`\alpha` [1]. For our case the components of the feature vector will be distortions of the three columns when applying action :math:`\alpha` on the data set. Our goal now is to find the components of the weight vector. We can the stochastic gradient descent (or SGD )
17-
for this [1]. In this case, the update rule is [1]
24+
where :math:`\mathbf{w}` is the weights vector and :math:`\mathbf{x}(s, \alpha)` is called the feature vector representing state :math:`s` when taking action :math:`\alpha` [1]. We will use `Tile coding`_ to construct :math:`\mathbf{x}(s, \alpha)`. Our goal now is to find the components of the weight vector.
25+
We can use stochastic gradient descent (or SGD ) for this [1]. In this case, the update rule is [1]
1826

1927
.. math::
2028
\mathbf{w}_{t + 1} = \mathbf{w}_t + \eta\left[U_t - \gamma \hat{q}(s_t, \alpha_t, \mathbf{w}_t)\right] \nabla_{\mathbf{w}} \hat{q}(s_t, \alpha_t, \mathbf{w}_t)
2129
22-
where :math:`U_t` for one-step SARSA is given by [1]:
30+
where :math:`\eta` is the learning rate and :math:`U_t`, for one-step SARSA, is given by [1]:
2331

2432
.. math::
2533
U_t = R_t + \gamma \hat{q}(s_{t + 1}, \alpha_{t + 1}, \mathbf{w}_t)
@@ -29,20 +37,27 @@ Since, :math:`\hat{q}(s, \alpha)` is a linear function with respect to the weigh
2937
.. math::
3038
\nabla_{\mathbf{w}} \hat{q}(s, \alpha) = \mathbf{x}(s, \alpha)
3139
32-
We will use bins to discretize the deformation range for each column in the data set.
33-
The state vector will contain these deformations. Hence, for the three column data set, the state vector will have three entries, each indicating the distortion of the respective column.
34-
3540
The semi-gradient SARSA algorithm is shown below
3641

3742
.. figure:: images/semi_gradient_sarsa.png
3843

3944
Episodic semi-gradient SARSA algorithm. Image from [1].
4045

4146

42-
43-
44-
Tiling
45-
------
47+
Tile coding
48+
-------------
49+
50+
Since we consider all the columns distortions in the data set, means that we deal with a multi-dimensional continuous spaces. In this case,
51+
we can use tile coding to construct :math:`\mathbf{x}(s, \alpha)` [1].
52+
53+
Tile coding is a form of coarse coding for multi-dimensional continuous spaces [1]. In this method, the features are grouped into partitions of the state
54+
space. Each partition is called a tiling, and each element of the partition is called a
55+
tile [1]. The following figure shows the a 2D state space partitioned in a uniform grid (left).
56+
If we only use this tiling, we would not have coarse coding but just a case of state aggregation.
57+
58+
In order to apply coarse coding, we use overlapping tiling partitions. In this case, each tiling is offset by a fraction of a tile width [1].
59+
A simple case with four tilings is shown on the right side of following figure.
60+
4661

4762
We will use a linear function approximation for :math:`\hat{q}`:
4863

@@ -53,9 +68,22 @@ We will use a linear function approximation for :math:`\hat{q}`:
5368
These tilings are offset from one another by a uniform amount in each dimension. Image from [1].
5469

5570

71+
One practical advantage of tile coding is that the overall number of features that are active
72+
at a given instance is the same for any state [1]. Exactly one feature is present in each tiling, so the total number of features present is
73+
always the same as the number of tilings [1]. This allows the learning parameter :math:`\eta`, to be set according to
74+
75+
.. math::
76+
\eta = \frac{1}{n}
77+
78+
79+
where :math:`n` is the number of tilings.
80+
81+
5682
Code
5783
----
5884

85+
The necessary imports
86+
5987
.. code-block::
6088
6189
import random
@@ -77,6 +105,8 @@ Code
77105
from src.utils.string_distance_calculator import StringDistanceType
78106
from src.utils.reward_manager import RewardManager
79107
108+
Next we set some constants
109+
80110
.. code-block::
81111
82112
N_LAYERS = 5
@@ -99,6 +129,8 @@ Code
99129
REWARD_FACTOR = 0.95
100130
PUNISH_FACTOR = 2.0
101131
132+
We continue by establishing some helper functions
133+
102134
.. code-block::
103135
104136
def get_ethinicity_hierarchy():
@@ -140,7 +172,6 @@ Code
140172
ethnicity_hierarchy["White"] = "White"
141173
return ethnicity_hierarchy
142174
143-
.. code-block::
144175
145176
def load_mock_subjects() -> MockSubjectsLoader:
146177
@@ -201,6 +232,8 @@ Code
201232
202233
return env
203234
235+
The driver code brings all elements together
236+
204237
.. code-block::
205238
206239
if __name__ == '__main__':
@@ -231,6 +264,10 @@ Code
231264
trainer.train()
232265
233266
267+
.. figure:: images/semi_gradient_sarsa_3_columns_reward.png
268+
269+
270+
.. figure:: images/semi_gradient_sarsa_3_columns_distortion.png
234271

235272
References
236273
----------

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
sys.path.append(os.path.abspath("../../src/algorithms/"))
1818
sys.path.append(os.path.abspath("../../src/exceptions/"))
1919
sys.path.append(os.path.abspath("../../src/spaces/"))
20+
sys.path.append(os.path.abspath("../../src/policies/"))
2021
print(sys.path)
2122

2223

docs/source/modules.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ API
77
API/actions
88
API/state
99
API/epsilon_greedy_q_estimator
10+
API/epsilon_greedy_policy
11+
API/time_step
1012
generated/action_space
1113
generated/q_estimator
1214
generated/q_learning
@@ -17,6 +19,5 @@ API
1719
generated/discrete_state_environment
1820
generated/observation_space
1921
generated/state
20-
generated/time_step
2122
generated/tiled_environment
2223

src/algorithms/epsilon_greedy_q_estimator.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
@dataclass(init=True, repr=True)
2020
class EpsilonGreedyQEstimatorConfig(EpsilonGreedyConfig):
21+
"""Configuration class for EpsilonGreedyQEstimator
22+
23+
"""
2124
gamma: float = 1.0
2225
alpha: float = 1.0
2326
env: Env = None
@@ -29,7 +32,7 @@ class EpsilonGreedyQEstimator(WithEstimatorMixin):
2932
"""
3033

3134
def __init__(self, config: EpsilonGreedyQEstimatorConfig):
32-
"""Constructor
35+
"""Constructor. Initialize the estimator with a given configuration
3336
3437
Parameters
3538
----------
@@ -71,43 +74,13 @@ def q_hat_value(self, state_action_vec: StateActionVec) -> float:
7174
-------
7275
float
7376
74-
7577
"""
7678

7779
if self.weights is None:
7880
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")
7981

8082
return self.weights.dot(state_action_vec)
8183

82-
"""
83-
def update_weights(self, total_reward: float, state_action: Action,
84-
state_action_: Action, t: float) -> None:
85-
86-
Update the weights
87-
88-
Parameters
89-
----------
90-
91-
total_reward: The reward observed
92-
state_action: The action that led to the reward
93-
state_action_:
94-
t: The decay factor for alpha
95-
96-
Returns
97-
-------
98-
99-
None
100-
101-
102-
103-
if self.weights is None:
104-
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")
105-
106-
v1 = self.q_hat_value(state_action_vec=state_action)
107-
v2 = self.q_hat_value(state_action_vec=state_action_)
108-
self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action
109-
"""
110-
11184
def on_state(self, state: State) -> Action:
11285
"""Returns the action on the given state
11386

src/examples/semi_gradient_sarsa.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from src.utils.numeric_distance_type import NumericDistanceType
1717
from src.utils.string_distance_calculator import StringDistanceType
1818
from src.utils.reward_manager import RewardManager
19+
from src.utils.plot_utils import plot_running_avg
20+
from src.utils import INFO
1921

2022

2123
N_LAYERS = 5
@@ -144,24 +146,60 @@ def load_discrete_env() -> DiscreteStateEnvironment:
144146
# set the seed for random engine
145147
random.seed(42)
146148

149+
# load the discrete environment
147150
discrete_env = load_discrete_env()
151+
152+
# establish the configuration for the Tiled environment
148153
tiled_env_config = TiledEnvConfig(n_layers=N_LAYERS, n_bins=N_BINS,
149154
env=discrete_env,
150155
column_ranges={"ethnicity": [0.0, 1.0],
151156
"salary": [0.0, 1.0],
152157
"diagnosis": [0.0, 1.0]})
158+
# create the Tiled environment
153159
tiled_env = TiledEnv(tiled_env_config)
154160
tiled_env.create_tiles()
155161

156-
configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": OUTPUT_MSG_FREQUENCY}
157-
162+
# agent configuration
158163
agent_config = SemiGradSARSAConfig(gamma=GAMMA, alpha=ALPHA, n_itrs_per_episode=N_ITRS_PER_EPISODE,
159164
policy=EpsilonGreedyQEstimator(EpsilonGreedyQEstimatorConfig(eps=EPS, n_actions=tiled_env.n_actions,
160165
decay_op=EPSILON_DECAY_OPTION,
161166
epsilon_decay_factor=EPSILON_DECAY_FACTOR,
162-
env=tiled_env, gamma=GAMMA, alpha=ALPHA)))
167+
env=tiled_env,
168+
gamma=GAMMA,
169+
alpha=ALPHA)))
170+
# create the agent
163171
agent = SemiGradSARSA(agent_config)
164172

165173
# create a trainer to train the Qlearning agent
174+
configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": OUTPUT_MSG_FREQUENCY}
166175
trainer = Trainer(env=tiled_env, agent=agent, configuration=configuration)
176+
177+
# train the agent
167178
trainer.train()
179+
180+
# avg_rewards = trainer.avg_rewards()
181+
avg_rewards = trainer.total_rewards
182+
plot_running_avg(avg_rewards, steps=100,
183+
xlabel="Episodes", ylabel="Reward",
184+
title="Running reward average over 100 episodes")
185+
186+
avg_episode_dist = np.array(trainer.total_distortions)
187+
print("{0} Max/Min distortion {1}/{2}".format(INFO, np.max(avg_episode_dist), np.min(avg_episode_dist)))
188+
189+
plot_running_avg(avg_episode_dist, steps=100,
190+
xlabel="Episodes", ylabel="Distortion",
191+
title="Running distortion average over 100 episodes")
192+
193+
print("=============================================")
194+
print("{0} Generating distorted dataset".format(INFO))
195+
196+
"""
197+
# Let's play
198+
env.reset()
199+
200+
stop_criterion = IterationControl(n_itrs=10, min_dist=MIN_DISTORTION, max_dist=MAX_DISTORTION)
201+
agent.play(env=env, stop_criterion=stop_criterion)
202+
env.save_current_dataset(episode_index=-2, save_index=False)
203+
"""
204+
print("{0} Done....".format(INFO))
205+
print("=============================================")

0 commit comments

Comments
 (0)