Skip to content

Commit 8cd2fb5

Browse files
committed
#27 Add softmax policy
1 parent 0065764 commit 8cd2fb5

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

src/policies/softmax_policy.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
from typing import TypeVar, Any
3+
from src.utils.mixins import WithQTableMixin
4+
5+
QTable = TypeVar('QTable')
6+
7+
8+
class SoftMaxPolicy(WithQTableMixin):
9+
10+
def __init__(self, n_actions: int, tau: float) -> None:
11+
self.n_actions = n_actions
12+
self.tau = tau
13+
14+
def __str__(self) -> str:
15+
return "SoftMaxPolicy"
16+
17+
def __call__(self, q_table: QTable, state: Any) -> int:
18+
"""
19+
Execute the policy
20+
:param q_table:
21+
:param state:
22+
:return:
23+
"""
24+
self.q_table = q_table
25+
action_values = [q_table[state, a] for a in range(self.n_actions)]
26+
softmax = np.exp(np.array(action_values) / self.tau) / np.sum( np.exp(np.array(action_values) / self.tau) )
27+
28+
# return the action index by choosing from
29+
return np.random.choice( [a for a in range(self.n_actions)], p=softmax)
30+
31+
def on_state(self, state: Any) -> int:
32+
"""
33+
Returns the optimal action on the current state
34+
:param state:
35+
:return:
36+
"""
37+
action_values = [self.q_table[state, a] for a in range(self.n_actions)]
38+
softmax = np.exp(np.array(action_values) / self.tau) / np.sum(np.exp(np.array(action_values) / self.tau))
39+
40+
# return the action index by choosing from
41+
return np.random.choice([a for a in range(self.n_actions)], p=softmax)
42+
43+
def actions_after_episode(self, episode_idx: int, **options) -> None:
44+
pass

src/utils/iteration_control.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Utility to control iteration
3+
"""
4+
5+
from src.utils import INFO, VERSION
6+
7+
8+
class IterationControl(object):
9+
"""
10+
Helper class to control iteration
11+
"""
12+
13+
def __init__(self, n_itrs: int, min_dist: float, max_dist: float) -> None:
14+
self.n_itrs = n_itrs
15+
self.min_dist = min_dist
16+
self.max_dist = max_dist
17+
self.iteration_counter = 0
18+
19+
def continue_itr(self, distortion: float):
20+
21+
if self.min_dist <= distortion <= self.max_dist:
22+
print("{0} Finished iteration with distortion={1} "
23+
"in [{2}, {3}]. Number of iterations={4}".format(INFO, distortion,
24+
self.min_dist, self.max_dist,
25+
self.iteration_counter))
26+
return False
27+
28+
# TODO: There is no way currently we go back
29+
# to acceptable distortions
30+
if VERSION == '0.0.1-alpha':
31+
if distortion > self.max_dist:
32+
print("{0} Finished iteration with distortion={1} "
33+
"in [{2}, {3}]. Number of iterations={4}".format(INFO, distortion, self.min_dist,
34+
self.max_dist, self.iteration_counter))
35+
return False
36+
37+
if 0 <= self.n_itrs <= self.iteration_counter:
38+
print("{0} Reached maximum number of iterations. With distortion={1} "
39+
"in [{2}, {3}]. Number of iterations={4}".format(INFO, distortion,
40+
self.min_dist, self.max_dist,
41+
self.iteration_counter))
42+
return False
43+
44+
self.iteration_counter += 1
45+
print("{0} Iteration={1} Distortion={2}".format(INFO, self.iteration_counter, distortion))
46+
return True

src/utils/plot_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
5+
def plot_running_avg(avg_array, steps: int,
6+
xlabel: str, ylabel: str,
7+
title: str, show_grid: bool = True):
8+
"""
9+
Plot a running average of the values in the arra
10+
:param title:
11+
:param xlabel:
12+
:param ylabel:
13+
:param show_grid:
14+
:param avg_array:
15+
:param steps:
16+
:return:
17+
"""
18+
19+
running_avg = np.empty(avg_array.shape[0])
20+
for t in range(avg_array.shape[0]):
21+
running_avg[t] = np.mean(avg_array[max(0, t-steps): (t+1)])
22+
23+
plt.plot(running_avg)
24+
plt.xlabel(xlabel)
25+
plt.ylabel(ylabel)
26+
plt.title(title)
27+
28+
if show_grid:
29+
plt.grid()
30+
31+
plt.show()

0 commit comments

Comments
 (0)