Skip to content

Commit 72e98cc

Browse files
committed
#13 Add policies
1 parent 6bc73ad commit 72e98cc

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

src/policies/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from typing import TypeVar
3+
4+
from src.policies.policy_adaptor_base import PolicyAdaptorBase
5+
6+
PolicyBase = TypeVar('PolicyBase')
7+
8+
9+
class DeterministicAdaptorPolicy(PolicyAdaptorBase):
10+
11+
"""
12+
Update a policy by choosing the best action
13+
"""
14+
15+
def __init__(self) -> None:
16+
super(DeterministicAdaptorPolicy, self).__init__()
17+
18+
def __call__(self, policy: PolicyBase, *args, **kwargs) -> PolicyBase:
19+
s: int = kwargs["s"]
20+
state_actions: np.ndarray = kwargs["state_actions"]
21+
action = np.argmax(state_actions)
22+
policy[s][action] = 1.0
23+
return policy
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Epsilon greedy policy implementation
3+
"""
4+
import random
5+
import numpy as np
6+
from enum import Enum
7+
from typing import Any, TypeVar
8+
9+
10+
from src.utils.mixins import WithMaxActionMixin
11+
12+
UserDefinedDecreaseMethod = TypeVar('UserDefinedDecreaseMethod')
13+
Env = TypeVar("Env")
14+
15+
16+
class EpsilonDecreaseOption(Enum):
17+
"""
18+
Options for reducing epsilon
19+
"""
20+
21+
NONE = 0
22+
EXPONENTIAL = 1
23+
INVERSE_STEP = 2
24+
CONSTANT_RATE = 3
25+
USER_DEFINED = 4
26+
27+
28+
class EpsilonGreedyPolicy(WithMaxActionMixin):
29+
30+
def __init__(self, env: Env, eps: float,
31+
decay_op: EpsilonDecreaseOption,
32+
max_eps: float = 1.0, min_eps: float = 0.001,
33+
epsilon_decay_factor: float = 0.01,
34+
user_defined_decrease_method: UserDefinedDecreaseMethod = None) -> None:
35+
super(WithMaxActionMixin, self).__init__()
36+
self._eps = eps
37+
self._n_actions = env.action_space.n
38+
self._decay_op = decay_op
39+
self._max_eps = max_eps
40+
self._min_eps = min_eps
41+
self._epsilon_decay_factor = epsilon_decay_factor
42+
self.user_defined_decrease_method: UserDefinedDecreaseMethod = user_defined_decrease_method
43+
44+
def __call__(self, q_func: Any, state: Any) -> int:
45+
46+
# select greedy action with probability epsilon
47+
if random.random() > self._eps:
48+
self.q_table = q_func
49+
return self.max_action(state=state, n_actions=self._n_actions)
50+
51+
else:
52+
53+
# otherwise, select an action randomly
54+
# what happens if we select an action that
55+
# has exhausted it's transforms?
56+
return random.choice(np.arange(self._n_actions))
57+
58+
def actions_after_episode(self, episode_idx: int, **options) -> None:
59+
"""
60+
Apply actions on the policy after the end of the episode
61+
:param episode_idx: The episode index
62+
:param options:
63+
:return: None
64+
"""
65+
66+
if self._decay_op == EpsilonDecreaseOption.NONE:
67+
return
68+
69+
if self._decay_op == EpsilonDecreaseOption.USER_DEFINED:
70+
self._eps = self.user_defined_decrease_method(self._eps, episode_idx)
71+
72+
if self._decay_op == EpsilonDecreaseOption.INVERSE_STEP:
73+
74+
if episode_idx == 0:
75+
episode_idx = 1
76+
77+
self._eps = 1.0 / episode_idx
78+
79+
elif self._decay_op == EpsilonDecreaseOption.EXPONENTIAL:
80+
self._eps = self._min_eps + (self._max_eps - self._min_eps) * np.exp(-self._epsilon_decay_factor * episode_idx)
81+
82+
elif self._decay_op == EpsilonDecreaseOption.CONSTANT_RATE:
83+
self._eps -= self._epsilon_decay_factor
84+
85+
if self._eps < self._min_eps:
86+
self._eps = self._min_eps
87+
88+

0 commit comments

Comments
 (0)