Skip to content

Commit 8d2646e

Browse files
committed
Add state space
1 parent 6d5be8e commit 8d2646e

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

src/spaces/state_space.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Discretized state space
3+
"""
4+
5+
from typing import TypeVar, List
6+
from gym.spaces.discrete import Discrete
7+
8+
from src.exceptions.exceptions import Error
9+
10+
ActionStatus = TypeVar("ActionStatus")
11+
Env = TypeVar("Env")
12+
13+
14+
class State(object):
15+
"""
16+
Describes an environment state
17+
"""
18+
def __init__(self, column_name: str, state_id: int):
19+
self.column_name: str = column_name
20+
self.state_id: int = state_id
21+
self.history: List[ActionStatus] = []
22+
23+
@property
24+
def key(self) -> tuple:
25+
return self.column_name, self.state_id
26+
27+
28+
class StateSpace(Discrete):
29+
"""
30+
The State space is accumulates the discrete states
31+
"""
32+
33+
def __init__(self):
34+
super(StateSpace, self).__init__(n=0)
35+
self.states = {}
36+
37+
def init_from_environment(self, env: Env):
38+
"""
39+
Initialize from environment
40+
:param env:
41+
:return:
42+
"""
43+
names = env.feature_names
44+
for col_name in names:
45+
46+
if col_name in self.states:
47+
raise ValueError("Column {0} already exists".format(col_name))
48+
49+
self.states[col_name] = State(column_name=col_name, state_id=len(self.states))
50+
51+
# set the number of discrete states
52+
self.n = len(self.states)
53+
54+
def add_state(self, state: State):
55+
if state.column_name in self.states:
56+
raise ValueError("Column {0} already exists".format(state.column_name))
57+
58+
self.states[state.column_name] = state
59+
60+
def update_state(self, state_name, status: ActionStatus):
61+
self.states[state_name].history.append(status)
62+
63+
def __len__(self):
64+
return len(self.states)

src/tests/test_space_state.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import unittest
2+
3+
import unittest
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
from src.spaces.environment import Environment
9+
from src.spaces.action_space import ActionSpace
10+
from src.spaces.actions import ActionSuppress, ActionGeneralize
11+
from src.exceptions.exceptions import Error
12+
from src.utils.serial_hierarchy import SerialHierarchy
13+
from src.utils.string_distance_calculator import DistanceType
14+
from src.datasets.dataset_wrapper import PandasDSWrapper
15+
from src.spaces.state_space import StateSpace, State
16+
17+
class TestStateSpace(unittest.TestCase):
18+
19+
def setUp(self) -> None:
20+
"""
21+
Setup the PandasDSWrapper to be used in the tests
22+
:return: None
23+
"""
24+
25+
# read the data
26+
filename = Path("../../data/mocksubjects.csv")
27+
28+
cols_types = {"gender": str, "ethnicity": str, "education": int,
29+
"salary": int, "diagnosis": int, "preventative_treatment": str,
30+
"mutation_status": int, }
31+
32+
self.ds = PandasDSWrapper(columns=cols_types)
33+
self.ds.read(filename=filename, **{"features_drop_names": ["NHSno", "given_name", "surname", "dob"],
34+
"names": ["NHSno", "given_name", "surname", "gender",
35+
"dob", "ethnicity", "education", "salary",
36+
"mutation_status", "preventative_treatment", "diagnosis"],
37+
"drop_na": True,
38+
"change_col_vals": {"diagnosis": [('N', 0)]}})
39+
40+
def test_creation(self):
41+
42+
action_space = ActionSpace(n=3)
43+
44+
generalization_table = {"Mixed White/Asian": SerialHierarchy(values=["Mixed", ]),
45+
"Chinese": SerialHierarchy(values=["Asian", ]),
46+
"Indian": SerialHierarchy(values=["Asian", ]),
47+
"Mixed White/Black African": SerialHierarchy(values=["Mixed", ]),
48+
"Black African": SerialHierarchy(values=["Black", ]),
49+
"Asian other": SerialHierarchy(values=["Asian", ]),
50+
"Black other": SerialHierarchy(values=["Black", ]),
51+
"Mixed White/Black Caribbean": SerialHierarchy(values=["Mixed", ]),
52+
"Mixed other": SerialHierarchy(values=["Mixed", ]),
53+
"Arab": SerialHierarchy(values=["Asian", ]),
54+
"White Irish": SerialHierarchy(values=["White", ]),
55+
"Not stated": SerialHierarchy(values=["Not stated"]),
56+
"White Gypsy/Traveller": SerialHierarchy(values=["White", ]),
57+
"White British": SerialHierarchy(values=["White", ]),
58+
"Bangladeshi": SerialHierarchy(values=["Asian", ]),
59+
"White other": SerialHierarchy(values=["White", ]),
60+
"Black Caribbean": SerialHierarchy(values=["Black", ]),
61+
"Pakistani": SerialHierarchy(values=["Asian", ])}
62+
63+
action_space.add(ActionGeneralize(column_name="ethnicity", generalization_table=generalization_table))
64+
65+
# create the environment from the given dataset
66+
env = Environment(data_set=self.ds, action_space=action_space, gamma=0.99, start_column="gender")
67+
68+
state_space = StateSpace()
69+
state_space.init_from_environment(env=env)
70+
71+
print(state_space.states.keys())
72+
73+
self.assertEqual(env.n_features, state_space.n)
74+
75+
76+
if __name__ == '__main__':
77+
unittest.main()

src/utils/reward_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
Various utilities to handle reward assignment
3+
"""
4+
5+
6+
class RewardManager(object):
7+
"""
8+
Helper class to assign rewards
9+
"""
10+
def __init__(self) -> None:
11+
pass
12+
13+
def get_state_reward(self, *options) -> float:
14+
return 0.0

0 commit comments

Comments
 (0)