Skip to content

Commit 81d3eeb

Browse files
committed
#53 API updates for n-step SARSA algorithm
1 parent 88706b7 commit 81d3eeb

File tree

6 files changed

+144
-60
lines changed

6 files changed

+144
-60
lines changed

build_sphinx_doc.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#sphinx-quickstart docs
22

3-
sphinx-apidoc -f -o docs/source docs/projectdir
4-
#sphinx-build -b html docs/source/ docs/build/html
3+
#sphinx-apidoc -f -o docs/source docs/source/API
4+
sphinx-build -b html docs/source/ docs/build/html

docs/source/examples.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@ Examples
33

44
Some examples can be found below
55

6-
- `Qlearning agent on a three columns dataset <src/examples/qlearning_three_columns.py>`_
7-
- `N-step semi-gradient SARSA on a three columns dataset <src/examples/nstep_semi_grad_sarsa_three_columns.py>`_
6+
.. toctree::
7+
:maxdepth: 4
8+
9+
Examples/qlearning_three_columns
10+
Examples/nstep_semi_grad_sarsa_three_columns

src/examples/nstep_semi_grad_sarsa_three_columns.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from src.utils.string_distance_calculator import StringDistanceType
1717
from src.utils.distortion_calculator import DistortionCalculationType, DistortionCalculator
1818
from src.spaces.discrete_state_environment import DiscreteStateEnvironment, DiscreteEnvConfig
19-
from src.spaces.tiled_environment import TiledEnv
19+
from src.spaces.tiled_environment import TiledEnv, TiledEnvConfig
2020
from src.utils.iteration_control import IterationControl
2121
from src.utils.plot_utils import plot_running_avg
2222
from src.utils import INFO
@@ -168,10 +168,10 @@ def load_dataset() -> MockSubjectsLoader:
168168
# create the environment
169169
env = DiscreteStateEnvironment(env_config=env_config)
170170

171+
tiled_env_config = TiledEnvConfig(env=env, num_tilings=NUM_TILINGS, max_size=MAX_SIZE, tiling_dim=TILING_DIM,
172+
column_scales={"ethnicity": [0.0, 1.0], "salary": [0.0, 1.0]})
171173
# we will use a tiled environment in this example
172-
tiled_env = TiledEnv(env=env, max_size=MAX_SIZE,
173-
num_tilings=NUM_TILINGS,
174-
tiling_dim=TILING_DIM)
174+
tiled_env = TiledEnv(config=tiled_env_config)
175175
tiled_env.reset()
176176

177177
# save the data before distortion so that we can

src/spaces/discrete_state_environment.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,41 @@
77
import numpy as np
88
from pathlib import Path
99
from typing import TypeVar, List
10+
from dataclasses import dataclass
1011
import multiprocessing as mp
1112

1213
from src.spaces.actions import ActionBase, ActionType
1314
from src.spaces.time_step import TimeStep, StepType
1415

16+
1517
DataSet = TypeVar("DataSet")
1618
RewardManager = TypeVar("RewardManager")
1719
ActionSpace = TypeVar("ActionSpace")
1820
DistortionCalculator = TypeVar('DistortionCalculator')
1921

2022

