Skip to content

Commit e315d3a

Browse files
committed
#32 Update API
1 parent 10cba59 commit e315d3a

File tree

4 files changed

+24
-28
lines changed

4 files changed

+24
-28
lines changed

src/algorithms/q_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
7474

7575
# set the q_table for the policy
7676
self.config.policy.q_table = self.q_table
77-
total_dist = env.total_average_current_distortion()
77+
total_dist = env.total_current_distortion()
7878
while stop_criterion.continue_itr(total_dist):
7979

8080
if stop_criterion.iteration_counter == 12:
@@ -87,7 +87,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
8787
print("{0} At state={1} with distortion={2} select action={3}".format("INFO: ", state_idx, total_dist,
8888
action.column_name + "-" + action.action_type.name))
8989
env.step(action=action)
90-
total_dist = env.total_average_current_distortion()
90+
total_dist = env.total_current_distortion()
9191

9292
def train(self, env: Env, **options) -> tuple:
9393

src/datasets/dataset_wrapper.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def read(self, filename: Path, **options) -> None:
2929

3030

3131
class PandasDSWrapper(DSWrapper[pd.DataFrame]):
32-
3332
"""
3433
Simple wrapper to a pandas DataFrame object.
3534
Facilitates various actions on the original dataset
@@ -60,15 +59,15 @@ def n_columns(self) -> int:
6059
def schema(self) -> dict:
6160
return pd.io.json.build_table_schema(self.ds)
6261

63-
def save_to_csv(self, filename: Path) -> None:
62+
def save_to_csv(self, filename: Path, save_index: bool) -> None:
6463
"""
6564
Save the underlying dataset in a csv format
6665
:param filename:
6766
:return:
6867
"""
69-
self.ds.to_csv(filename)
68+
self.ds.to_csv(filename, index=save_index)
7069

71-
def read(self, filename: Path, **options) -> None:
70+
def read(self, filename: Path, **options) -> None:
7271
"""
7372
Load a data set from a file
7473
:param filename:
@@ -145,14 +144,14 @@ def get_column(self, col_name: str):
145144
return self.ds.loc[:, col_name]
146145

147146
def get_column_unique_values(self, col_name: str):
148-
"""
147+
"""
149148
Returns the unique values for the column
150149
:param col_name:
151150
:return:
152151
"""
153-
col = self.get_column(col_name=col_name)
154-
vals = col.values.ravel()
155-
return pd.unique(vals)
152+
col = self.get_column(col_name=col_name)
153+
vals = col.values.ravel()
154+
return pd.unique(vals)
156155

157156
def get_columns_types(self):
158157
return list(self.ds.dtypes)
@@ -181,8 +180,3 @@ def apply_column_transform(self, column_name: str, transform: Transform) -> None
181180
column = self.get_column(col_name=column_name)
182181
column = transform.act(**{"data": column.values})
183182
self.ds[transform.column_name] = column
184-
185-
186-
187-
188-

src/examples/qlearning_three_columns.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,10 @@ def get_ethinicity_hierarchy():
151151
# create the environment
152152
env = DiscreteStateEnvironment(env_config=env_config)
153153
env.reset()
154-
env.save_current_dataset(episode_index=-1)
155154

156-
# save the original dataset for comparison
157-
env.save_current_dataset(episode_index=-1)
158-
env.reset()
155+
# save the data before distortion so that we can
156+
# later load it on ARX
157+
env.save_current_dataset(episode_index=-1, save_index=False)
159158

160159
# configuration for the Q-learner
161160
algo_config = QLearnConfig()
@@ -195,7 +194,8 @@ def get_ethinicity_hierarchy():
195194

196195
stop_criterion = IterationControl(n_itrs=10, min_dist=MIN_DISTORTION, max_dist=MAX_DISTORTION)
197196
agent.play(env=env, stop_criterion=stop_criterion)
198-
env.save_current_dataset(episode_index=-2)
199-
197+
env.save_current_dataset(episode_index=-2, save_index=False)
198+
print("{0} Done....".format(INFO))
199+
print("=============================================")
200200

201201

src/spaces/discrete_state_environment.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,16 @@ def n_states(self) -> int:
128128
def get_action(self, aidx: int) -> ActionBase:
129129
return self.config.action_space[aidx]
130130

131-
def save_current_dataset(self, episode_index: int) -> None:
131+
def save_current_dataset(self, episode_index: int, save_index: bool = False) -> None:
132132
"""
133133
Save the current distorted datase for the given episode index
134134
:param episode_index:
135+
:param save_index:
135136
:return:
136137
"""
137138
self.distorted_data_set.save_to_csv(
138-
filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)))
139+
filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)),
140+
save_index=save_index)
139141

140142
def create_bins(self) -> None:
141143
"""
@@ -216,15 +218,14 @@ def apply_action(self, action: ActionBase):
216218

217219
self.column_distances[action.column_name] = distance
218220

219-
def total_average_current_distortion(self) -> float:
221+
def total_current_distortion(self) -> float:
220222
"""
221-
Calculates the average total distortion of the dataset
222-
by summing over the current computed distances for each column
223+
Calculates the current total distortion of the dataset.
223224
:return:
224225
"""
225226

226227
return self.config.distortion_calculator.total_distortion(
227-
list(self.column_distances.values())) # float(np.mean(list(self.column_distances.values())))
228+
list(self.column_distances.values()))
228229

229230
def reset(self, **options) -> TimeStep:
230231
"""
@@ -270,7 +271,7 @@ def step(self, action: ActionBase) -> TimeStep:
270271
self.apply_action(action=action)
271272

272273
# calculate the distortion of the dataset
273-
current_distortion = self.total_average_current_distortion()
274+
current_distortion = self.total_current_distortion()
274275

275276
# get the reward for the current distortion
276277
reward = self.config.reward_manager.get_reward_for_state(state=current_distortion, **{"action": action})
@@ -312,6 +313,7 @@ def step(self, action: ActionBase) -> TimeStep:
312313

313314
# TODO: these modifications will cause the agent to always
314315
# move close to transition points
316+
# TODO: Remove the magic constants
315317
if next_state is not None and self.current_time_step.observation is not None:
316318
if next_state < min_dist_bin <= self.current_time_step.observation:
317319
# the agent chose to step into the chaos again

0 commit comments

Comments
 (0)