66
77"""
88
9- from dataclasses import dataclass
9+ from dataclasses import dataclass
1010from typing import TypeVar
1111
12- from src .utils .mixins import WithMaxActionMixin , WithQTableMixinBase , WithEstimatorMixin
12+ from src .utils .mixins import WithEstimatorMixin
1313from src .utils .episode_info import EpisodeInfo
1414from src .spaces .time_step import TimeStep
15+ from src .utils .function_wraps import time_func_wrapper
1516from src .exceptions .exceptions import InvalidParamValue
1617
17-
1818Policy = TypeVar ('Policy' )
1919Env = TypeVar ('Env' )
2020State = TypeVar ('State' )
@@ -61,11 +61,6 @@ def actions_before_training(self, env: Env, **options) -> None:
6161
6262 self ._validate ()
6363 self ._init ()
64- """
65- for state in range(1, env.n_states):
66- for action in range(env.n_actions):
67- self.q_table[state, action] = 0.0
68- """
6964
7065 def actions_before_episode_begins (self , env : Env , episode_idx : int , ** options ) -> None :
7166 """Any actions to perform before the episode begins
@@ -107,6 +102,33 @@ def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
107102 ----------
108103
109104 env: The environment to train on
105+ episode_idx: The index of the training episode
106+ options: Any keyword based options passed by the client code
107+
108+ Returns
109+ -------
110+
111+ An instance of EpisodeInfo
112+ """
113+
114+ episode_info_ , total_execution_time = self ._do_train (env = env , episode_idx = episode_idx , ** options )
115+
116+ episode_info = EpisodeInfo ()
117+ episode_info .episode_score = episode_info_ .episode_score
118+ episode_info .episode_itrs = episode_info_ .episode_itrs
119+ episode_info .total_distortion = episode_info_ .total_distortion
120+ episode_info .total_execution_time = total_execution_time
121+ return episode_info
122+
123+ @time_func_wrapper (show_time = False )
124+ def _do_train (self , env : Env , episode_idx : int , ** options ) -> EpisodeInfo :
125+ """Train the algorithm on the episode
126+
127+ Parameters
128+ ----------
129+
130+ env: The environment to train on
131+ episode_idx: The index of the training episode
110132 options: Any keyword based options passed by the client code
111133
112134 Returns
@@ -115,76 +137,142 @@ def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
115137 An instance of EpisodeInfo
116138 """
117139
118- episode_reward = 0.0
119- episode_n_itrs = 0
140+ episode_reward : float = 0.0
141+ episode_n_itrs : int = 0
142+ total_episode_distortion : float = 0.0
120143
121144 # reset the environment
122145 time_step = env .reset (** {"tiled_state" : False })
123146
124- # select a state
147+ # obtain the initial state S
125148 state : State = time_step .observation
126149
127- #choose an action using the policy
150+ # initial action A
128151 action : Action = self .config .policy .on_state (state )
129152
130153 for itr in range (self .config .n_itrs_per_episode ):
131154
132- # take action and observe reward and next_state
155+ # take action A
133156 time_step : TimeStep = env .step (action , ** {"tiled_state" : False })
134157
158+ # ... observe reward R
135159 reward : float = time_step .reward
136160 episode_reward += reward
161+ total_episode_distortion += time_step .info ["total_distortion" ]
162+
163+ # ... observe the S prime
137164 next_state : State = time_step .observation
138165
139166 # if next_state is terminal i.e. the done flag
140167 # is set. then update the weights
168+ if time_step .done :
169+ self ._weights_update_episode_done (env = env , state = state , action = action , reward = reward )
170+ break
171+
172+ # choose action A prime as a function of q_hat(S prime, *, w)
173+ next_action : Action = self .config .policy .on_state (next_state )
141174
142- # otherwise chose next action as a function of q_hat
143- next_action : Action = None
144- # update the weights
175+ # update the weights. This expects tiled vector states
176+ self . _weights_update ( env = env , state = state , action = action ,
177+ next_state = next_state , next_action = next_action , reward = reward )
145178
146179 # update state
147- state = next_state
180+ state : State = next_state
148181
149182 # update action
150- action = next_action
183+ action : Action = next_action
151184
152185 episode_n_itrs += 1
153186
154187 episode_info = EpisodeInfo ()
155188 episode_info .episode_score = episode_reward
156189 episode_info .episode_itrs = episode_n_itrs
190+ episode_info .total_distortion = total_episode_distortion
157191 return episode_info
158192
159- def _weights_update_episode_done (self , state : State , reward : float ,
160- action : Action , next_state : State ) -> None :
193+ def _weights_update_episode_done (self , env : Env , state : State , action : Action ,
194+ reward : float , t : float = 1.0 ) -> None :
195+ """Update the weights of the underlying Q-estimator
196+
197+ Parameters
198+ ----------
199+
200+ state: The current state it is assumed to be a raw state
201+ reward: The reward observed when taking the given action when at the given state
202+ action: The action we took at the state
203+
204+
205+ Returns
206+ -------
207+
208+ None
209+ """
210+ action_id = action
211+ if not isinstance (action , int ):
212+ action_id = action .idx
213+
214+ # get a copy of the weights
215+ weights = self .config .policy .weights
216+
217+ tiled_state = env .featurize_state_action (action = action_id , state = state )
218+ v1 = self .config .policy .q_hat_value (state_action_vec = tiled_state )
219+
220+ weights += self .config .alpha / t * (reward - v1 ) * tiled_state
221+ self .config .policy .weights = weights
222+
223+ def _weights_update (self , env : Env , state : State , action : Action , reward : float ,
224+ next_state : State , next_action : Action , t : float = 1.0 ) -> None :
161225 """Update the weights due to the fact that
162226 the episode is finished
163227
164228 Parameters
165229 ----------
166230
231+ env: The environment instance that the training takes place
167232 state: The current state
168- reward: The reward to use
169233 action: The action we took at state
170- next_state: The observed state
234+ reward: The reward observed when taking the given action when at the given state
235+ next_state: The observed new state
236+ next_action: The action to be executed in next_state
171237
172238 Returns
173239 -------
174240
175241 None
176242 """
177- pass
243+
244+ action_id_1 = action
245+ if not isinstance (action , int ):
246+ action_id_1 = action .idx
247+
248+ action_id_2 = next_action
249+ if not isinstance (action , int ):
250+ action_id_2 = next_action .idx
251+
252+ # get a copy of the weights
253+ weights = self .config .policy .weights
254+
255+ tiled_state1 = env .featurize_state_action (action = action_id_1 , state = state )
256+ tiled_state2 = env .featurize_state_action (action = action_id_2 , state = next_state )
257+
258+ v1 = self .config .policy .q_hat_value (state_action_vec = tiled_state1 )
259+ v2 = self .config .policy .q_hat_value (state_action_vec = tiled_state2 )
260+ weights += self .config .alpha / t * (reward + self .config .gamma * v2 - v1 ) * tiled_state1
261+ self .config .policy .weights = weights
178262
179263 def _init (self ) -> None :
180- """
181- Any initializations needed before starting the training
264+ """Any initializations needed before starting the training
182265
183266 Returns
184267 -------
268+
185269 None
270+
186271 """
187- pass
272+
273+ if self .config .policy .weights is None or \
274+ len (self .config .policy .weights ) == 0 :
275+ self .config .policy .initialize ()
188276
189277 def _validate (self ) -> None :
190278 """
@@ -205,4 +293,3 @@ def _validate(self) -> None:
205293
206294 if not isinstance (self .config .policy , WithEstimatorMixin ):
207295 raise InvalidParamValue (param_name = "policy" , param_value = str (self .config .policy ))
208-
0 commit comments