23+
@dataclass(init=True, repr=True)
2124
class DiscreteEnvConfig(object):
22-
"""
23-
Configuration for discrete environment
25+
"""Configuration for discrete environment
2426
"""
2527

26-
def __init__(self) -> None:
27-
self.data_set: DataSet = None
28-
self.action_space: ActionSpace = None
29-
self.reward_manager: RewardManager = None
30-
self.average_distortion_constraint: float = 0.0
31-
self.gamma: float = 0.99
32-
self.n_states: int = 10
33-
self.min_distortion: float = 0.4
34-
self.max_distortion: float = 0.7
35-
self.punish_factor: float = 2.0
36-
self.reward_factor: float = 0.95
37-
self.n_rounds_below_min_distortion: int = 10
38-
self.distorted_set_path: Path = None
39-
self.distortion_calculator: DistortionCalculator = None
28+
data_set: DataSet = None
29+
action_space: ActionSpace = None
30+
reward_manager: RewardManager = None
31+
average_distortion_constraint: float = 0.0
32+
gamma: float = 0.99
33+
n_states: int = 10
34+
min_distortion: float = 0.4
35+
max_distortion: float = 0.7
36+
punish_factor: float = 2.0
37+
reward_factor: float = 0.95
38+
n_rounds_below_min_distortion: int = 10
39+
distorted_set_path: Path = None
40+
distortion_calculator: DistortionCalculator = None
4041

4142

4243
class DiscreteStateEnvironment(object):
43-
"""
44-
The DiscreteStateEnvironment class. Uses state aggregation in order
44+
"""The DiscreteStateEnvironment class. Uses state aggregation in order
4545
to create bins where the average total distortion of the dataset falls in
4646
"""
4747

@@ -80,6 +80,10 @@ def n_actions(self) -> int:
8080
def n_states(self) -> int:
8181
return self.config.n_states
8282

83+
@property
84+
def column_names(self) -> list:
85+
return self.config.data_set.get_columns_names()
86+
8387
def get_action(self, aidx: int) -> ActionBase:
8488
return self.config.action_space[aidx]
8589

@@ -268,7 +272,6 @@ def step(self, action: ActionBase) -> TimeStep:
268272

269273
# TODO: these modifications will cause the agent to always
270274
# move close to transition points
271-
# TODO: Remove the magic constants
272275
if next_state is not None and self.current_time_step.observation is not None:
273276
if next_state < min_dist_bin <= self.current_time_step.observation:
274277
# the agent chose to step into the chaos again

src/spaces/tiled_environment.py

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,29 @@
22
Tile environment
33
"""
44

5+
import copy
56
from typing import TypeVar
7+
from dataclasses import dataclass
68
from src.extern.tile_coding import IHT, tiles
79
from src.spaces.actions import ActionBase, ActionType
810
from src.spaces.time_step import TimeStep
911
from src.exceptions.exceptions import InvalidParamValue
12+
from src.spaces.state import State
13+
from src.spaces.time_step import copy_time_step
1014

1115
Env = TypeVar('Env')
12-
State = TypeVar('State')
1316

1417

18+
@dataclass(init=True, repr=True)
1519
class TiledEnvConfig(object):
20+
"""Configuration for the TiledEnvironment
1621
"""
17-
Configuration for the TiledEnvironment
18-
"""
19-
def __init__(self):
20-
self.env: Env = None
21-
self.num_tilings: int = 0
22-
self.max_size = 0
23-
self.tiling_dim = 0
24-
self.column_scales = {}
22+
23+
env: Env = None
24+
num_tilings: int = 0
25+
max_size: int = 0
26+
tiling_dim: int = 0
27+
column_scales: dict = None
2528

2629

2730
class TiledEnv(object):
@@ -44,14 +47,6 @@ def __init__(self, config: TiledEnvConfig) -> None:
4447
self._validate()
4548
self.iht = IHT(self.max_size)
4649

47-
def step(self, action: ActionBase) -> TimeStep:
48-
"""
49-
Apply the action and return new state
50-
:param action: The action to apply
51-
:return:
52-
"""
53-
return self.env.step(action)
54-
5550
@property
5651
def action_space(self):
5752
return self.env.action_space
@@ -64,6 +59,72 @@ def n_actions(self) -> int:
6459
def n_states(self) -> int:
6560
return self.env.n_states
6661

