Skip to content

Commit 5d50039

Browse files
authored
Merge pull request #68 from pockerman/investigate_sarsa_semi_gradient
Update API
2 parents c1f0335 + 669fe20 commit 5d50039

File tree

8 files changed

+211
-46
lines changed

8 files changed

+211
-46
lines changed

docs/source/conf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,11 @@
3939
'sphinx.ext.doctest',
4040
'sphinx.ext.autodoc',
4141
'sphinx.ext.autosummary',
42-
"numpydoc",
42+
#"numpydoc",
4343
'sphinx.ext.napoleon'
4444
]
4545

4646
#extensions = ['sphinx.ext.napoleon']
47-
48-
numpydoc_show_class_members = True
49-
5047
# generate autosummary even if no references
5148
autosummary_generate = True
5249
autosummary_imported_members = False

src/algorithms/epsilon_greedy_q_estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,10 @@ def q_hat_value(self, state_action_vec: StateActionVec) -> float:
7979

8080
return self.weights.dot(state_action_vec)
8181

82+
"""
8283
def update_weights(self, total_reward: float, state_action: Action,
8384
state_action_: Action, t: float) -> None:
84-
"""
85+
8586
Update the weights
8687
8788
Parameters
@@ -97,14 +98,15 @@ def update_weights(self, total_reward: float, state_action: Action,
9798
9899
None
99100
100-
"""
101+
101102
102103
if self.weights is None:
103104
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")
104105
105106
v1 = self.q_hat_value(state_action_vec=state_action)
106107
v2 = self.q_hat_value(state_action_vec=state_action_)
107108
self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action
109+
"""
108110

109111
def on_state(self, state: State) -> Action:
110112
"""Returns the action on the given state

src/algorithms/q_learning.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from src.exceptions.exceptions import InvalidParamValue
99
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase
10+
from src.utils.episode_info import EpisodeInfo
11+
from src.utils.function_wraps import time_func_wrapper
1012

1113
Env = TypeVar('Env')
1214
Policy = TypeVar('Policy')
@@ -86,7 +88,42 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
8688
env.step(action=action)
8789
total_dist = env.total_current_distortion()
8890

89-
def on_episode(self, env: Env, **options) -> tuple:
91+
def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
92+
"""Train the algorithm on the episode
93+
94+
Parameters
95+
----------
96+
97+
env: The environment to train on
98+
episode_idx: The index of the training episode
99+
options: Any keyword based options passed by the client code
100+
101+
Returns
102+
-------
103+
104+
An instance of EpisodeInfo
105+
"""
106+
107+
episode_info, total_time = self._do_train(env, episode_idx, **options)
108+
episode_info.total_execution_time = total_time
109+
return episode_info
110+
111+
@time_func_wrapper(show_time=False)
112+
def _do_train(self, env: Env, episode_idx: int, **option) -> EpisodeInfo:
113+
"""Train the algorithm on the episode
114+
115+
Parameters
116+
----------
117+
118+
env: The environment to train on
119+
episode_idx: The index of the training episode
120+
options: Any keyword based options passed by the client code
121+
122+
Returns
123+
-------
124+
125+
An instance of EpisodeInfo
126+
"""
90127

91128
# episode score
92129
episode_score = 0
@@ -119,7 +156,10 @@ def on_episode(self, env: Env, **options) -> tuple:
119156
if next_time_step.last():
120157
break
121158

122-
return episode_score, total_distortion, counter
159+
episode_info = EpisodeInfo(episode_score=episode_score, total_distortion=total_distortion, episode_itrs=counter)
160+
return episode_info
161+
162+
123163

124164
def _update_Q_table(self, state: int, action: int, n_actions: int,
125165
reward: float, next_state: int = None) -> None:

src/algorithms/semi_gradient_sarsa.py

Lines changed: 114 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
77
"""
88

9-
from dataclasses import dataclass
9+
from dataclasses import dataclass
1010
from typing import TypeVar
1111

12-
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase, WithEstimatorMixin
12+
from src.utils.mixins import WithEstimatorMixin
1313
from src.utils.episode_info import EpisodeInfo
1414
from src.spaces.time_step import TimeStep
15+
from src.utils.function_wraps import time_func_wrapper
1516
from src.exceptions.exceptions import InvalidParamValue
1617

17-
1818
Policy = TypeVar('Policy')
1919
Env = TypeVar('Env')
2020
State = TypeVar('State')
@@ -61,11 +61,6 @@ def actions_before_training(self, env: Env, **options) -> None:
6161

6262
self._validate()
6363
self._init()
64-
"""
65-
for state in range(1, env.n_states):
66-
for action in range(env.n_actions):
67-
self.q_table[state, action] = 0.0
68-
"""
6964

7065
def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None:
7166
"""Any actions to perform before the episode begins
@@ -107,6 +102,33 @@ def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
107102
----------
108103
109104
env: The environment to train on
105+
episode_idx: The index of the training episode
106+
options: Any keyword based options passed by the client code
107+
108+
Returns
109+
-------
110+
111+
An instance of EpisodeInfo
112+
"""
113+
114+
episode_info_, total_execution_time = self._do_train(env=env, episode_idx=episode_idx, **options)
115+
116+
episode_info = EpisodeInfo()
117+
episode_info.episode_score = episode_info_.episode_score
118+
episode_info.episode_itrs = episode_info_.episode_itrs
119+
episode_info.total_distortion = episode_info_.total_distortion
120+
episode_info.total_execution_time = total_execution_time
121+
return episode_info
122+
123+
@time_func_wrapper(show_time=False)
124+
def _do_train(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
125+
"""Train the algorithm on the episode
126+
127+
Parameters
128+
----------
129+
130+
env: The environment to train on
131+
episode_idx: The index of the training episode
110132
options: Any keyword based options passed by the client code
111133
112134
Returns
@@ -115,76 +137,142 @@ def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
115137
An instance of EpisodeInfo
116138
"""
117139

118-
episode_reward = 0.0
119-
episode_n_itrs = 0
140+
episode_reward: float = 0.0
141+
episode_n_itrs: int = 0
142+
total_episode_distortion: float = 0.0
120143

121144
# reset the environment
122145
time_step = env.reset(**{"tiled_state": False})
123146

124-
# select a state
147+
# obtain the initial state S
125148
state: State = time_step.observation
126149

127-
#choose an action using the policy
150+
# initial action A
128151
action: Action = self.config.policy.on_state(state)
129152

130153
for itr in range(self.config.n_itrs_per_episode):
131154

132-
# take action and observe reward and next_state
155+
# take action A
133156
time_step: TimeStep = env.step(action, **{"tiled_state": False})
134157

158+
# ... observe reward R
135159
reward: float = time_step.reward
136160
episode_reward += reward
161+
total_episode_distortion += time_step.info["total_distortion"]
162+
163+
# ... observe the S prime
137164
next_state: State = time_step.observation
138165

139166
# if next_state is terminal i.e. the done flag
140167
# is set. then update the weights
168+
if time_step.done:
169+
self._weights_update_episode_done(env=env, state=state, action=action, reward=reward)
170+
break
171+
172+
# choose action A prime as a function of q_hat(S prime, *, w)
173+
next_action: Action = self.config.policy.on_state(next_state)
141174

142-
# otherwise chose next action as a function of q_hat
143-
next_action: Action = None
144-
# update the weights
175+
# update the weights. This expects tiled vector states
176+
self._weights_update(env=env, state=state, action=action,
177+
next_state=next_state, next_action=next_action, reward=reward)
145178

146179
# update state
147-
state = next_state
180+
state: State = next_state
148181

149182
# update action
150-
action = next_action
183+
action: Action = next_action
151184

152185
episode_n_itrs += 1
153186

154187
episode_info = EpisodeInfo()
155188
episode_info.episode_score = episode_reward
156189
episode_info.episode_itrs = episode_n_itrs
190+
episode_info.total_distortion = total_episode_distortion
157191
return episode_info
158192

159-
def _weights_update_episode_done(self, state: State, reward: float,
160-
action: Action, next_state: State) -> None:
193+
def _weights_update_episode_done(self, env: Env, state: State, action: Action,
194+
reward: float, t: float = 1.0) -> None:
195+
"""Update the weights of the underlying Q-estimator
196+
197+
Parameters
198+
----------
199+
200+
state: The current state it is assumed to be a raw state
201+
reward: The reward observed when taking the given action when at the given state
202+
action: The action we took at the state
203+
204+
205+
Returns
206+
-------
207+
208+
None
209+
"""
210+
action_id = action
211+
if not isinstance(action, int):
212+
action_id = action.idx
213+
214+
# get a copy of the weights
215+
weights = self.config.policy.weights
216+
217+
tiled_state = env.featurize_state_action(action=action_id, state=state)
218+
v1 = self.config.policy.q_hat_value(state_action_vec=tiled_state)
219+
220+
weights += self.config.alpha / t * (reward - v1) * tiled_state
221+
self.config.policy.weights = weights
222+
223+
def _weights_update(self, env: Env, state: State, action: Action, reward: float,
224+
next_state: State, next_action: Action, t: float = 1.0) -> None:
161225
"""Update the weights due to the fact that
162226
the episode is finished
163227
164228
Parameters
165229
----------
166230
231+
env: The environment instance that the training takes place
167232
state: The current state
168-
reward: The reward to use
169233
action: The action we took at state
170-
next_state: The observed state
234+
reward: The reward observed when taking the given action when at the given state
235+
next_state: The observed new state
236+
next_action: The action to be executed in next_state
171237
172238
Returns
173239
-------
174240
175241
None
176242
"""
177-
pass
243+
244+
action_id_1 = action
245+
if not isinstance(action, int):
246+
action_id_1 = action.idx
247+
248+
action_id_2 = next_action
249+
if not isinstance(action, int):
250+
action_id_2 = next_action.idx
251+
252+
# get a copy of the weights
253+
weights = self.config.policy.weights
254+
255+
tiled_state1 = env.featurize_state_action(action=action_id_1, state=state)
256+
tiled_state2 = env.featurize_state_action(action=action_id_2, state=next_state)
257+
258+
v1 = self.config.policy.q_hat_value(state_action_vec=tiled_state1)
259+
v2 = self.config.policy.q_hat_value(state_action_vec=tiled_state2)
260+
weights += self.config.alpha / t * (reward + self.config.gamma * v2 - v1) * tiled_state1
261+
self.config.policy.weights = weights
178262

179263
def _init(self) -> None:
180-
"""
181-
Any initializations needed before starting the training
264+
"""Any initializations needed before starting the training
182265
183266
Returns
184267
-------
268+
185269
None
270+
186271
"""
187-
pass
272+
273+
if self.config.policy.weights is None or \
274+
len(self.config.policy.weights) == 0:
275+
self.config.policy.initialize()
188276

189277
def _validate(self) -> None:
190278
"""
@@ -205,4 +293,3 @@ def _validate(self) -> None:
205293

206294
if not isinstance(self.config.policy, WithEstimatorMixin):
207295
raise InvalidParamValue(param_name="policy", param_value=str(self.config.policy))
208-

src/algorithms/trainer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TypeVar
88

99
from src.utils import INFO
10-
from src.utils.function_wraps import time_func
10+
from src.utils.function_wraps import time_func, time_func_wrapper
1111
from src.utils.episode_info import EpisodeInfo
1212

1313
Env = TypeVar("Env")
@@ -17,6 +17,17 @@
1717
class Trainer(object):
1818

1919
def __init__(self, env: Env, agent: Agent, configuration: dir) -> None:
20+
"""Constructor. Initialize a trainer by passing the training environment
21+
instance the agen to train and configuration dictionary
22+
23+
Parameters
24+
----------
25+
26+
env: The environment to train the agent
27+
agent: The agent to train
28+
configuration: Configuration parameters for the trainer
29+
30+
"""
2031
self.env = env
2132
self.agent = agent
2233
self.configuration = configuration
@@ -96,9 +107,9 @@ def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> N
96107

97108
if episode_idx % self.configuration['output_msg_frequency'] == 0:
98109
if self.env.config.distorted_set_path is not None:
99-
self.env.save_current_dataset(options["episode_idx"])
110+
self.env.save_current_dataset(episode_idx)
100111

101-
@time_func
112+
@time_func_wrapper(show_time=True)
102113
def train(self):
103114

104115
print("{0} Training agent {1}".format(INFO, self.agent.name))
@@ -115,13 +126,14 @@ def train(self):
115126
#episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)
116127
episode_info: EpisodeInfo = self.agent.on_episode(self.env, episode)
117128

129+
print("{0} Episode {1} finished in {2} secs".format(INFO, episode, episode_info.total_execution_time))
118130
print("{0} Episode score={1}, episode total avg distortion {2}".format(INFO, episode_info.episode_score,
119-
episode_info.total_distortion / episode_info.info["n_iterations"]))
131+
episode_info.total_distortion / episode_info.episode_itrs))
120132

121133
#if episode % self.configuration['output_msg_frequency'] == 0:
122-
print("{0} Episode finished after {1} iterations".format(INFO, episode_info.info["n_iterations"]))
134+
print("{0} Episode finished after {1} iterations".format(INFO, episode_info.episode_itrs))
123135

124-
self.iterations_per_episode.append(episode_info.info["n_iterations"])
136+
self.iterations_per_episode.append(episode_info.episode_itrs)
125137
self.total_rewards[episode] = episode_info.episode_score
126138
self.total_distortions.append(episode_info.total_distortion)
127139
self.actions_after_episode_ends(self.env, episode, **{})

0 commit comments

Comments
 (0)