Skip to content

Commit 2dc329b

Browse files
committed
#13 Update API
1 parent 361f72a commit 2dc329b

19 files changed

+545
-765
lines changed

src/algorithms/anonymity_a2c_ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ray.rllib.agents.a3c as a3c
77
from ray.tune.logger import pretty_print
88
from ray.rllib.env.env_context import EnvContext
9-
from src.spaces.environment import TimeStep, StepType
9+
from src.spaces.discrete_state_environment import TimeStep, StepType
1010
from src.spaces.observation_space import ObsSpace
1111

1212

src/algorithms/q_learning.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def actions_before_training(self, env: Env, **options):
4242
if self.config.policy is None:
4343
raise InvalidParamValue(param_name="policy", param_value="None")
4444

45-
for state in range(env.observation_space.n):
46-
for action in range(env.action_space.n):
45+
for state in range(1, env.n_states):
46+
for action in range(env.n_actions):
4747
self.q_table[state, action] = 0.0
4848

4949
def actions_after_episode_ends(self, **options):
@@ -59,8 +59,9 @@ def actions_after_episode_ends(self, **options):
5959
def train(self, env: Env, **options) -> tuple:
6060

6161
# episode score
62-
episode_score = 0 # initialize score
62+
episode_score = 0
6363
counter = 0
64+
total_distortion = 0
6465

6566
time_step = env.reset()
6667
state = time_step.observation
@@ -72,24 +73,28 @@ def train(self, env: Env, **options) -> tuple:
7273

7374
action = env.get_action(action_idx)
7475

76+
if action.action_type.name == "GENERALIZE" and action.column_name == "salary":
77+
print("Attempt to generalize salary")
78+
else:
79+
print(action.action_type.name, " on ", action.column_name)
80+
7581
# take action A, observe R, S'
7682
next_time_step = env.step(action)
7783
next_state = next_time_step.observation
7884
reward = next_time_step.reward
7985

80-
next_state_id = next_state.state_id if next_state is not None else None
81-
8286
# add reward to agent's score
83-
episode_score += next_time_step.reward
84-
self._update_Q_table(state=state.state_id, action=action_idx, reward=reward,
85-
next_state=next_state_id, n_actions=env.action_space.n)
87+
episode_score += reward
88+
self._update_Q_table(state=state, action=action_idx, reward=reward,
89+
next_state=next_state, n_actions=env.n_actions)
8690
state = next_state # S <- S'
8791
counter += 1
92+
total_distortion += next_time_step.info["total_distortion"]
8893

8994
if next_time_step.last():
9095
break
9196

92-
return episode_score, counter
97+
return episode_score, total_distortion, counter
9398

