Skip to content

Commit 758fd51

Browse files
authored
Merge pull request #43 from pockerman/add_sarsa_semi_gradient
Debug tests
2 parents 2d77f0d + 13a0ad8 commit 758fd51

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

src/spaces/discrete_state_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class DiscreteStateEnvironment(object):
4545
to create bins where the average total distortion of the dataset falls in
4646
"""
4747

48+
IS_TILED_ENV_CONSTRAINT = False
49+
4850
def __init__(self, env_config: DiscreteEnvConfig) -> None:
4951
self.config = env_config
5052
self.n_rounds_below_min_distortion = 0

src/tests/test_suite.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import unittest
2+
3+
from .test_trainer import TestTrainer
4+
from .test_serial_hierarchy import TestSerialHierarchy
5+
from .test_preprocessor import TestPreprocessor
6+
from .test_actions import TestActions
7+
8+
def suite():
9+
suite = unittest.TestSuite()
10+
suite.addTest(TestTrainer)
11+
suite.addTest(TestSerialHierarchy)
12+
suite.addTest(TestPreprocessor)
13+
suite.addTest(TestActions)
14+
return suite
15+
16+
if __name__ == '__main__':
17+
runner = unittest.TextTestRunner()
18+
runner.run(suite())

src/tests/test_trainer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Unit-tests for class Trainer
3+
"""
4+
import unittest
5+
6+
from src.algorithms.trainer import Trainer
7+
from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn
8+
from src.spaces.tiled_environment import TiledEnv
9+
10+
11+
class TestTrainer(unittest.TestCase):
12+
13+
def test_with_sarsa_semi_grad_agent(self):
14+
15+
# create tiled environment
16+
tiled_env = TiledEnv(env=None, num_tilings=10, max_size=4096,
17+
tiling_dim=5)
18+
19+
sarsa_config = SARSAnConfig()
20+
agent = SARSAn(sarsa_config=sarsa_config)
21+
22+
trainer = Trainer(agent=agent, env=tiled_env,
23+
configuration={"n_episodes": 1})
24+
25+
trainer.train()
26+
27+
28+
if __name__ == '__main__':
29+
unittest.main()

0 commit comments

Comments
 (0)