Skip to content

Commit 3f5a77f

Browse files
committed
#53 Update API and finalize the example
1 parent 38c4abc commit 3f5a77f

File tree

7 files changed

+139
-53
lines changed

7 files changed

+139
-53
lines changed

src/algorithms/sarsa_semi_gradient.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
"""
66
import numpy as np
77
from typing import TypeVar
8+
from dataclasses import dataclass
89

910
from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase
1011
from src.utils.episode_info import EpisodeInfo
12+
1113
from src.algorithms.q_estimator import QEstimator
1214
from src.exceptions.exceptions import InvalidParamValue
1315

@@ -16,28 +18,28 @@
1618
Policy = TypeVar('Policy')
1719
Estimator = TypeVar('Estimator')
1820

19-
21+
@dataclass(init=True, repr=True)
2022
class SARSAnConfig:
23+
"""Configuration class for n-step SARSA algorithm
2124
22-
def __init__(self) -> None:
23-
self.gamma: float = 1.0
24-
self.alpha = 0.1
25-
self.n = 10
26-
self.n_itrs_per_episode: int = 100
27-
self.max_size: int = 4096
28-
self.use_trace: bool = False
29-
self.policy: Policy = None
30-
self.estimator: Estimator = None
31-
self.reset_estimator_z_traces: bool = False
25+
"""
26+
gamma: float = 1.0
27+
alpha: float = 0.1
28+
n: int = 10
29+
n_itrs_per_episode: int = 100
30+
max_size: int = 4096
31+
use_trace: bool = False
32+
policy: Policy = None
33+
estimator: Estimator = None
34+
reset_estimator_z_traces: bool = False
3235

3336

3437
class SARSAn(WithMaxActionMixin):
35-
"""
36-
Implementation ofn-step semi-gradient SARSA algorithm
38+
"""Implementation of n-step semi-gradient SARSA algorithm
3739
"""
3840

3941
def __init__(self, sarsa_config: SARSAnConfig):
40-
42+
super(SARSAn, self).__init__()
4143
self.name = "SARSAn"
4244
self.config = sarsa_config
4345
self.q_table = {}
@@ -66,6 +68,9 @@ def actions_before_episode_begins(self, **options) -> None:
6668
# reset the estimator
6769
self.config.estimator.reset(self.config.reset_estimator_z_traces)
6870

