1111
1212from src .utils .mixins import WithMaxActionMixin , WithQTableMixinBase , WithEstimatorMixin
1313from src .utils .episode_info import EpisodeInfo
14+ from src .spaces .time_step import TimeStep
1415from src .exceptions .exceptions import InvalidParamValue
1516
17+
1618Policy = TypeVar ('Policy' )
1719Env = TypeVar ('Env' )
1820State = TypeVar ('State' )
@@ -38,11 +40,16 @@ class SemiGradSARSA(object):
3840 def __init__ (self , config : SemiGradSARSAConfig ) -> None :
3941 self .config : SemiGradSARSAConfig = config
4042
43+ @property
44+ def name (self ) -> str :
45+ return "Semi-Grad SARSA"
46+
4147 def actions_before_training (self , env : Env , ** options ) -> None :
4248 """Specify any actions necessary before training begins
4349
4450 Parameters
4551 ----------
52+
4653 env: The environment to train on
4754 options: Any key-value options passed by the client
4855
@@ -60,27 +67,74 @@ def actions_before_training(self, env: Env, **options) -> None:
6067 self.q_table[state, action] = 0.0
6168 """
6269
63- def on_episode (self , env : Env , ** options ) -> EpisodeInfo :
70+ def actions_before_episode_begins (self , env : Env , episode_idx : int , ** options ) -> None :
71+ """Any actions to perform before the episode begins
72+
73+ Parameters
74+ ----------
75+
76+ env: The instance of the training environment
77+ episode_idx: The training episode index
78+ options: Any keyword options passed by the client code
79+
80+ Returns
81+ -------
82+
83+ None
84+
85+ """
86+
87+ def actions_after_episode_ends (self , env : Env , episode_idx : int , ** options ) -> None :
88+ """Any actions after the training episode ends
89+
90+ Parameters
91+ ----------
92+
93+ env: The training environment
94+ episode_idx: The training episode index
95+ options: Any options passed by the client code
96+
97+ Returns
98+ -------
99+
100+ None
101+ """
102+
103+ def on_episode (self , env : Env , episode_idx : int , ** options ) -> EpisodeInfo :
104+ """Train the algorithm on the episode
105+
106+ Parameters
107+ ----------
108+
109+ env: The environment to train on
110+ options: Any keyword based options passed by the client code
111+
112+ Returns
113+ -------
114+
115+ An instance of EpisodeInfo
116+ """
64117
65118 episode_reward = 0.0
66119 episode_n_itrs = 0
67120
68121 # reset the environment
69- time_step = env .reset ()
122+ time_step = env .reset (** { "tiled_state" : False } )
70123
71124 # select a state
72125 state : State = time_step .observation
73126
74127 #choose an action using the policy
75- action : Action = self .config .policy (state )
128+ action : Action = self .config .policy . on_state (state )
76129
77130 for itr in range (self .config .n_itrs_per_episode ):
78131
79132 # take action and observe reward and next_state
80- time_step = env .step (action )
81- reward : float = 0.0
133+ time_step : TimeStep = env .step (action , ** {"tiled_state" : False })
134+
135+ reward : float = time_step .reward
82136 episode_reward += reward
83- next_state : State = None
137+ next_state : State = time_step . observation
84138
85139 # if next_state is terminal i.e. the done flag
86140 # is set. then update the weights
@@ -109,6 +163,7 @@ def _weights_update_episode_done(self, state: State, reward: float,
109163
110164 Parameters
111165 ----------
166+
112167 state: The current state
113168 reward: The reward to use
114169 action: The action we took at state
0 commit comments