11import unittest
22import pytest
33from src .algorithms .sarsa_semi_gradient import SARSAnConfig , SARSAn
4- from src .spaces .tiled_environment import TiledEnv
4+ from src .spaces .tiled_environment import TiledEnv , TiledEnvConfig
55from src .policies .epsilon_greedy_policy import EpsilonGreedyPolicy , EpsilonDecayOption
66from src .exceptions .exceptions import InvalidParamValue
77
@@ -18,15 +18,29 @@ def test_actions_before_training_throws_invalid_environment(self):
1818 "satisfy the IS_TILED_ENV_CONSTRAINT constraint" , str (e ))
1919
2020 def test_actions_before_training_throws_invalid_policy (self ):
21-
22- env = TiledEnv (env = None , tiling_dim = 10 , num_tilings = 4096 , max_size = 100 )
23- config = SARSAnConfig ()
24- agent = SARSAn (sarsa_config = config )
25- with pytest .raises (InvalidParamValue ) as e :
26- agent .actions_before_training (env = env )
21+ env_config = TiledEnvConfig ()
22+ env_config .max_size = 4096
23+ env_config .num_tilings = 5
24+ env_config .tiling_dim = 6
25+ env_config .env = None
26+ env_config .column_scales = {"col1" : [0.0 , 1.0 ]}
27+
28+ env = TiledEnv (env_config )
29+ config = SARSAnConfig ()
30+ agent = SARSAn (sarsa_config = config )
31+ with pytest .raises (InvalidParamValue ) as e :
32+ agent .actions_before_training (env = env )
2733
2834 def test_actions_before_training_throws_estimator_not_set (self ):
29- env = TiledEnv (env = None , tiling_dim = 10 , num_tilings = 4096 , max_size = 100 )
35+
36+ env_config = TiledEnvConfig ()
37+ env_config .max_size = 4096
38+ env_config .num_tilings = 5
39+ env_config .tiling_dim = 6
40+ env_config .env = None
41+ env_config .column_scales = {"col1" : [0.0 , 1.0 ]}
42+
43+ env = TiledEnv (env_config )
3044 policy = EpsilonGreedyPolicy (eps = 1.0 , n_actions = 1 , decay_op = EpsilonDecayOption .NONE )
3145 config = SARSAnConfig ()
3246 config .policy = policy
0 commit comments