71+
def actions_after_episode_ends(self, **options) -> None:
72+
pass
73+
6974
def on_episode(self, env: Env) -> EpisodeInfo:
7075
"""
7176
Train the agent on the given algorithm
@@ -95,6 +100,7 @@ def on_episode(self, env: Env) -> EpisodeInfo:
95100
# take action A, observe R, S'
96101
next_time_step = env.step(action)
97102
next_state = next_time_step.observation
103+
states.append(next_state)
98104
reward = next_time_step.reward
99105

100106
total_distortion += next_time_step.info["total_distortion"]
@@ -107,7 +113,7 @@ def on_episode(self, env: Env) -> EpisodeInfo:
107113

108114
next_action_idx = self.config.policy(self.q_table, next_state)
109115
next_action = env.get_action(next_action_idx)
110-
actions.append(next_action)
116+
actions.append(next_action_idx)
111117

112118
# should we update
113119
update_time = itr + 1 - self.config.n
@@ -122,7 +128,14 @@ def on_episode(self, env: Env) -> EpisodeInfo:
122128
q_values_next = self.config.estimator.predict(states[update_time + self.config.n])
123129
target += q_values_next[actions[update_time + self.config.n]]
124130

125-
# Update step
131+
# Update step. what happens if the update_time is greater than
132+
# len(states) or len(actions)
133+
134+
if update_time >= len(states) or update_time >= len(actions):
135+
raise InvalidParamValue(param_name="update_time", param_value=str(update_time))
136+
137+
# update the state for the respective action
138+
# with the computed target
126139
self.config.estimator.update(states[update_time], actions[update_time], target)
127140

128141
if update_time == T - 1:
@@ -135,7 +148,7 @@ def on_episode(self, env: Env) -> EpisodeInfo:
135148
episode_info = EpisodeInfo()
136149
episode_info.episode_score = episode_score
137150
episode_info.total_distortion = total_distortion
138-
episode_info.info["m_iterations"] = counter
151+
episode_info.info["n_iterations"] = counter
139152
return episode_info
140153

141154
def _validate(self, env: Env) -> None:

src/algorithms/trainer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
import numpy as np
66
from typing import TypeVar
7+
78
from src.utils import INFO
89
from src.utils.function_wraps import time_func
10+
from src.utils.episode_info import EpisodeInfo
911

1012
Env = TypeVar("Env")
1113
Agent = TypeVar("Agent")
@@ -83,16 +85,18 @@ def train(self):
8385

8486
self.actions_before_episode_begins(**{"env": self.env})
8587
# train for a number of iterations
86-
episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)
88+
#episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)
89+
episode_info: EpisodeInfo = self.agent.on_episode(self.env)
8790

88-
print("{0} Episode score={1}, episode total distortion {2}".format(INFO, episode_score, total_distortion / n_itrs))
91+
print("{0} Episode score={1}, episode total avg distortion {2}".format(INFO, episode_info.episode_score,
92+
episode_info.total_distortion / episode_info.info["n_iterations"]))
8993

9094
#if episode % self.configuration['output_msg_frequency'] == 0:
91-
print("{0} Episode finished after {1} iterations".format(INFO, n_itrs))
95+
print("{0} Episode finished after {1} iterations".format(INFO, episode_info.info["n_iterations"]))
9296

93-
self.iterations_per_episode.append(n_itrs)
94-
self.total_rewards[episode] = episode_score
95-
self.total_distortions.append(total_distortion)
97+
self.iterations_per_episode.append(episode_info.info["n_iterations"])
98+
self.total_rewards[episode] = episode_info.episode_score
99+
self.total_distortions.append(episode_info.total_distortion)
96100
self.actions_after_episode_ends(**{"episode_idx": episode})
97101

98102
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/examples/nstep_semi_grad_sarsa_three_columns.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def load_dataset() -> MockSubjectsLoader:
169169
env = DiscreteStateEnvironment(env_config=env_config)
170170

171171
tiled_env_config = TiledEnvConfig(env=env, num_tilings=NUM_TILINGS, max_size=MAX_SIZE, tiling_dim=TILING_DIM,
172-
column_scales={"ethnicity": [0.0, 1.0], "salary": [0.0, 1.0]})
172+
column_ranges={"ethnicity": [0.0, 1.0],
173+
"salary": [0.0, 1.0],
174+
"diagnosis": [0.0, 5.0]})
173175
# we will use a tiled environment in this example
174176
tiled_env = TiledEnv(config=tiled_env_config)
175177
tiled_env.reset()

src/extern/tile_coding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def hashcoords(coordinates, m, readonly=False):
7777

7878

7979
def tiles(ihtORsize, numtilings, floats, ints=[], readonly=False):
80-
"""returns num-tilings tile indices corresponding to the floats and ints"""
80+
"""Returns num-tilings tile indices corresponding to the floats and ints
81+
"""
8182
qfloats = [floor(f * numtilings) for f in floats]
8283
Tiles = []
8384
for tiling in range(numtilings):

src/spaces/discrete_state_environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def n_states(self) -> int:
8484
def column_names(self) -> list:
8585
return self.config.data_set.get_columns_names()
8686

87+
@property
88+
def column_distortions(self) -> dict:
89+
return self.column_distances
90+
8791
def get_action(self, aidx: int) -> ActionBase:
8892
return self.config.action_space[aidx]
8993

src/spaces/state.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,27 @@ def __init__(self):
6161
self.idx = -1
6262
self.bin_idx = -1
6363
self.total_distortion: float = 0.0
64-
self.column_names = []
64+
self.column_distortions = {}
6565

6666
def __contains__(self, item) -> bool:
67-
return item in self.column_names
67+
return item in self.column_distortions.keys()
6868

6969
def __iter__(self):
70-
return StateIterator(self.column_names)
70+
return StateIterator(list(self.column_distortions.keys()))
71+
72+
def __getitem__(self, name: str) -> float:
73+
"""
74+
Get the distortion corresponding to the name-th column
75+
76+
Parameters
77+
----------
78+
name: The name of the column
79+
80+
Returns
81+
-------
82+
83+
The column distortion
84+
"""
85+
return self.column_distortions[name]
7186

7287

src/spaces/tiled_environment.py

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import copy
6-
from typing import TypeVar
6+
from typing import TypeVar, List
77
from dataclasses import dataclass
88
from src.extern.tile_coding import IHT, tiles
99
from src.spaces.actions import ActionBase, ActionType
@@ -13,6 +13,8 @@
1313
from src.spaces.time_step import copy_time_step
1414

1515
Env = TypeVar('Env')
16+
Tile = TypeVar('Tile')
17+
Config = TypeVar('Config')
1618

1719

1820
@dataclass(init=True, repr=True)
@@ -24,10 +26,14 @@ class TiledEnvConfig(object):
2426
num_tilings: int = 0
2527
max_size: int = 0
2628
tiling_dim: int = 0
27-
column_scales: dict = None
29+
column_ranges: dict = None
2830

2931

3032
class TiledEnv(object):
33+
"""The TiledEnv class. It models a tiled
34+
environment
35+
"""
36+
3137
IS_TILED_ENV_CONSTRAINT = True
3238

3339
def __init__(self, config: TiledEnvConfig) -> None:
@@ -40,11 +46,13 @@ def __init__(self, config: TiledEnvConfig) -> None:
4046
# set up the columns scaling
4147
# only the columns that are to be altered participate in the
4248
# tiling
43-
self.column_scales = config.column_scales
49+
self.column_ranges = config.column_ranges
50+
self.column_scales = {}
4451

4552
# Initialize index hash table (IHT) for tile coding.
4653
# This assigns a unique index to each tile up to max_size tiles.
4754
self._validate()
55+
self._create_column_scales()
4856
self.iht = IHT(self.max_size)
4957

5058
@property
@@ -59,6 +67,10 @@ def n_actions(self) -> int:
5967
def n_states(self) -> int:
6068
return self.env.n_states
6169

70+
@property
71+
def config(self) -> Config:
72+
return self.env.config
73+
6274
def step(self, action: ActionBase) -> TimeStep:
6375
"""Execute the action in the environment and return
6476
a new state for observation
@@ -83,16 +95,11 @@ def step(self, action: ActionBase) -> TimeStep:
8395
# of the bin that the total distortion falls into
8496
state.bin_idx = raw_time_step.observation
8597
state.total_distortion = raw_time_step.info["total_distortion"]
86-
state.column_names = self.env.column_names
98+
state.column_distortions = self.env.column_distortions
8799

