Skip to content

Commit 6bc73ad

Browse files
committed
#13 Add mixins
1 parent bec8f6b commit 6bc73ad

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

src/utils/mixins.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Various mixin classes to use for simplifying code
3+
"""
4+
5+
import numpy as np
6+
from typing import TypeVar, Any
7+
8+
QTable = TypeVar('QTable')
9+
Hierarchy = TypeVar('Hierarchy')
10+
11+
12+
class WithHierarchyTable(object):
13+
14+
def __init__(self) -> None:
15+
self.table = {}
16+
self.iterators = []
17+
18+
def add_hierarchy(self, key: str, hierarchy: Hierarchy) -> None:
19+
"""
20+
Add a hierarchy for the given key
21+
:param key: The key to attach the Hierarchy
22+
:param hierarchy: The hierarchy to attach
23+
:return: None
24+
"""
25+
self.table[key] = hierarchy
26+
27+
def reset_iterators(self):
28+
"""
29+
Reinitialize the iterators in the table
30+
:return:
31+
"""
32+
33+
# fill in the iterators
34+
self.iterators = [iter(self.table[item]) for item in self.table]
35+
36+
def finished(self) -> bool:
37+
"""
38+
Returns true if the action has exhausted all its
39+
transforms
40+
:return:
41+
"""
42+
exhausted = True
43+
44+
for item in self.table:
45+
if not self.table[item].is_exhausted():
46+
return False
47+
48+
return exhausted
49+
50+
51+
class WithQTableMixin(object):
52+
"""
53+
Helper class to associate a q_table with an algorithm
54+
if this is needed.
55+
"""
56+
def __init__(self):
57+
# the table representing the q function
58+
# client code should choose the type of
59+
# the table
60+
self.q_table: QTable = None
61+
62+
63+
class WithMaxActionMixin(object):
64+
"""
65+
The class WithMaxActionMixin.
66+
"""
67+
68+
def __init__(self):
69+
super(WithMaxActionMixin, self).__init__()
70+
self.q_table: QTable = None
71+
72+
def max_action(self, state: Any, n_actions: int) -> int:
73+
"""
74+
Return the action index that presents the maximum
75+
value at the given state
76+
:param state: state index
77+
:param n_actions: Total number of actions allowed
78+
:return: The action that corresponds to the maximum value
79+
"""
80+
values = np.array(self.q_table[state, a] for a in range(n_actions))
81+
action = np.argmax(values)
82+
return int(action)

0 commit comments

Comments
 (0)