9499
def _update_Q_table(self, state: int, action: int, n_actions: int, reward: float, next_state: int = None) -> None:
95100
"""

src/algorithms/trainer.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,24 @@ def __init__(self, env: Env, agent: Agent, configuration: dir) -> None:
1919
# monitor performance
2020
self.total_rewards: np.array = np.zeros(configuration['n_episodes'])
2121
self.iterations_per_episode = []
22+
self.total_distortions = []
23+
24+
def avg_rewards(self) -> np.array:
25+
"""
26+
Returns the average reward per episode
27+
:return:
28+
"""
29+
avg = np.zeros(self.configuration['n_episodes'])
30+
31+
for i in range(self.total_rewards.shape[0]):
32+
avg[i] = self.total_rewards[i] / self.iterations_per_episode[i]
33+
return avg
2234

2335
def actions_before_training(self):
36+
"""
37+
Any actions to perform before training begins
38+
:return:
39+
"""
2440
self.total_rewards: np.array = np.zeros(self.configuration['n_episodes'])
2541
self.iterations_per_episode = []
2642

@@ -29,27 +45,32 @@ def actions_before_training(self):
2945
def actions_after_episode_ends(self, **options):
3046
self.agent.actions_after_episode_ends(**options)
3147

48+
if options["episode_idx"] % self.configuration['output_msg_frequency'] == 0:
49+
if self.env.config.distorted_set_path is not None:
50+
self.env.save_current_dataset(options["episode_idx"])
51+
3252
def train(self):
3353

3454
print("{0} Training agent {1}".format(INFO, self.agent.name))
3555
self.actions_before_training()
3656

3757
for episode in range(0, self.configuration["n_episodes"]):
38-
print("INFO: Episode {0}/{1}".format(episode, self.configuration["n_episodes"]))
58+
print("{0} On episode {1}/{2}".format(INFO, episode, self.configuration["n_episodes"]))
3959

4060
# reset the environment
4161
ignore = self.env.reset()
4262

4363
# train for a number of iterations
44-
episode_score, n_itrs = self.agent.train(self.env)
64+
episode_score, total_distortion, n_itrs = self.agent.train(self.env)
4565

46-
if episode % self.configuration['output_msg_frequency'] == 0:
47-
print("{0}: On episode {1} training finished with "
48-
"{2} iterations. Total reward={3}".format(INFO, episode, n_itrs, episode_score))
66+
print("{0} Episode score={1}, episode total distortion {2}".format(INFO, episode_score, total_distortion / n_itrs))
67+
68+
#if episode % self.configuration['output_msg_frequency'] == 0:
69+
print("{0} Episode finished after {1} iterations".format(INFO, n_itrs))
4970

5071
self.iterations_per_episode.append(n_itrs)
5172
self.total_rewards[episode] = episode_score
52-
73+
self.total_distortions.append(total_distortion)
5374
self.actions_after_episode_ends(**{"episode_idx": episode})
5475

5576
print("{0} Training finished for agent {1}".format(INFO, self.agent.name))

src/apps/qlearning_on_mock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from src.algorithms.trainer import Trainer
66
from src.utils.string_distance_calculator import StringDistanceType
77
from src.spaces.actions import ActionSuppress, ActionIdentity, ActionStringGeneralize, ActionTransform
8-
from src.spaces.environment import Environment, EnvConfig
8+
from src.spaces.discrete_state_environment import Environment, EnvConfig
99
from src.spaces.action_space import ActionSpace
1010
from src.datasets.datasets_loaders import MockSubjectsLoader
1111
from src.utils.reward_manager import RewardManager

src/datasets/dataset_wrapper.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from src.preprocessor.cleanup_utils import read_csv, replace, change_column_types
8+
from src.exceptions.exceptions import InvalidDataTypeException
89

910
DS = TypeVar("DS")
1011
HierarchyBase = TypeVar('HierarchyBase')
@@ -41,7 +42,7 @@ def __init__(self, columns: dir) -> None:
4142

4243
# map that holds the hierarchy to be applied
4344
# on each column in the dataset
44-
self.column_hierarchy = {}
45+
#self.column_hierarchy = {}
4546

4647
@property
4748
def n_rows(self) -> int:
@@ -63,6 +64,14 @@ def n_columns(self) -> int:
6364
def schema(self) -> dict:
6465
return pd.io.json.build_table_schema(self.ds)
6566

67+
def save_to_csv(self, filename: Path) -> None:
68+
"""
69+
Save the underlying dataset in a csv format
70+
:param filename:
71+
:return:
72+
"""
73+
self.ds.to_csv(filename)
74+
6675
def read(self, filename: Path, **options) -> None:
6776
"""
6877
Load a data set from a file
@@ -82,6 +91,25 @@ def read(self, filename: Path, **options) -> None:
8291
# try to cast to the data types
8392
self.ds = change_column_types(ds=self.ds, column_types=self.columns)
8493

94+
def normalize_column(self, column_name) -> None:
95+
"""
96+
Normalizes the column with the given name using the following
97+
transformation:
98+
99+
z_i = \frac{x_i - min(x)}{max(x) - min(x)}
100+
101+
if the column is not of numeric type then this function
102+
throws an InvalidDataTypeException
103+
:param column_name:
104+
:return:
105+
"""
106+
107+
data_type = self.columns[column_name]
108+
if data_type is not int or data_type is not float:
109+
raise InvalidDataTypeException(param_name=column_name, param_types="[int, float]")
110+
111+
raise NotImplementedError("Function is not implemented")
112+
85113
def sample_column_name(self) -> str:
86114
"""
87115
Samples a name from the columns
@@ -98,18 +126,23 @@ def set_columns_to_type(self, col_name_types) -> None:
98126
"""
99127
self.ds.astype(dtype=col_name_types)
100128