62+
def step(self, action: ActionBase) -> TimeStep:
63+
"""Execute the action in the environment and return
64+
a new state for observation
65+
66+
Parameters
67+
----------
68+
action: The action to execute
69+
70+
Returns
71+
-------
72+
73+
An instance of TimeStep type
74+
75+
"""
76+
77+
raw_time_step = self.env.step(action)
78+
79+
# a state wrapper to communicate
80+
state = State()
81+
82+
# the raw environment returns an index
83+
# of the bin that the total distortion falls into
84+
state.bin_idx = raw_time_step.observation
85+
state.total_distortion = raw_time_step.info["total_distortion"]
86+
state.column_names = self.env.column_names
87+
88+
time_step = copy_time_step(time_step=raw_time_step, **{"observation": state})
89+
#time_step = copy.deepcopy(raw_time_step)
90+
#time_step.observation = state
91+
92+
return time_step
93+
94+
return
95+
96+
def reset(self, **options) -> TimeStep:
97+
"""Reset the environment so that a new sequence
98+
of episodes can be generated
99+
100+
Parameters
101+
----------
102+
options: Client provided named options
103+
104+
Returns
105+
-------
106+
107+
An instance of TimeStep type
108+
"""
109+
110+
raw_time_step = self.env.reset(**options)
111+
112+
# a state wrapper to communicate
113+
state = State()
114+
115+
# the raw environment returns an index
116+
# of the bin that the total distortion falls into
117+
state.bin_idx = raw_time_step.observation
118+
state.total_distortion = raw_time_step.info["total_distortion"]
119+
state.column_names = self.env.column_names
120+
121+
time_step = copy_time_step(time_step=raw_time_step, **{"observation": state})
122+
123+
#time_step = copy.deepcopy(raw_time_step)
124+
#time_step.observation = state
125+
126+
return time_step
127+
67128
def get_action(self, aidx: int) -> ActionBase:
68129
return self.env.action_space[aidx]
69130

@@ -130,21 +191,6 @@ def total_current_distortion(self) -> float:
130191
"""
131192
return self.env.total_current_distortion()
132193

133-
def reset(self, **options) -> TimeStep:
134-
"""
135-
Starts a new sequence and returns the first `TimeStep` of this sequence.
136-
Returns:
137-
A `TimeStep` namedtuple containing:
138-
step_type: A `StepType` of `FIRST`.
139-
reward: `None`, indicating the reward is undefined.
140-
discount: `None`, indicating the discount is undefined.
141-
observation: A NumPy array, or a nested dict, list or tuple of arrays.
142-
Scalar values that can be cast to NumPy arrays (e.g. Python floats)
143-
are also valid in place of a scalar array. Must conform to the
144-
specification returned by `observation_spec()`.
145-
"""
146-
return self.env.reset(**options)
147-
148194
def get_scaled_state(self, state: State) -> list:
149195
"""
150196
Scales the state components ad returns the

src/spaces/time_step.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
"""
44

5+
import copy
56
import enum
67
from typing import NamedTuple, Generic, Optional, TypeVar
78

@@ -52,4 +53,35 @@ def last(self) -> bool:
5253

5354
@property
5455
def done(self) -> bool:
55-
return self.last()
56+
return self.last()
57+
58+
59+
def copy_time_step(time_step: TimeStep, **copy_options) -> TimeStep:
60+
"""Helper to copy partly or in whole a TimeStep namedtuple.
61+
If copy_options is None or empty it returns a deep copy
62+
of the given time step
63+
64+
Parameters
65+
----------
66+
time_step: The time step to copy
67+
copy_options: Members to be copied
68+
69+
Returns
70+
-------
71+
72+
An instance of the TimeStep namedtuple
73+
74+
"""
75+
if not copy_options or len(copy_options) == 0:
76+
return copy.deepcopy(time_step)
77+
78+
observation = copy_options["observation"] if "observation" in copy_options else time_step.observation
79+
step_type = copy_options["step_type"] if "step_type" in copy_options else time_step.step_type
80+
info = copy_options["info"] if "info" in copy_options else time_step.info
81+
reward = copy_options["reward"] if "reward" in copy_options else time_step.reward
82+
discount = copy_options["discount"] if "discount" in copy_options else time_step.discount
83+
return TimeStep(observation=observation, step_type=step_type, info=info,
84+
reward=reward, discount=discount)
85+
86+
87+

0 commit comments

Comments
 (0)