Skip to content

Commit 9175354

Browse files
committed
API updates
1 parent 61f884b commit 9175354

File tree

9 files changed

+271
-46
lines changed

9 files changed

+271
-46
lines changed

src/algorithms/trainer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Trainer
33
"""
44

5-
from src.utils import INFO
5+
import numpy as np
66
from typing import TypeVar
7+
from src.utils import INFO
78

89
Env = TypeVar("Env")
910
Agent = TypeVar("Agent")
@@ -15,22 +16,44 @@ def __init__(self, env: Env, agent: Agent, configuration: dir) -> None:
1516
self.env = env
1617
self.agent = agent
1718
self.configuration = configuration
19+
# monitor performance
20+
self.total_rewards: np.array = None
21+
self.iterations_per_episode = []
22+
23+
def actions_before_training(self):
24+
self.total_rewards: np.array = np.zeros(self.configuration['n_episodes'])
25+
self.iterations_per_episode = []
26+
27+
self.agent.actions_before_training(self.env)
28+
29+
def actions_after_episode_ends(self, **options):
30+
self.agent.actions_after_episode_ends(**options)
1831

1932
def train(self):
2033

2134
print("{0} Training agent {1}".format(INFO, self.agent.name))
35+
self.actions_before_training()
2236

23-
for episode in range(1, self.configuration["max_n_episodes"] + 1):
24-
print("INFO: Episode {0}/{1}".format(episode, self.configuration["max_n_episodes"]))
37+
for episode in range(0, self.configuration["n_episodes"]):
38+
print("INFO: Episode {0}/{1}".format(episode, self.configuration["n_episodes"]))
2539

2640
# reset the environment
2741
ignore = self.env.reset()
2842

2943
# train for a number of iterations
30-
self.agent.train(self.env)
44+
episode_score, n_itrs = self.agent.train(self.env)
45+
46+
if episode % self.configuration['output_msg_frequency'] == 0:
47+
print("{0}: On episode {1} training finished with "
48+
"{2} iterations. Total reward={3}".format(INFO, episode, n_itrs, episode_score))
49+
50+
self.iterations_per_episode.append(n_itrs)
51+
self.total_rewards[episode] = episode_score
3152

3253
# is it time to update the model?
3354
if self.configuration["update_frequency"] % episode == 0:
3455
self.agent.update()
3556

57+
self.actions_after_episode_ends(**{"episode_idx": episode})
58+
3659
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/datasets/dataset_wrapper.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,20 @@ def read(self, filename: Path, **options) -> None:
8080
# try to cast to the data types
8181
self.ds = change_column_types(ds=self.ds, column_types=self.columns)
8282

83+
def sample_column_name(self) -> str:
84+
"""
85+
Samples a name from the columns
86+
:return: a column name
87+
"""
88+
names = self.get_columns_names()
89+
return np.random.choice(names)
90+
8391
def set_columns_to_type(self, col_name_types) -> None:
92+
"""
93+
Set the types of the columns
94+
:param col_name_types:
95+
:return:
96+
"""
8497
self.ds.astype(dtype=col_name_types)
8598

8699
def attach_column_hierarchy(self, col_name: str, hierarchy: HierarchyBase):

src/exceptions/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,13 @@ def __str__(self):
1515
return self.message
1616

1717

18+
class InvalidParamValue(Exception):
19+
def __init__(self, param_name: str, param_value: str):
20+
self.message = "Parameter {0} has invalid value {1}".format(param_name, param_value)
21+
22+
def __str__(self):
23+
return self.message
24+
25+
26+
27+

src/spaces/action_space.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
actions in the actions.py module
44
"""
55

6+
import numpy as np
67
from gym.spaces.discrete import Discrete
78
from src.spaces.actions import ActionBase
89

