Skip to content

Commit bec8f6b

Browse files
committed
#13 Add Qlearing algorithm
1 parent 9175354 commit bec8f6b

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

src/algorithms/q_learning.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Simple Q-learning algorithm
3+
"""
4+
5+
import numpy as np
6+
from typing import TypeVar
7+
8+
from src.exceptions.exceptions import InvalidParamValue
9+
from src.utils.mixins import WithMaxActionMixin
10+
11+
Env = TypeVar('Env')
12+
Policy = TypeVar('Policy')
13+
14+
class QLearnConfig(object):
15+
16+
def __init__(self):
17+
self.gamma: float = 1.0
18+
self.alpha: float = 0.1
19+
self.n_itrs_per_episode: int = 100
20+
self.policy: Policy = None
21+
22+
23+
class QLearning(WithMaxActionMixin):
24+
25+
def __init__(self, algo_config: QLearnConfig):
26+
super(QLearning, self).__init__()
27+
self.q_table = {}
28+
self.config = algo_config
29+
30+
# monitor performance
31+
self.total_rewards: np.array = None
32+
self.iterations_per_episode = []
33+
34+
@property
35+
def name(self) -> str:
36+
return "QLearn"
37+
38+
def actions_before_training(self, env: Env, **options):
39+
40+
if self.config.policy is None:
41+
raise InvalidParamValue(param_name="policy", param_value="None")
42+
43+
for state in range(env.observation_space.n):
44+
for action in range(env.action_space.n):
45+
self.q_table[state, action] = 0.0
46+
47+
def actions_after_episode_ends(self, **options):
48+
"""
49+
Execute any actions the algorithm needs before
50+
starting the episode
51+
:param options:
52+
:return:
53+
"""
54+
55+
self.config.policy.actions_after_episode(options['episode_idx'])
56+
57+
def train(self, env: Env, **options) -> tuple:
58+
59+
# episode score
60+
episode_score = 0 # initialize score
61+
counter = 0
62+
63+
time_step = env.reset()
64+
state = time_step.observation
65+
66+
for itr in range(self.config.n_itrs_per_episode):
67+
68+
# epsilon-greedy action selection
69+
action_idx = self.config.policy(q_func=self.q_table, state=state)
70+
71+
action = env.get_action(action_idx)
72+
73+
# take action A, observe R, S'
74+
next_time_step = env.step(action)
75+
next_state = next_time_step.observation
76+
reward = next_time_step.reward
77+
78+
next_state_id = next_state.state_id if next_state is not None else None
79+
80+
# add reward to agent's score
81+
episode_score += next_time_step.reward
82+
self._update_Q_table(state=state.state_id, action=action_idx, reward=reward,
83+
next_state=next_state_id, n_actions=env.action_space.n)
84+
state = next_state # S <- S'
85+
counter += 1
86+
87+
if next_time_step.last():
88+
break
89+
90+
return episode_score, counter
91+
92+
def _update_Q_table(self, state: int, action: int, n_actions: int, reward: float, next_state: int = None) -> None:
93+
"""
94+
Update the Q-value for the state
95+
"""
96+
97+
# estimate in Q-table (for current state, action pair)
98+
q_s = self.q_table[state, action]
99+
100+
# value of next state
101+
Qsa_next = \
102+
self.q_table[next_state, self.max_action(next_state, n_actions=n_actions)] if next_state is not None else 0
103+
# construct TD target
104+
target = reward + (self.config.gamma * Qsa_next)
105+
106+
# get updated value
107+
new_value = q_s + (self.config.alpha * (target - q_s))
108+
self.q_table[state, action] = new_value

0 commit comments

Comments
 (0)