Skip to content

Commit abd2e28

Browse files
committed
Add semi-gradient SARSA algo
1 parent f0f14dd commit abd2e28

File tree

3 files changed

+267
-0
lines changed

3 files changed

+267
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Module epsilon_greedy_q_estimator
2+
3+
"""
4+
5+
from src.utils.mixins import WithEstimatorMixin
6+
7+
class EpsilonGreedyQEstimator(WithEstimatorMixin):
8+
9+
def __init__(self):
10+
super(EpsilonGreedyQEstimator, self).__init__()
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Module semi_gradient_sarsa. Implements
2+
episodic semi-gradient SARSA for estimating the state-action
3+
value function. the im[plementation follows the algorithm
4+
at page 244 in the book by Sutton and Barto: Reinforcement Learning An Introduction
5+
second edition 2020
6+
7+
"""
8+
9+
from dataclasses import dataclass
10+
from typing import TypeVar
11+
12+
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase, WithEstimatorMixin
13+
from src.utils.episode_info import EpisodeInfo
14+
from src.exceptions.exceptions import InvalidParamValue
15+
16+
Policy = TypeVar('Policy')
17+
Env = TypeVar('Env')
18+
State = TypeVar('State')
19+
Action = TypeVar('Action')
20+
21+
22+
@dataclass(init=True, repr=True)
23+
class SemiGradSARSAConfig(object):
24+
"""Configuration class for semi-gradient SARSA algorithm
25+
"""
26+
gamma: float = 1.0
27+
alpha: float = 0.1
28+
n_itrs_per_episode: int = 100
29+
policy: Policy = None
30+
31+
32+
class SemiGradSARSA(object):
33+
"""SemiGradSARSA class. Implements the semi-gradient SARSA algorithm
34+
as described
35+
36+
"""
37+
38+
def __init__(self, config: SemiGradSARSAConfig) -> None:
39+
self.config: SemiGradSARSAConfig = config
40+
41+
def actions_before_training(self, env: Env, **options) -> None:
42+
"""Specify any actions necessary before training begins
43+
44+
Parameters
45+
----------
46+
env: The environment to train on
47+
options: Any key-value options passed by the client
48+
49+
Returns
50+
-------
51+
52+
None
53+
"""
54+
55+
self._validate()
56+
self._init()
57+
"""
58+
for state in range(1, env.n_states):
59+
for action in range(env.n_actions):
60+
self.q_table[state, action] = 0.0
61+
"""
62+
63+
def on_episode(self, env: Env, **options) -> EpisodeInfo:
64+
65+
episode_reward = 0.0
66+
episode_n_itrs = 0
67+
68+
# select a state
69+
state: State = None
70+
71+
#choose an action using the policy
72+
action: Action = None
73+
74+
for itr in range(self.config.n_itrs_per_episode):
75+
76+
# take action and observe reward and next_state
77+
78+
reward: float = 0.0
79+
episode_reward += reward
80+
next_state: State = None
81+
82+
# if next_state is terminal i.e. the done flag
83+
# is set. then update the weights
84+
85+
# otherwise chose next action as a function of q_hat
86+
next_action: Action = None
87+
# update the weights
88+
89+
# update state
90+
state = next_state
91+
92+
# update action
93+
action = next_action
94+
95+
episode_n_itrs += 1
96+
97+
episode_info = EpisodeInfo()
98+
episode_info.episode_score = episode_reward
99+
episode_info.episode_itrs = episode_n_itrs
100+
return episode_info
101+
102+
def _weights_update_episode_done(self, state: State, reward: float,
103+
action: Action, next_state: State) -> None:
104+
"""Update the weights due to the fact that
105+
the episode is finished
106+
107+
Parameters
108+
----------
109+
state: The current state
110+
reward: The reward to use
111+
action: The action we took at state
112+
next_state: The observed state
113+
114+
Returns
115+
-------
116+
117+
None
118+
"""
119+
pass
120+
121+
def _init(self) -> None:
122+
"""
123+
Any initializations needed before starting the training
124+
125+
Returns
126+
-------
127+
None
128+
"""
129+
pass
130+
131+
def _validate(self) -> None:
132+
"""
133+
Validate the state of the agent. Is called before
134+
any training begins to check that the starting state is sane
135+
136+
Returns
137+
-------
138+
139+
None
140+
"""
141+
142+
if self.config is None:
143+
raise InvalidParamValue(param_name="self.config", param_value="None")
144+
145+
if self.config.n_itrs_per_episode <= 0:
146+
raise ValueError("n_itrs_per_episode should be greater than zero")
147+
148+
if not isinstance(self.config.policy, WithEstimatorMixin):
149+
raise InvalidParamValue(param_name="policy", param_value=str(self.config.policy))
150+
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import unittest
2+
import pytest
3+
4+
from src.algorithms.semi_gradient_sarsa import SemiGradSARSAConfig, SemiGradSARSA
5+
from src.algorithms.epsilon_greedy_q_estimator import EpsilonGreedyQEstimator
6+
from src.exceptions.exceptions import InvalidParamValue
7+
from src.spaces.tiled_environment import TiledEnv
8+
from src.spaces.discrete_state_environment import DiscreteStateEnvironment
9+
from src.datasets.datasets_loaders import MockSubjectsLoader, MockSubjectsData
10+
11+
class TestSemiGradSARSA(unittest.TestCase):
12+
13+
def test_constructor(self):
14+
config = SemiGradSARSAConfig()
15+
semi_grad_sarsa = SemiGradSARSA(config)
16+
self.assertIsNotNone(semi_grad_sarsa.config)
17+
18+
def test_actions_before_training_throws_1(self):
19+
20+
semi_grad_sarsa = SemiGradSARSA(None)
21+
with pytest.raises(InvalidParamValue) as e:
22+
semi_grad_sarsa.actions_before_training(env=None)
23+
24+
def test_actions_before_training_throws_2(self):
25+
config = SemiGradSARSAConfig()
26+
config.n_itrs_per_episode = 0
27+
semi_grad_sarsa = SemiGradSARSA(config)
28+
29+
# make sure this is valid
30+
self.assertIsNotNone(semi_grad_sarsa.config)
31+
32+
with pytest.raises(ValueError) as e:
33+
semi_grad_sarsa.actions_before_training(env=None)
34+
35+
def test_actions_before_training_throws_3(self):
36+
config = SemiGradSARSAConfig()
37+
semi_grad_sarsa = SemiGradSARSA(config)
38+
39+
# make sure this is valid
40+
self.assertIsNotNone(semi_grad_sarsa.config)
41+
42+
with pytest.raises(InvalidParamValue) as e:
43+
semi_grad_sarsa.actions_before_training(env=None)
44+
45+
def test_on_episode_returns_info(self):
46+
config = SemiGradSARSAConfig()
47+
semi_grad_sarsa = SemiGradSARSA(config)
48+
49+
# make sure this is valid
50+
self.assertIsNotNone(semi_grad_sarsa.config)
51+
52+
episode_info = semi_grad_sarsa.on_episode(env=None)
53+
self.assertIsNotNone(episode_info)
54+
55+
def test_on_episode_trains(self):
56+
57+
sarsa_config = SemiGradSARSAConfig(n_itrs_per_episode=1, policy=EpsilonGreedyQEstimator())
58+
semi_grad_sarsa = SemiGradSARSA(sarsa_config)
59+
60+
# cretate a default data
61+
ds_default_data = MockSubjectsData()
62+
ds = MockSubjectsLoader.from_options(filename=ds_default_data.FILENAME,
63+
names=ds_default_data.NAMES, drop_na=ds_default_data.DROP_NA,
64+
change_col_vals=ds_default_data.CHANGE_COLS_VALS,
65+
features_drop_names=ds_default_data.FEATURES_DROP_NAMES +
66+
["preventative_treatment", "gender",
67+
"education", "mutation_status"],
68+
column_normalization=["salary"], column_types={"ethnicity": str,
69+
"salary": float,
70+
"diagnosis": int})
71+
72+
discrete_env = DiscreteStateEnvironment.from_options(data_set=ds, action_space=None,
73+
reward_manager=None, distortion_calculator=None)
74+
tiled_env = TiledEnv.from_options(env=discrete_env, max_size=4096, num_tilings=5, n_bins=10,
75+
column_ranges={"ethnicity": [0.0, 1.0],
76+
"salary": [0.0, 1.0],
77+
"diagnosis": [0.0, 1.0]}, tiling_dim=3)
78+
79+
"""
80+
# specify the columns to drop
81+
drop_columns = MockSubjectsLoader.FEATURES_DROP_NAMES + ["preventative_treatment", "gender",
82+
"education", "mutation_status"]
83+
MockSubjectsLoader.FEATURES_DROP_NAMES = drop_columns
84+
85+
# do a salary normalization so that we work with
86+
# salaries in [0, 1] this is needed as we will
87+
# be using normalized distances
88+
MockSubjectsLoader.NORMALIZED_COLUMNS = ["salary"]
89+
90+
# specify the columns to use
91+
MockSubjectsLoader.COLUMNS_TYPES = {"ethnicity": str, "salary": float, "diagnosis": int}
92+
ds = MockSubjectsLoader()
93+
"""
94+
95+
# create the discrete environment
96+
97+
semi_grad_sarsa.actions_before_training(tiled_env)
98+
99+
# make sure this is valid
100+
self.assertIsNotNone(semi_grad_sarsa.config)
101+
102+
episode_info = semi_grad_sarsa.on_episode(env=tiled_env)
103+
self.assertIsNotNone(episode_info)
104+
105+
106+
if __name__ == '__main__':
107+
unittest.main()

0 commit comments

Comments
 (0)