88100
time_step = copy_time_step(time_step=raw_time_step, **{"observation": state})
89-
#time_step = copy.deepcopy(raw_time_step)
90-
#time_step.observation = state
91-
92101
return time_step
93102

94-
return
95-
96103
def reset(self, **options) -> TimeStep:
97104
"""Reset the environment so that a new sequence
98105
of episodes can be generated
@@ -116,24 +123,29 @@ def reset(self, **options) -> TimeStep:
116123
# of the bin that the total distortion falls into
117124
state.bin_idx = raw_time_step.observation
118125
state.total_distortion = raw_time_step.info["total_distortion"]
119-
state.column_names = self.env.column_names
126+
state.column_distortions = self.env.column_distortions
120127

121128
time_step = copy_time_step(time_step=raw_time_step, **{"observation": state})
122129

123-
#time_step = copy.deepcopy(raw_time_step)
124-
#time_step.observation = state
125-
126130
return time_step
127131

128132
def get_action(self, aidx: int) -> ActionBase:
129133
return self.env.action_space[aidx]
130134

131135
def save_current_dataset(self, episode_index: int, save_index: bool = False) -> None:
132136
"""
133-
Save the current distorted datase for the given episode index
134-
:param episode_index:
135-
:param save_index:
136-
:return:
137+
Save the current data set at the given episode index
138+
Parameters
139+
----------
140+
141+
episode_index: Episode index corresponding to the training episode
142+
save_index: if True the Pandas index is also saved
143+
144+
Returns
145+
-------
146+
147+
None
148+
137149
"""
138150
self.env.save_current_dataset(episode_index, save_index)
139151

@@ -200,22 +212,54 @@ def get_scaled_state(self, state: State) -> list:
200212
"""
201213
scaled_state_vals = []
202214
for name in state:
203-
scaled_state_vals.append(state[name] * self.columns_scales[name])
215+
scaled_state_vals.append(state[name] * self.column_scales[name])
204216

205217
return scaled_state_vals
206218

207-
def featurize_state_action(self, state, action: ActionBase) -> None:
208-
"""
209-
Returns the featurized representation for a state-action pair
210-
:param state:
211-
:param action:
212-
:return:
219+
def featurize_state_action(self, state: State, action: ActionBase) -> List[Tile]:
220+
"""Get a list of Tiles for the given state and action
221+
222+
Parameters
223+
----------
224+
state: The environment state observed
225+
action: The action
226+
227+
Returns
228+
-------
229+
230+
A list of tiles
231+
213232
"""
233+
214234
scaled_state = self.get_scaled_state(state)
215235
featurized = tiles(self.iht, self.num_tilings, scaled_state, [action])
216236
return featurized
217237

238+
def _create_column_scales(self) -> None:
239+
"""
240+
Create the scales for each column
241+
242+
Returns
243+
-------
244+
245+
None
246+
247+
"""
248+
249+
for name in self.column_ranges:
250+
range_ = self.column_ranges[name]
251+
self.column_scales[name] = self.tiling_dim / (range_[1] - range_[0])
252+
218253
def _validate(self) -> None:
254+
"""
255+
Validate the internal data structures
256+
257+
Returns
258+
-------
259+
260+
None
261+
262+
"""
219263
if self.max_size <= 0:
220264
raise InvalidParamValue(param_name="max_size",
221265
param_value=str(self.max_size) + " should be > 0")
@@ -227,7 +271,10 @@ def _validate(self) -> None:
227271
param_value=str(self.max_size) +
228272
" should be >=num_tilings * tiling_dim * tiling_dim")
229273

230-
if len(self.column_scales) == 0:
274+
if len(self.column_ranges) == 0:
231275
raise InvalidParamValue(param_name="column_scales",
232276
param_value=str(len(self.column_scales)) + " should not be empty")
233277

278+
if len(self.column_ranges) != len(self.env.column_names):
279+
raise ValueError("Column ranges is not equal to number of columns")
280+

0 commit comments

Comments
 (0)