1+ import unittest
2+ import pytest
3+
4+ from src .algorithms .semi_gradient_sarsa import SemiGradSARSAConfig , SemiGradSARSA
5+ from src .algorithms .epsilon_greedy_q_estimator import EpsilonGreedyQEstimator
6+ from src .exceptions .exceptions import InvalidParamValue
7+ from src .spaces .tiled_environment import TiledEnv
8+ from src .spaces .discrete_state_environment import DiscreteStateEnvironment
9+ from src .datasets .datasets_loaders import MockSubjectsLoader , MockSubjectsData
10+
11+ class TestSemiGradSARSA (unittest .TestCase ):
12+
13+ def test_constructor (self ):
14+ config = SemiGradSARSAConfig ()
15+ semi_grad_sarsa = SemiGradSARSA (config )
16+ self .assertIsNotNone (semi_grad_sarsa .config )
17+
18+ def test_actions_before_training_throws_1 (self ):
19+
20+ semi_grad_sarsa = SemiGradSARSA (None )
21+ with pytest .raises (InvalidParamValue ) as e :
22+ semi_grad_sarsa .actions_before_training (env = None )
23+
24+ def test_actions_before_training_throws_2 (self ):
25+ config = SemiGradSARSAConfig ()
26+ config .n_itrs_per_episode = 0
27+ semi_grad_sarsa = SemiGradSARSA (config )
28+
29+ # make sure this is valid
30+ self .assertIsNotNone (semi_grad_sarsa .config )
31+
32+ with pytest .raises (ValueError ) as e :
33+ semi_grad_sarsa .actions_before_training (env = None )
34+
35+ def test_actions_before_training_throws_3 (self ):
36+ config = SemiGradSARSAConfig ()
37+ semi_grad_sarsa = SemiGradSARSA (config )
38+
39+ # make sure this is valid
40+ self .assertIsNotNone (semi_grad_sarsa .config )
41+
42+ with pytest .raises (InvalidParamValue ) as e :
43+ semi_grad_sarsa .actions_before_training (env = None )
44+
45+ def test_on_episode_returns_info (self ):
46+ config = SemiGradSARSAConfig ()
47+ semi_grad_sarsa = SemiGradSARSA (config )
48+
49+ # make sure this is valid
50+ self .assertIsNotNone (semi_grad_sarsa .config )
51+
52+ episode_info = semi_grad_sarsa .on_episode (env = None )
53+ self .assertIsNotNone (episode_info )
54+
55+ def test_on_episode_trains (self ):
56+
57+ sarsa_config = SemiGradSARSAConfig (n_itrs_per_episode = 1 , policy = EpsilonGreedyQEstimator ())
58+ semi_grad_sarsa = SemiGradSARSA (sarsa_config )
59+
60+ # cretate a default data
61+ ds_default_data = MockSubjectsData ()
62+ ds = MockSubjectsLoader .from_options (filename = ds_default_data .FILENAME ,
63+ names = ds_default_data .NAMES , drop_na = ds_default_data .DROP_NA ,
64+ change_col_vals = ds_default_data .CHANGE_COLS_VALS ,
65+ features_drop_names = ds_default_data .FEATURES_DROP_NAMES +
66+ ["preventative_treatment" , "gender" ,
67+ "education" , "mutation_status" ],
68+ column_normalization = ["salary" ], column_types = {"ethnicity" : str ,
69+ "salary" : float ,
70+ "diagnosis" : int })
71+
72+ discrete_env = DiscreteStateEnvironment .from_options (data_set = ds , action_space = None ,
73+ reward_manager = None , distortion_calculator = None )
74+ tiled_env = TiledEnv .from_options (env = discrete_env , max_size = 4096 , num_tilings = 5 , n_bins = 10 ,
75+ column_ranges = {"ethnicity" : [0.0 , 1.0 ],
76+ "salary" : [0.0 , 1.0 ],
77+ "diagnosis" : [0.0 , 1.0 ]}, tiling_dim = 3 )
78+
79+ """
80+ # specify the columns to drop
81+ drop_columns = MockSubjectsLoader.FEATURES_DROP_NAMES + ["preventative_treatment", "gender",
82+ "education", "mutation_status"]
83+ MockSubjectsLoader.FEATURES_DROP_NAMES = drop_columns
84+
85+ # do a salary normalization so that we work with
86+ # salaries in [0, 1] this is needed as we will
87+ # be using normalized distances
88+ MockSubjectsLoader.NORMALIZED_COLUMNS = ["salary"]
89+
90+ # specify the columns to use
91+ MockSubjectsLoader.COLUMNS_TYPES = {"ethnicity": str, "salary": float, "diagnosis": int}
92+ ds = MockSubjectsLoader()
93+ """
94+
95+ # create the discrete environment
96+
97+ semi_grad_sarsa .actions_before_training (tiled_env )
98+
99+ # make sure this is valid
100+ self .assertIsNotNone (semi_grad_sarsa .config )
101+
102+ episode_info = semi_grad_sarsa .on_episode (env = tiled_env )
103+ self .assertIsNotNone (episode_info )
104+
105+
106+ if __name__ == '__main__' :
107+ unittest .main ()
0 commit comments