@@ -66,3 +67,37 @@ def sample_and_get(self) -> ActionBase:
6667
"""
6768
action_idx = self.sample()
6869
return self.actions[action_idx]
70+
71+
def get_non_exhausted_actions(self) -> list:
72+
73+
actions_ = []
74+
75+
for action in self.actions:
76+
if not action.is_exhausted():
77+
actions_.append(action)
78+
79+
return actions_
80+
81+
def sample_and_get_non_exhausted(self) -> ActionBase:
82+
83+
actions = self.get_non_exhausted_actions()
84+
return np.random.choice(actions)
85+
86+
def is_exhausted(self):
87+
88+
finished = True
89+
90+
for action in self.actions:
91+
if not action.is_exhausted():
92+
return False
93+
94+
return finished
95+
96+
def reset(self) -> None:
97+
"""
98+
Reset every action in the action space
99+
:return:
100+
"""
101+
for action in self.actions:
102+
action.reinit()
103+

src/spaces/actions.py

Lines changed: 106 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
from typing import List
44

55
from src.utils.hierarchy_base import HierarchyBase
6+
from src.utils.mixins import WithHierarchyTable
7+
8+
9+
def move_next(iterators: List) -> None:
10+
"""
11+
Loop over the iterators and move them
12+
to the next item
13+
:param iterators: The list of iterators to propagate
14+
:return: None
15+
"""
16+
for item in iterators:
17+
next(item)
618

719

820
class ActionType(enum.IntEnum):
@@ -47,39 +59,27 @@ def act(self, **ops) -> None:
4759
"""
4860

4961
@abc.abstractmethod
50-
def get_maximum_number_of_transforms(self):
62+
def get_maximum_number_of_transforms(self) -> int:
5163
"""
5264
Returns the maximum number of transforms that the action applies
5365
:return:
5466
"""
5567

68+
@abc.abstractmethod
69+
def is_exhausted(self) -> bool:
70+
"""
71+
Returns true if the action has exhausted all its
72+
transforms
73+
:return:
74+
"""
5675

57-
def move_next(iterators: List) -> None:
58-
"""
59-
Loop over the iterators and move them
60-
to the next item
61-
:param iterators: The list of iterators to propagate
62-
:return: None
63-
"""
64-
for item in iterators:
65-
next(item)
66-
67-
68-
class _WithTable(object):
69-
70-
def __init__(self) -> None:
71-
super(_WithTable, self).__init__()
72-
self.table = {}
73-
self.iterators = []
74-
75-
def add_hierarchy(self, key: str, hierarchy: HierarchyBase) -> None:
76+
@abc.abstractmethod
77+
def reinit(self) -> None:
7678
"""
77-
Add a hierarchy for the given key
78-
:param key: The key to attach the Hierarchy
79-
:param hierarchy: The hierarchy to attach
80-
:return: None
79+
Reinitialize the action to the state when the
80+
constructor is called
81+
:return:
8182
"""
82-
self.table[key] = hierarchy
8383

8484

8585
class ActionIdentity(ActionBase):
@@ -89,13 +89,14 @@ class ActionIdentity(ActionBase):
8989

9090
def __init__(self, column_name: str) -> None:
9191
super(ActionIdentity, self).__init__(column_name=column_name, action_type=ActionType.IDENTITY)
92+
self.called = False
9293

93-
def act(self, **ops):
94+
def act(self, **ops) -> None:
9495
"""
9596
Perform the action
9697
:return:
9798
"""
98-
pass
99+
self.called = True
99100

100101
def get_maximum_number_of_transforms(self):
101102
"""
@@ -104,6 +105,22 @@ def get_maximum_number_of_transforms(self):
104105
"""
105106
return 1
106107

108+
def is_exhausted(self) -> bool:
109+
"""
110+
Returns true if the action has exhausted all its
111+
transforms
112+
:return:
113+
"""
114+
return self.called
115+
116+
def reinit(self) -> None:
117+
"""
118+
Reinitialize the action to the state when the
119+
constructor is called
120+
:return:
121+
"""
122+
self.called = False
123+
107124

108125
class ActionTransform(ActionBase):
109126