101-
def attach_column_hierarchy(self, col_name: str, hierarchy: HierarchyBase):
102-
self.column_hierarchy[col_name] = hierarchy
103-
104129
def get_column(self, col_name: str):
130+
"""
131+
Returns the column with the given name
132+
:param col_name:
133+
:return:
134+
"""
105135
return self.ds.loc[:, col_name]
106136

107137
def get_column_unique_values(self, col_name: str):
108-
# what are the unique values?
109-
110-
col = self.get_column(col_name=col_name)
111-
vals = col.values.ravel()
112-
return pd.unique(vals)
138+
"""
139+
Returns the unique values for the column
140+
:param col_name:
141+
:return:
142+
"""
143+
col = self.get_column(col_name=col_name)
144+
vals = col.values.ravel()
145+
return pd.unique(vals)
113146

114147
def get_columns_types(self):
115148
return list(self.ds.dtypes)
@@ -136,7 +169,7 @@ def apply_column_transform(self, column_name: str, transform: Transform) -> None
136169

137170
# get the column
138171
column = self.get_column(col_name=column_name)
139-
column = transform.act(**{"data": column})
172+
column = transform.act(**{"data": column.values})
140173
self.ds[transform.column_name] = column
141174

142175

src/exceptions/exceptions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ def __str__(self):
2323
return self.message
2424

2525

26+
class InvalidDataTypeException(Exception):
27+
def __init__(self, param_name: str, param_types: str):
28+
self.message = "Parameter {0} has invalid type. Type not in {1}".format(param_name, param_types)
29+
30+
def __str__(self):
31+
return self.message
32+
33+
2634
class InvalidSchemaException(Exception):
2735
def __init__(self, message: str) -> None:
2836
self.message = message
@@ -39,6 +47,14 @@ def __str__(self):
3947
return self.message
4048

4149

50+
class IncompatibleVectorSizesException(Exception):
51+
def __iter__(self, size1: int, size2: int) -> None:
52+
self.message = "Size {0} does not match size {1} ".format(size1, size2)
53+
54+
def __str__(self):
55+
return self.message
56+
57+
4258

4359

4460

src/spaces/action_space.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __setitem__(self, key: int, value: ActionBase) -> None:
3838
"""
3939
self.actions[key] = value
4040

41+
def __len__(self) -> int:
42+
return len(self.actions)
43+
4144
def shuffle(self) -> None:
4245
"""
4346
Randomly shuffle the actions in the space
@@ -91,44 +94,4 @@ def sample_and_get(self) -> ActionBase:
9194
action_idx = self.sample()
9295
return self.actions[action_idx]
9396

94-
def get_non_exhausted_actions(self) -> list:
95-
"""
96-
Returns a list of actions that have not exhausted the
97-
transformations that apply on a column.
98-
:return: list of actions. List may be empty. Client code should handle this
99-
"""
100-
actions_ = []
101-
for action in self.actions:
102-
if not action.is_exhausted():
103-
actions_.append(action)
104-
105-
return actions_
106-
107-
def sample_and_get_non_exhausted(self) -> ActionBase:
108-
"""
109-
Sample an action from the non exhausted actions
110-
:return: A non-exhausted action
111-
"""
112-
actions = self.get_non_exhausted_actions()
113-
return np.random.choice(actions)
114-
115-
def is_exhausted(self) -> bool:
116-
"""
117-
Returns true if all the actions in the space are exhausted
118-
:return:
119-
"""
120-
finished = True
121-
for action in self.actions:
122-
if not action.is_exhausted():
123-
return False
124-
125-
return finished
126-
127-
def reset(self) -> None:
128-
"""
129-
Reset every action in the action space
130-
:return:
131-
"""
132-
for action in self.actions:
133-
action.reinitialize()
13497

0 commit comments

Comments
 (0)