1+ import unittest
2+ from pathlib import Path
3+
4+ import pytest
5+ import numpy as np
6+
7+ from spaces .environment import Environment
8+ from spaces .action_space import ActionSpace
9+ from exceptions .exceptions import Error
10+ from utils .string_sequence_calculator import DistanceType
11+ from utils .dataset_wrapper import PandasDSWrapper
12+
13+
14+ class TestEnvironment (unittest .TestCase ):
15+
16+ def setUp (self ) -> None :
17+ """
18+ Setup the PandasDSWrapper to be used in the tests
19+ :return: None
20+ """
21+
22+ # read the data
23+ filename = Path ("../data/mocksubjects.csv" )
24+
25+ cols_types = {"gender" : str , "ethnicity" : str , "education" : int ,
26+ "salary" : int , "diagnosis" : int , "preventative_treatment" : str ,
27+ "mutation_status" : int , }
28+
29+ self .ds = PandasDSWrapper (columns = cols_types )
30+ self .ds .read (filename = filename , ** {"features_drop_names" : ["NHSno" , "given_name" , "surname" , "dob" ],
31+ "names" : ["NHSno" , "given_name" , "surname" , "gender" ,
32+ "dob" , "ethnicity" , "education" , "salary" ,
33+ "mutation_status" , "preventative_treatment" , "diagnosis" ],
34+ "drop_na" : True ,
35+ "change_col_vals" : {"diagnosis" : [('N' , 0 )]}})
36+
37+ #@pytest.mark.skip(reason="no way of currently testing this")
38+ def test_prepare_column_states_throw_Error (self ):
39+ # specify the action space. We need to establish how these actions
40+ # are performed
41+ action_space = ActionSpace (n = 1 )
42+
43+ # create the environment and
44+ env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
45+
46+ with pytest .raises (Error ):
47+ env .prepare_column_states ()
48+
49+ #@pytest.mark.skip(reason="no way of currently testing this")
50+ def test_prepare_column_states (self ):
51+ # specify the action space. We need to establish how these actions
52+ # are performed
53+ action_space = ActionSpace (n = 1 )
54+
55+ # create the environment and
56+ env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
57+
58+ env .initialize_text_distances (distance_type = DistanceType .COSINE )
59+ env .prepare_column_states ()
60+
61+ def test_get_numeric_ds (self ):
62+ # specify the action space. We need to establish how these actions
63+ # are performed
64+ action_space = ActionSpace (n = 1 )
65+
66+ # create the environment and
67+ env = Environment (data_set = self .ds , action_space = action_space , gamma = 0.99 , start_column = "gender" )
68+
69+ env .initialize_text_distances (distance_type = DistanceType .COSINE )
70+ env .prepare_column_states ()
71+
72+ tensor = env .get_numeric_ds ()
73+
74+ # test the shape of the tensor
75+ shape0 = tensor .size (dim = 0 )
76+ shape1 = tensor .size (dim = 1 )
77+
78+ self .assertEqual (shape0 , env .start_ds .n_rows ())
79+ self .assertEqual (shape1 , env .start_ds .n_columns ())
80+
81+
82+
83+
84+
85+ if __name__ == '__main__' :
86+ unittest .main ()
0 commit comments