Skip to content

Commit 61f884b

Browse files
authored
Merge pull request #12 from pockerman/add_actor_critic_algorithm
Add various utilities
2 parents 33b22fe + 8d2646e commit 61f884b

File tree

8 files changed

+322
-24
lines changed

8 files changed

+322
-24
lines changed

src/spaces/actions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def act(self, **ops) -> None:
4646
:return:
4747
"""
4848

49+
@abc.abstractmethod
50+
def get_maximum_number_of_transforms(self):
51+
"""
52+
Returns the maximum number of transforms that the action applies
53+
:return:
54+
"""
55+
4956

5057
def move_next(iterators: List) -> None:
5158
"""
@@ -90,6 +97,13 @@ def act(self, **ops):
9097
"""
9198
pass
9299

100+
def get_maximum_number_of_transforms(self):
101+
"""
102+
Returns the maximum number of transforms that the action applies
103+
:return:
104+
"""
105+
return 1
106+
93107

94108
class ActionTransform(ActionBase):
95109

@@ -106,6 +120,13 @@ def act(self, **ops):
106120
"""
107121
pass
108122

123+
def get_maximum_number_of_transforms(self):
124+
"""
125+
Returns the maximum number of transforms that the action applies
126+
:return:
127+
"""
128+
raise NotImplementedError("Method not implemented")
129+
109130

110131
class ActionSuppress(ActionBase, _WithTable):
111132

@@ -138,6 +159,21 @@ def act(self, **ops) -> None:
138159
# update the generalization
139160
move_next(iterators=self.iterators)
140161

162+
def get_maximum_number_of_transforms(self):
163+
"""
164+
Returns the maximum number of transforms that the action applies
165+
:return:
166+
"""
167+
max_transform = 0
168+
169+
for item in self.table:
170+
size = len(self.table[item])
171+
172+
if size > max_transform:
173+
max_transform = size
174+
175+
return max_transform
176+
141177

142178
class ActionGeneralize(ActionBase, _WithTable):
143179
"""
@@ -181,5 +217,21 @@ def act(self, **ops):
181217
def add_generalization(self, key: str, values: HierarchyBase) -> None:
182218
self.table[key] = values
183219

220+
def get_maximum_number_of_transforms(self):
221+
"""
222+
Returns the maximum number of transforms that the action applies
223+
:return:
224+
"""
225+
max_transform = 0
226+
227+
for item in self.table:
228+
size = len(self.table[item])
229+
230+
if size > max_transform:
231+
max_transform = size
232+
233+
return max_transform
234+
235+
184236

185237

src/spaces/environment.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313

1414
from src.exceptions.exceptions import Error
1515
from src.spaces.actions import ActionBase, ActionType
16+
from src.spaces.state_space import StateSpace, State
1617
from src.utils.string_distance_calculator import DistanceType, TextDistanceCalculator
1718

1819
DataSet = TypeVar("DataSet")
20+
RewardManager = TypeVar("RewardManager")
1921

2022
_Reward = TypeVar('_Reward')
2123
_Discount = TypeVar('_Discount')
@@ -65,20 +67,37 @@ def last(self) -> bool:
6567
class Environment(object):
6668

6769
def __init__(self, data_set, action_space,
68-
gamma: float, start_column: str, ):
70+
gamma: float, start_column: str, reward_manager: RewardManager):
6971
self.data_set = data_set
7072
self.start_ds = copy.deepcopy(data_set)
7173
self.current_time_step = self.start_ds
7274
self.action_space = action_space
7375
self.gamma = gamma
7476
self.start_column = start_column
7577
self.column_distances = {}
78+
self.state_space = StateSpace()
7679
self.distance_calculator = None
80+
self.reward_manager: RewardManager = reward_manager
81+
82+
# initialize the state space
83+
self.state_space.init_from_environment(env=self)
7784

7885
@property
7986
def n_features(self) -> int:
87+
"""
88+
Returns the number of features in the dataset
89+
:return:
90+
"""
8091
return self.start_ds.n_columns
8192

93+
@property
94+
def feature_names(self) -> list:
95+
"""
96+
Returns the feature names in the dataset
97+
:return:
98+
"""
99+
return self.start_ds.get_columns_names()
100+
82101
@property
83102
def n_examples(self) -> int:
84103
return self.start_ds.n_rows
@@ -99,6 +118,24 @@ def initialize_text_distances(self, distance_type: DistanceType) -> None:
99118
def sample_action(self) -> ActionBase:
100119
return self.action_space.sample_and_get()
101120

121+
def get_column_as_tensor(self, column_name) -> torch.Tensor:
122+
"""
123+
Returns the column in the dataset as a torch tensor
124+
:param column_name:
125+
:return:
126+
"""
127+
data = {}
128+
129+
if self.start_ds.columns[column_name] == str:
130+
131+
numpy_vals = self.column_distances[column_name]
132+
data[column_name] = numpy_vals
133+
else:
134+
data[column_name] = self.data_set.get_column(col_name=column_name).to_numpy()
135+
136+
target_df = pd.DataFrame(data)
137+
return torch.tensor(target_df.to_numpy(), dtype=torch.float64)
138+
102139
def get_ds_as_tensor(self) -> torch.Tensor:
103140

104141
"""
@@ -111,7 +148,6 @@ def get_ds_as_tensor(self) -> torch.Tensor:
111148
for col in col_names:
112149

