Skip to content

Commit f0f14dd

Browse files
committed
API updates
1 parent 714fbcb commit f0f14dd

12 files changed

+172
-48
lines changed

src/algorithms/sarsa_semi_gradient.py renamed to src/algorithms/n_step_semi_gradient_sarsa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
@dataclass(init=True, repr=True)
2222
class SARSAnConfig:
2323
"""Configuration class for n-step SARSA algorithm
24-
2524
"""
2625
gamma: float = 1.0
2726
alpha: float = 0.1

src/algorithms/trainer.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,30 @@ def avg_distortion(self) -> np.array:
4747
avg[i] = self.total_distortions[i] / self.iterations_per_episode[i]
4848
return avg
4949

50-
def actions_before_training(self):
51-
"""
52-
Any actions to perform before training begins
53-
:return:
50+
def actions_before_training(self) -> None:
51+
"""Any actions to perform before training begins
52+
53+
Returns
54+
-------
55+
56+
None
5457
"""
58+
5559
self.total_rewards: np.array = np.zeros(self.configuration['n_episodes'])
5660
self.iterations_per_episode = []
57-
5861
self.agent.actions_before_training(self.env)
5962

6063
def actions_before_episode_begins(self, **options) -> None:
61-
"""
62-
Perform any actions necessary before the training begins
63-
:param options:
64-
:return:
64+
"""Perform any actions necessary before the training begins
65+
66+
Parameters
67+
----------
68+
options: Any options passed by the client code
69+
70+
Returns
71+
-------
72+
73+
None
6574
"""
6675
self.agent.actions_before_episode_begins(**options)
6776

