Skip to content

Commit 714fbcb

Browse files
committed
Update docstrings and API
1 parent 9de8124 commit 714fbcb

File tree

9 files changed

+80
-40
lines changed

9 files changed

+80
-40
lines changed

docs/source/API/actions.rst

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,8 @@
33

44
.. automodule:: actions
55

6-
7-
8-
96

10-
11-
12-
13-
14-
15-
7+
168
.. rubric:: Classes
179

1810
.. autosummary::

docs/source/modules.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ API
55
:maxdepth: 4
66

77
API/actions
8+
API/state
89
generated/action_space
910
generated/q_estimator
1011
generated/q_learning
1112
generated/trainer
1213
generated/sarsa_semi_gradient
1314
generated/exceptions
1415
generated/action_space
15-
generated/actions
1616
generated/column_type
1717
generated/discrete_state_environment
1818
generated/observation_space

src/algorithms/q_learning.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
7777
total_dist = env.total_current_distortion()
7878
while stop_criterion.continue_itr(total_dist):
7979

80-
if stop_criterion.iteration_counter == 12:
81-
print("Break...")
82-
8380
# use the policy to select an action
8481
state_idx = env.get_aggregated_state(total_dist)
8582
action_idx = self.config.policy.on_state(state_idx)

src/algorithms/sarsa_semi_gradient.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,45 @@ class SARSAn(WithMaxActionMixin):
3939
"""
4040

4141
def __init__(self, sarsa_config: SARSAnConfig):
42-
super(SARSAn, self).__init__()
42+
super(SARSAn, self).__init__(table={})
4343
self.name = "SARSAn"
4444
self.config = sarsa_config
45-
self.q_table = {}
4645

4746
def play(self, env: Env, stop_criterion: Criterion) -> None:
48-
pass
47+
"""
48+
Apply the trained agent on the given environment.
49+
50+
Parameters
51+
----------
52+
env: The environment to apply the agent
53+
stop_criterion: Criteria that specify when play should stop
54+
55+
Returns
56+
-------
57+
58+
None
59+
60+
"""
61+
# loop over the columns and for the
62+
# column get the action that corresponds to
63+
# the max payout.
64+
# TODO: This will no work as the distortion is calculated
65+
# by summing over the columns.
66+
67+
# set the q_table for the policy
68+
# this is the table we should be using to
69+
# make decisions
70+
self.config.policy.q_table = self.q_table
71+
total_dist = env.total_current_distortion()
72+
while stop_criterion.continue_itr(total_dist):
73+
# use the policy to select an action
74+
state_idx = env.get_aggregated_state(total_dist)
75+
action_idx = self.config.policy.on_state(state_idx)
76+
action = env.get_action(action_idx)
77+
print("{0} At state={1} with distortion={2} select action={3}".format("INFO: ", state_idx, total_dist,
78+
action.column_name + "-" + action.action_type.name))
79+
env.step(action=action)
80+
total_dist = env.total_current_distortion()
4981

5082
def actions_before_training(self, env: Env) -> None:
5183
"""

src/algorithms/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
"""
2-
Trainer
1+
"""Module trainer. Specifies a utility class
2+
for training serial reinforcement learning algorithms
3+
34
"""
45

56
import numpy as np

src/examples/nstep_semi_grad_sarsa_three_columns.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def load_dataset() -> MockSubjectsLoader:
218218
title="Running distortion average over 100 episodes")
219219

220220

221-
'''
221+
222222
print("=============================================")
223223
print("{0} Generating distorted dataset".format(INFO))
224224
# Let's play
@@ -229,4 +229,3 @@ def load_dataset() -> MockSubjectsLoader:
229229
env.save_current_dataset(episode_index=-2, save_index=False)
230230
print("{0} Done....".format(INFO))
231231
print("=============================================")
232-
'''

src/policies/epsilon_greedy_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, eps: float, n_actions: int,
3232
max_eps: float = 1.0, min_eps: float = 0.001,
3333
epsilon_decay_factor: float = 0.01,
3434
user_defined_decrease_method: UserDefinedDecreaseMethod = None) -> None:
35-
super(WithMaxActionMixin, self).__init__()
35+
super(WithMaxActionMixin, self).__init__(table={})
3636
self._eps = eps
3737
self._n_actions = n_actions
3838
self._decay_op = decay_op

src/spaces/state.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
"""
2-
Discretized state space
1+
"""The state module. Specifies a wrapper
2+
to a state such that it exposes column distortions
3+
and the bin index of the overall distortion.
4+
35
"""
46

57
from typing import TypeVar, List, Any
@@ -54,17 +56,30 @@ def __len__(self):
5456

5557

5658
class State(object):
57-
"""
58-
Helper to represent a State
59+
"""Helper to represent a State
5960
"""
6061
def __init__(self):
6162
self.idx = -1
6263
self.bin_idx = -1
6364
self.total_distortion: float = 0.0
6465
self.column_distortions = {}
6566

66-
def __contains__(self, item) -> bool:
67-
return item in self.column_distortions.keys()
67+
def __contains__(self, column_name: str) -> bool:
68+
"""
69+
Returns true if column_name is in the column_distortions
70+
keys
71+
72+
Parameters
73+
----------
74+
column_name: The column name to query
75+
76+
Returns
77+
-------
78+
79+
A boolean indicating if column_name is in the column_distortions
80+
keys or not.
81+
"""
82+
return column_name in self.column_distortions.keys()
6883

6984
def __iter__(self):
7085
return StateIterator(list(self.column_distortions.keys()))

src/utils/mixins.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,29 @@ def finished(self) -> bool:
5252

5353

5454
class WithQTableMixinBase(metaclass=abc.ABCMeta):
55-
"""
56-
Base class to impose the concept of Q-table
55+
"""Base class to impose the concept of Q-table
5756
"""
5857

59-
def __init__(self):
58+
def __init__(self, table: QTable = None):
6059
# the table representing the q function
6160
# client code should choose the type of
6261
# the table
63-
self.q_table: QTable = None
62+
self.q_table: QTable = table
6463

6564

6665
class WithQTableMixin(WithQTableMixinBase):
66+
"""Helper class to associate a q_table with an algorithm
6767
"""
68-
Helper class to associate a q_table with an algorithm
69-
if this is needed.
70-
"""
71-
def __init__(self):
72-
super(WithQTableMixin, self).__init__()
68+
def __init__(self, table: QTable = None):
69+
"""
70+
Constructor
71+
72+
Parameters
73+
----------
74+
table: The Q-table representing the Q-function
75+
76+
"""
77+
super(WithQTableMixin, self).__init__(table)
7378

7479
def state_action_values(self, state: Any, n_actions: int):
7580

@@ -81,12 +86,11 @@ def state_action_values(self, state: Any, n_actions: int):
8186

8287

8388
class WithMaxActionMixin(WithQTableMixin):
84-
"""
85-
The class WithMaxActionMixin.
89+
"""The class WithMaxActionMixin.
8690
"""
8791

88-
def __init__(self):
89-
super(WithMaxActionMixin, self).__init__()
92+
def __init__(self, table: QTable = None):
93+
super(WithMaxActionMixin, self).__init__(table)
9094

9195
def max_action(self, state: Any, n_actions: int) -> int:
9296
"""

0 commit comments

Comments
 (0)