Skip to content

Commit b77fe7a

Browse files
committed
Add n-step semi-gradient SARSA algorithm
1 parent 8765f0b commit b77fe7a

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Implementation of SARSA semi-gradient algorithm.
3+
Initial implementation is inspired from
4+
https://michaeloneill.github.io/RL-tutorial.html
5+
"""
6+
import numpy as np
7+
from typing import TypeVar
8+
9+
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase
10+
from src.algorithms.q_estimator import QEstimator
11+
from src.exceptions.exceptions import InvalidParamValue
12+
13+
Env = TypeVar('Env')
14+
Criterion = TypeVar('Criterion')
15+
Policy = TypeVar('Policy')
16+
Estimator = TypeVar('Estimator')
17+
18+
19+
class SARSAnConfig:
20+
21+
def __init__(self) -> None:
22+
self.gamma: float = 1.0
23+
self.alpha = 0.1
24+
self.n = 10
25+
self.n_itrs_per_episode: int = 100
26+
self.max_size: int = 4096
27+
self.use_trace: bool = False
28+
self.policy: Policy = None
29+
self.estimator: Estimator = None
30+
self.reset_estimator_z_traces: bool = False
31+
32+
33+
class SARSAn(WithMaxActionMixin):
34+
"""
35+
Implementation ofn-step semi-gradient SARSA algorithm
36+
"""
37+
38+
def __init__(self, sarsa_config: SARSAnConfig):
39+
40+
self.name = "SARSAn"
41+
self.config = sarsa_config
42+
self.q_table = {}
43+
44+
def play(self, env: Env, stop_criterion: Criterion) -> None:
45+
pass
46+
47+
def actions_before_training(self, env: Env) -> None:
48+
"""
49+
Any action to execute before
50+
entering the training loop
51+
:param env:
52+
:return:
53+
"""
54+
55+
is_tiled = getattr(env, "IS_TILED_ENV_CONSTRAINT", None)
56+
if is_tiled is None or is_tiled == False:
57+
raise ValueError("The given environment does not "
58+
"satisfy the IS_TILED_ENV_CONSTRAINT constraint")
59+
60+
if not isinstance(self.config.policy, WithQTableMixinBase):
61+
raise InvalidParamValue(param_name="policy", param_value=str(self.config.policy))
62+
63+
if self.config.estimator is None:
64+
raise ValueError("Estimator has not been set")
65+
66+
# reset the estimator
67+
self.config.estimator.reset(self.config.reset_estimator_z_traces)
68+
69+
def actions_before_episode_begins(self, **options) -> None:
70+
"""
71+
Actions for the agent to perform
72+
:param options:
73+
:return:
74+
"""
75+
# reset the estimator
76+
self.config.estimator.reset(self.config.reset_estimator_z_traces)
77+
78+
def on_episode(self, env: Env) -> tuple:
79+
"""
80+
Train the agent on the given algorithm
81+
:param env:
82+
:return:
83+
"""
84+
85+
# reset before the episode begins
86+
time_step = env.reset()
87+
state = time_step.observation
88+
89+
# vars to measure performance
90+
episode_score = 0
91+
counter = 0
92+
total_distortion = 0
93+
T = float('inf')
94+
actions = []
95+
rewards = []
96+
for itr in range(self.config.n_itrs_per_episode):
97+
98+
if itr < T:
99+
# select an action using the current
100+
# policy
101+
action_idx = self.config.policy(self.q_table, state)
102+
103+
action = env.get_action(action_idx)
104+
actions.append(action)
105+
106+
# take action A, observe R, S'
107+
next_time_step = env.step(action)
108+
next_state = next_time_step.observation
109+
reward = next_time_step.reward
110+
111+
episode_score += reward
112+
rewards.append(reward)
113+
114+
if next_time_step.done:
115+
T = itr + 1
116+
else:
117+
118+
# take the next step
119+
pass
120+
121+
# should we update
122+
update_time = itr + 1 - self.config.n
123+
if update_time >= 0:
124+
125+
# build target
126+
target = 0
127+
for i in range(update_time + 1, min(T, update_time + self.config.n) + 1):
128+
target += np.power(self.config.gamma, i - update_time - 1) * rewards[i]
129+
130+
if update_time + self.config.n < T:
131+
q_values_next = self.config.estimator.predict(states[update_time + self.config.n])
132+
target += q_values_next[actions[update_time + self.config.n]]
133+
134+
# Update step
135+
self.config.estimator.update(states[update_time], actions[update_time], target)
136+
137+
if update_time == T - 1:
138+
break
139+
140+
state = next_state
141+
action = next_action
142+
143+

0 commit comments

Comments
 (0)