src/datasets/datasets_loaders.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,46 +5,65 @@
55
"""
66

77
from pathlib import Path
8+
from typing import List
9+
from dataclasses import dataclass, field
10+
811
from src.datasets.dataset_wrapper import PandasDSWrapper
912

1013

11-
class MockSubjectsLoader(PandasDSWrapper):
12-
"""
13-
The class MockSubjectsLoader. Loads the mocksubjects.csv
14-
"""
14+
@dataclass(init=True, repr=True)
15+
class MockSubjectsData(object):
1516

1617
# Path to the dataset file
17-
FILENAME = Path("../../data/mocksubjects.csv")
18+
FILENAME: Path = Path("../../data/mocksubjects.csv")
1819

1920
# the assumed column types. We use this map to cast
2021
# the types of the columns
21-
COLUMNS_TYPES = {"gender": str, "ethnicity": str, "education": int,
22-
"salary": int, "diagnosis": int, "preventative_treatment": str,
23-
"mutation_status": int, }
22+
COLUMNS_TYPES: dict = field(default_factory=lambda: {"gender": str, "ethnicity": str, "education": int,
23+
"salary": int, "diagnosis": int, "preventative_treatment": str,
24+
"mutation_status": int,})
2425

2526
# features to drop
26-
FEATURES_DROP_NAMES = ["NHSno", "given_name", "surname", "dob"]
27+
FEATURES_DROP_NAMES: List[str] = field(default_factory=lambda: ["NHSno", "given_name", "surname", "dob"])
2728

2829
# Names of the columns in the dataset
29-
NAMES = ["NHSno", "given_name", "surname", "gender",
30-
"dob", "ethnicity", "education", "salary",
31-
"mutation_status", "preventative_treatment", "diagnosis"]
30+
NAMES: List[str] = field(default_factory=lambda: ["NHSno", "given_name", "surname", "gender",
31+
"dob", "ethnicity", "education", "salary",
32+
"mutation_status", "preventative_treatment", "diagnosis"])
3233

3334
# option to drop NaN
34-
DROP_NA = True
35+
DROP_NA: bool = True
3536

3637
# Map that holds for each column the transformations
3738
# we want to apply for each value
38-
CHANGE_COLS_VALS = {"diagnosis": [('N', 0)]}
39+
CHANGE_COLS_VALS: dict = field(default_factory=lambda: {"diagnosis": [('N', 0)]})
3940

4041
# list of columns to be normalized
41-
NORMALIZED_COLUMNS = []
42-
43-
def __init__(self):
44-
super(MockSubjectsLoader, self).__init__(columns=MockSubjectsLoader.COLUMNS_TYPES)
45-
self.read(filename=MockSubjectsLoader.FILENAME,
46-
**{"features_drop_names": MockSubjectsLoader.FEATURES_DROP_NAMES,
47-
"names": MockSubjectsLoader.NAMES,
48-
"drop_na": MockSubjectsLoader.DROP_NA,
49-
"change_col_vals": MockSubjectsLoader.CHANGE_COLS_VALS,
50-
"column_normalization": MockSubjectsLoader.NORMALIZED_COLUMNS})
42+
NORMALIZED_COLUMNS: List[str] = field(default_factory=list)
43+
44+
45+
class MockSubjectsLoader(PandasDSWrapper):
46+
"""The class MockSubjectsLoader. Loads the mocksubjects.csv
47+
"""
48+
49+
@classmethod
50+
def from_options(cls, *, filename: Path,
51+
column_types: dir, features_drop_names: List[str],
52+
names: List[str], drop_na: bool, change_col_vals: dict, column_normalization: List[str]):
53+
54+
data = MockSubjectsData(FILENAME=filename, COLUMNS_TYPES=column_types,
55+
FEATURES_DROP_NAMES=features_drop_names, NAMES=names,
56+
DROP_NA=drop_na, CHANGE_COLS_VALS=change_col_vals,
57+
NORMALIZED_COLUMNS=column_normalization)
58+
return cls(data=data)
59+
60+
def __init__(self, data: MockSubjectsData, do_read: bool=True):
61+
super(MockSubjectsLoader, self).__init__(columns=data.COLUMNS_TYPES)
62+
63+
if do_read:
64+
self.read(filename=data.FILENAME,
65+
**{"features_drop_names": data.FEATURES_DROP_NAMES,
66+
"names": data.NAMES,
67+
"drop_na": data.DROP_NA,
68+
"change_col_vals": data.CHANGE_COLS_VALS,
69+
"column_normalization": data.NORMALIZED_COLUMNS})

src/examples/nstep_semi_grad_sarsa_three_columns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from pathlib import Path
44

5-
from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn
5+
from src.algorithms.n_step_semi_gradient_sarsa import SARSAnConfig, SARSAn
66
from src.algorithms.q_estimator import QEstimator
77
from src.algorithms.trainer import Trainer
88
from src.datasets.datasets_loaders import MockSubjectsLoader

src/spaces/discrete_state_environment.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from src.spaces.actions import ActionBase, ActionType
1414
from src.spaces.time_step import TimeStep, StepType
1515

16-
1716
DataSet = TypeVar("DataSet")
1817
RewardManager = TypeVar("RewardManager")
1918
ActionSpace = TypeVar("ActionSpace")
@@ -47,6 +46,24 @@ class DiscreteStateEnvironment(object):
4746

4847
IS_TILED_ENV_CONSTRAINT = False
4948

49+
@classmethod
50+
def from_options(cls, *, data_set: DataSet, action_space: ActionSpace,
51+
reward_manager: RewardManager, distortion_calculator: DistortionCalculator,
52+
average_distortion_constraint: float = 0.0,
53+
gamma: float = 0.99, n_states: int = 10, min_distortion: float = 0.4,
54+
max_distortion: float = 0.7, punish_factor: float = 2.0, reward_factor: float = 0.95,
55+
n_rounds_below_min_distortion: int = 10,
56+
distorted_set_path: Path = None):
57+
58+
config = DiscreteEnvConfig(data_set=data_set, action_space=action_space, reward_manager=reward_manager,
59+
distortion_calculator=distortion_calculator, distorted_set_path=distorted_set_path,
60+
reward_factor=reward_factor,
61+
n_rounds_below_min_distortion=n_rounds_below_min_distortion,
62+
punish_factor=punish_factor, max_distortion=max_distortion, gamma=gamma,
63+
n_states=n_states, min_distortion=min_distortion,
64+
average_distortion_constraint=average_distortion_constraint)
65+
return cls(env_config=config)
66+
5067
def __init__(self, env_config: DiscreteEnvConfig) -> None:
5168
self.config = env_config
5269
self.n_rounds_below_min_distortion = 0

src/spaces/tiled_environment.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import copy
66
from typing import TypeVar, List
77
from dataclasses import dataclass
8+
9+
import numpy as np
10+
811
from src.extern.tile_coding import IHT, tiles
912
from src.spaces.actions import ActionBase, ActionType
1013
from src.spaces.time_step import TimeStep
@@ -26,6 +29,7 @@ class TiledEnvConfig(object):
2629
num_tilings: int = 0
2730
max_size: int = 0
2831
tiling_dim: int = 0
32+
n_bins: int = 1
2933
column_ranges: dict = None
3034

3135

@@ -36,12 +40,19 @@ class TiledEnv(object):
3640

3741
IS_TILED_ENV_CONSTRAINT = True
3842

43+
@classmethod
44+
def from_options(cls, *, env: Env, max_size: int, num_tilings: int,
45+
tiling_dim: int, n_bins: int, column_ranges: dict):
46+
return cls(TiledEnvConfig(env=env, max_size=max_size, num_tilings=num_tilings,
47+
tiling_dim=tiling_dim, n_bins=n_bins, column_ranges=column_ranges))
48+
3949
def __init__(self, config: TiledEnvConfig) -> None:
4050

4151
self.env = config.env
4252
self.max_size = config.max_size
4353
self.num_tilings = config.num_tilings
4454
self.tiling_dim = config.tiling_dim
55+
self.n_bins = config.n_bins
4556

4657
# set up the columns scaling
4758
# only the columns that are to be altered participate in the
@@ -55,6 +66,8 @@ def __init__(self, config: TiledEnvConfig) -> None:
5566
self._create_column_scales()
5667
self.iht = IHT(self.max_size)
5768

69+
self.column_bins = {}
70+
5871
@property
5972
def action_space(self):
6073
return self.env.action_space
@@ -170,7 +183,20 @@ def create_bins(self) -> None:
170183
None
171184
172185
"""
173-
self.env.create_bins()
186+
187+
# calculate the tile width for each column in the
188+
# data set
189+
190+
tile_widhs = {}
191+
for column in self.column_ranges:
192+
range_ = self.column_ranges[column]
193+
tile_width = (range_[1] + range_[0]) / self.n_bins
194+
self.column_bins[column] = np.zeros((self.num_tilings, self.n_bins))
195+
196+
# for each layer create an offset
197+
# bin
198+
for i in range(self.num_tilings):
199+
self.column_bins[column][i] = np.linspace(range_[0] + i * tile_width, range_[1] + i * tile_width, self.n_bins)
174200

