Skip to content

Commit 5e07d3b

Browse files
committed
Add tests for environment
1 parent fa090da commit 5e07d3b

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

tests/test_environment.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
import pytest
5+
import numpy as np
6+
7+
from spaces.environment import Environment
8+
from spaces.action_space import ActionSpace
9+
from exceptions.exceptions import Error
10+
from utils.string_sequence_calculator import DistanceType
11+
from utils.dataset_wrapper import PandasDSWrapper
12+
13+
14+
class TestEnvironment(unittest.TestCase):
15+
16+
def setUp(self) -> None:
17+
"""
18+
Setup the PandasDSWrapper to be used in the tests
19+
:return: None
20+
"""
21+
22+
# read the data
23+
filename = Path("../data/mocksubjects.csv")
24+
25+
cols_types = {"gender": str, "ethnicity": str, "education": int,
26+
"salary": int, "diagnosis": int, "preventative_treatment": str,
27+
"mutation_status": int, }
28+
29+
self.ds = PandasDSWrapper(columns=cols_types)
30+
self.ds.read(filename=filename, **{"features_drop_names": ["NHSno", "given_name", "surname", "dob"],
31+
"names": ["NHSno", "given_name", "surname", "gender",
32+
"dob", "ethnicity", "education", "salary",
33+
"mutation_status", "preventative_treatment", "diagnosis"],
34+
"drop_na": True,
35+
"change_col_vals": {"diagnosis": [('N', 0)]}})
36+
37+
#@pytest.mark.skip(reason="no way of currently testing this")
38+
def test_prepare_column_states_throw_Error(self):
39+
# specify the action space. We need to establish how these actions
40+
# are performed
41+
action_space = ActionSpace(n=1)
42+
43+
# create the environment and
44+
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
45+
46+
with pytest.raises(Error):
47+
env.prepare_column_states()
48+
49+
#@pytest.mark.skip(reason="no way of currently testing this")
50+
def test_prepare_column_states(self):
51+
# specify the action space. We need to establish how these actions
52+
# are performed
53+
action_space = ActionSpace(n=1)
54+
55+
# create the environment and
56+
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
57+
58+
env.initialize_text_distances(distance_type=DistanceType.COSINE)
59+
env.prepare_column_states()
60+
61+
def test_get_numeric_ds(self):
62+
# specify the action space. We need to establish how these actions
63+
# are performed
64+
action_space = ActionSpace(n=1)
65+
66+
# create the environment and
67+
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
68+
69+
env.initialize_text_distances(distance_type=DistanceType.COSINE)
70+
env.prepare_column_states()
71+
72+
tensor = env.get_numeric_ds()
73+
74+
# test the shape of the tensor
75+
shape0 = tensor.size(dim=0)
76+
shape1 = tensor.size(dim=1)
77+
78+
self.assertEqual(shape0, env.start_ds.n_rows())
79+
self.assertEqual(shape1, env.start_ds.n_columns())
80+
81+
82+
83+
84+
85+
if __name__ == '__main__':
86+
unittest.main()

0 commit comments

Comments
 (0)