@@ -127,17 +144,32 @@ def get_maximum_number_of_transforms(self):
127144
"""
128145
raise NotImplementedError("Method not implemented")
129146

147+
def is_exhausted(self) -> bool:
148+
"""
149+
Returns true if the action has exhausted all its
150+
transforms
151+
:return:
152+
"""
153+
raise NotImplementedError("Method not implemented")
154+
155+
def reinit(self) -> None:
156+
"""
157+
Reinitialize the action to the state when the
158+
constructor is called
159+
:return:
160+
"""
161+
raise NotImplementedError("Method not implemented")
162+
130163

131-
class ActionSuppress(ActionBase, _WithTable):
164+
class ActionSuppress(ActionBase, WithHierarchyTable):
132165

133166
"""
134167
Implements the suppress action
135168
"""
136169
def __init__(self, column_name: str, suppress_table=None):
137170
super(ActionSuppress, self).__init__(column_name=column_name, action_type=ActionType.SUPPRESS)
138171

139-
if suppress_table is not None:
140-
self.table = suppress_table
172+
self.table = suppress_table
141173

142174
# fill in the iterators
143175
self.iterators = [iter(self.table[item]) for item in self.table]
@@ -148,16 +180,21 @@ def act(self, **ops) -> None:
148180
:return: None
149181
"""
150182

183+
# get the values of the column
184+
col_vals = ops['data'].values
185+
151186
# generalize the data given
152187
for i, item in enumerate(ops["data"]):
188+
value = self.table[item].value
189+
col_vals[i] = value
153190

154-
if item in self.table:
155-
value = self.table[item].value
156-
item = value
157-
ops["data"][i] = value
191+
ops["data"] = col_vals
158192

159-
# update the generalization
193+
# update the generalization iterators
194+
# so next time we visit we update according to
195+
# the new values
160196
move_next(iterators=self.iterators)
197+
return ops['data']
161198

162199
def get_maximum_number_of_transforms(self):
163200
"""
@@ -174,17 +211,32 @@ def get_maximum_number_of_transforms(self):
174211

175212
return max_transform
176213

214+
def is_exhausted(self) -> bool:
215+
"""
216+
Returns true if the action has exhausted all its
217+
transforms
218+
:return:
219+
"""
220+
return self.finished()
221+
222+
def reinit(self) -> None:
223+
"""
224+
Reinitialize the action to the state when the
225+
constructor is called
226+
:return:
227+
"""
228+
self.reset_iterators()
177229

178-
class ActionGeneralize(ActionBase, _WithTable):
230+
231+
class ActionGeneralize(ActionBase, WithHierarchyTable):
179232
"""
180233
Implements the generalization action
181234
"""
182235

183236
def __init__(self, column_name: str, generalization_table: dict = None):
184237
super(ActionGeneralize, self).__init__(column_name=column_name, action_type=ActionType.GENERALIZE)
185238

186-
if generalization_table is not None:
187-
self.table = generalization_table
239+
self.table = generalization_table
188240

189241
# fill in the iterators
190242
self.iterators = [iter(self.table[item]) for item in self.table]
@@ -201,7 +253,6 @@ def act(self, **ops):
201253
# generalize the data given
202254
for i, item in enumerate(col_vals):
203255

204-
#print(item)
205256
# How do we update the generalizations?
206257
value = self.table[item].value
207258
col_vals[i] = value
@@ -232,6 +283,22 @@ def get_maximum_number_of_transforms(self):
232283

233284
return max_transform
234285

286+
def is_exhausted(self) -> bool:
287+
"""
288+
Returns true if the action has exhausted all its
289+
transforms
290+
:return:
291+
"""
292+
return self.finished()
293+
294+
def reinit(self) -> None:
295+
"""
296+
Reinitialize the action to the state when the
297+
constructor is called
298+
:return:
299+
"""
300+
self.reset_iterators()
301+
235302

236303

237304

0 commit comments

Comments
 (0)