175201
def get_aggregated_state(self, state_val: float) -> int:
176202
"""
@@ -325,10 +351,18 @@ def _validate(self) -> None:
325351
param_value=str(self.max_size) +
326352
" should be >=num_tilings * tiling_dim * tiling_dim")
327353

354+
if self.column_ranges is None:
355+
raise InvalidParamValue(param_name="column_ranges",
356+
param_value="None")
357+
328358
if len(self.column_ranges) == 0:
329359
raise InvalidParamValue(param_name="column_scales",
330360
param_value=str(len(self.column_scales)) + " should not be empty")
331361

362+
if self.env is None:
363+
raise InvalidParamValue(param_name="env",
364+
param_value="None")
365+
332366
if len(self.column_ranges) != len(self.env.column_names):
333367
raise ValueError("Column ranges is not equal to number of columns")
334368

src/tests/test_sarsa_semi_gradient.py renamed to src/tests/test_n_step_sarsa_semi_gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import pytest
3-
from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn
3+
from src.algorithms.n_step_semi_gradient_sarsa import SARSAnConfig, SARSAn
44
from src.spaces.tiled_environment import TiledEnv, TiledEnvConfig
55
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonDecayOption
66
from src.exceptions.exceptions import InvalidParamValue

src/tests/test_suite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .test_serial_hierarchy import TestSerialHierarchy
55
from .test_preprocessor import TestPreprocessor
66
from .test_actions import TestActions
7-
from .test_sarsa_semi_gradient import TestSARSAn
7+
from .test_n_step_sarsa_semi_gradient import TestSARSAn
8+
from .test_semi_gradient_sarsa import TestSemiGradSARSA
89
from .test_tiled_environment import TestTiledEnv
910

1011

@@ -15,6 +16,7 @@ def suite():
1516
suite.addTest(TestPreprocessor)
1617
suite.addTest(TestActions)
1718
suite.addTest(TestSARSAn)
19+
suite.addTest(TestSemiGradSARSA)
1820
suite.addTest(TestTiledEnv)
1921
return suite
2022

src/tests/test_tiled_environment.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from src.spaces.tiled_environment import TiledEnv, TiledEnvConfig
77
from src.exceptions.exceptions import InvalidParamValue
88

9+
class DummyEnv(object):
10+
11+
def __init__(self):
12+
self.column_names = ["col1", "col2"]
913

1014
class TestTiledEnv(unittest.TestCase):
1115

@@ -27,16 +31,47 @@ def test_constructor_raises_invalid_max_size(self):
2731
with pytest.raises(InvalidParamValue) as e:
2832
env = TiledEnv(config)
2933

30-
def test_empty_column_scales(self):
34+
def test_none_column_ranges(self):
3135
config = TiledEnvConfig()
3236
config.env = None
3337
config.max_size = 4096
3438
config.tiling_dim = 2
3539
config.num_tilings = 5
36-
config.columns_scales = {}
40+
config.column_ranges = None
3741
with pytest.raises(InvalidParamValue) as e:
3842
env = TiledEnv(config)
3943

44+
def test_empty_column_ranges(self):
45+
config = TiledEnvConfig()
46+
config.env = None
47+
config.max_size = 4096
48+
config.tiling_dim = 2
49+
config.num_tilings = 5
50+
config.column_ranges = {}
51+
with pytest.raises(InvalidParamValue) as e:
52+
env = TiledEnv(config)
53+
54+
def test_create_bins(self):
55+
config = TiledEnvConfig()
56+
config.env = DummyEnv()
57+
config.max_size = 4096
58+
config.tiling_dim = 2
59+
config.num_tilings = 2
60+
config.column_ranges = {"col1": [0.0, 1.0], "col2": [0.0, 1.0]}
61+
env = TiledEnv(config)
62+
env.create_bins()
63+
64+
tiles = env.column_bins
65+
# we must have as many bins as columns
66+
self.assertEqual(2, len(tiles))
67+
68+
for column in tiles:
69+
# for each column we must have config.num_tilings
70+
self.assertEqual(config.num_tilings, len(tiles[column]))
71+
72+
# each tiling must have config.n_bins
73+
for tile in tiles[column]:
74+
self.assertEqual(config.n_bins, len(tile))
4075

4176

4277
if __name__ == '__main__':

src/tests/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from src.algorithms.trainer import Trainer
8-
from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn
8+
from src.algorithms.n_step_semi_gradient_sarsa import SARSAnConfig, SARSAn
99
from src.spaces.tiled_environment import TiledEnv
1010

1111

0 commit comments

Comments
 (0)