113150
if self.start_ds.columns[col] == str:
114-
#print("col: {0} type {1}".format(col, self.start_ds.get_column_type(col_name=col)))
115151
numpy_vals = self.column_distances[col]
116152
data[col] = numpy_vals
117153
else:
@@ -195,28 +231,22 @@ def step(self, action: ActionBase) -> TimeStep:
195231
`action` will be ignored.
196232
"""
197233

234+
# apply the action
198235
self.apply_action(action=action)
199236

200-
# if the action is identity don't bother
201-
# doing anything
202-
#if action.action_type == ActionType.IDENTITY:
203-
# return TimeStep(step_type=StepType.MID, reward=0.0,
204-
# observation=self.get_ds_as_tensor().float(), discount=self.gamma)
205-
206-
# apply the transform of the data set
207-
#self.data_set.apply_column_transform(transform=action)
237+
# update the state space
238+
self.state_space.update_state(state_name=action.column_name, status=action.action_type)
208239

209240
# perform the action on the data set
210241
self.prepare_column_states()
211242

212243
# calculate the information leakage and establish the reward
213244
# to return to the agent
245+
reward = self.reward_manager.get_state_reward(self.state_space, action)
214246

215-
return TimeStep(step_type=StepType.MID, reward=0.0,
216-
observation=self.get_ds_as_tensor().float(), discount=self.gamma)
217-
218-
219-
247+
return TimeStep(step_type=StepType.MID, reward=reward,
248+
observation=self.get_column_as_tensor(column_name=action.column_name).float(),
249+
discount=self.gamma)
220250

221251

222252
class MultiprocessEnv(object):

src/spaces/state_space.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Discretized state space
3+
"""
4+
5+
from typing import TypeVar, List
6+
from gym.spaces.discrete import Discrete
7+
8+
from src.exceptions.exceptions import Error
9+
10+
ActionStatus = TypeVar("ActionStatus")
11+
Env = TypeVar("Env")
12+
13+
14+
class State(object):
15+
"""
16+
Describes an environment state
17+
"""
18+
def __init__(self, column_name: str, state_id: int):
19+
self.column_name: str = column_name
20+
self.state_id: int = state_id
21+
self.history: List[ActionStatus] = []
22+
23+
@property
24+
def key(self) -> tuple:
25+
return self.column_name, self.state_id
26+
27+
28+
class StateSpace(Discrete):
29+
"""
30+
The State space is accumulates the discrete states
31+
"""
32+
33+
def __init__(self):
34+
super(StateSpace, self).__init__(n=0)
35+
self.states = {}
36+
37+
def init_from_environment(self, env: Env):
38+
"""
39+
Initialize from environment
40+
:param env:
41+
:return:
42+
"""
43+
names = env.feature_names
44+
for col_name in names:
45+
46+
if col_name in self.states:
47+
raise ValueError("Column {0} already exists".format(col_name))
48+
49+
self.states[col_name] = State(column_name=col_name, state_id=len(self.states))
50+
51+
# set the number of discrete states
52+
self.n = len(self.states)
53+
54+
def add_state(self, state: State):
55+
if state.column_name in self.states:
56+
raise ValueError("Column {0} already exists".format(state.column_name))
57+
58+
self.states[state.column_name] = state
59+
60+
def update_state(self, state_name, status: ActionStatus):
61+
self.states[state_name].history.append(status)
62+
63+
def __len__(self):
64+
return len(self.states)

0 commit comments

Comments
 (0)