Skip to content

Commit d7b73fa

Browse files
committed
#53 Fix test failing
1 parent e754b63 commit d7b73fa

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

src/tests/test_sarsa_semi_gradient.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import pytest
33
from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn
4-
from src.spaces.tiled_environment import TiledEnv
4+
from src.spaces.tiled_environment import TiledEnv, TiledEnvConfig
55
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonDecayOption
66
from src.exceptions.exceptions import InvalidParamValue
77

@@ -18,15 +18,29 @@ def test_actions_before_training_throws_invalid_environment(self):
1818
"satisfy the IS_TILED_ENV_CONSTRAINT constraint", str(e))
1919

2020
def test_actions_before_training_throws_invalid_policy(self):
21-
22-
env = TiledEnv(env=None, tiling_dim=10, num_tilings=4096, max_size=100)
23-
config = SARSAnConfig()
24-
agent = SARSAn(sarsa_config=config)
25-
with pytest.raises(InvalidParamValue) as e:
26-
agent.actions_before_training(env=env)
21+
env_config = TiledEnvConfig()
22+
env_config.max_size = 4096
23+
env_config.num_tilings = 5
24+
env_config.tiling_dim = 6
25+
env_config.env = None
26+
env_config.column_scales = {"col1": [0.0, 1.0]}
27+
28+
env = TiledEnv(env_config)
29+
config = SARSAnConfig()
30+
agent = SARSAn(sarsa_config=config)
31+
with pytest.raises(InvalidParamValue) as e:
32+
agent.actions_before_training(env=env)
2733

2834
def test_actions_before_training_throws_estimator_not_set(self):
29-
env = TiledEnv(env=None, tiling_dim=10, num_tilings=4096, max_size=100)
35+
36+
env_config = TiledEnvConfig()
37+
env_config.max_size = 4096
38+
env_config.num_tilings = 5
39+
env_config.tiling_dim = 6
40+
env_config.env = None
41+
env_config.column_scales = {"col1": [0.0, 1.0]}
42+
43+
env = TiledEnv(env_config)
3044
policy = EpsilonGreedyPolicy(eps=1.0, n_actions=1, decay_op=EpsilonDecayOption.NONE)
3145
config = SARSAnConfig()
3246
config.policy = policy

0 commit comments

Comments
 (0)