Skip to content

Commit 24d6bfa

Browse files
authored
Merge pull request #25 from pockerman/investigate_q_learning
Investigate q learning
2 parents 549a042 + ded1788 commit 24d6bfa

File tree

9 files changed

+158
-32
lines changed

9 files changed

+158
-32
lines changed

src/algorithms/q_learning.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ def actions_after_episode_ends(self, **options):
5656

5757
self.config.policy.actions_after_episode(options['episode_idx'])
5858

59+
def play(self, env: Env) -> None:
60+
"""
61+
Play the game on the environment. This should produce
62+
a distorted dataset
63+
:param env:
64+
:return:
65+
"""
66+
67+
# loop over the columns and for the
68+
# column get the action that corresponds to
69+
# the max payout.
70+
# TODO: This will no work as the distortion is calculated
71+
# by summing over the columns.
72+
raise NotImplementedError("Function not implemented")
73+
5974
def train(self, env: Env, **options) -> tuple:
6075

6176
# episode score
@@ -73,10 +88,10 @@ def train(self, env: Env, **options) -> tuple:
7388

7489
action = env.get_action(action_idx)
7590

76-
if action.action_type.name == "GENERALIZE" and action.column_name == "salary":
77-
print("Attempt to generalize salary")
78-
else:
79-
print(action.action_type.name, " on ", action.column_name)
91+
#if action.action_type.name == "GENERALIZE" and action.column_name == "salary":
92+
# print("Attempt to generalize salary")
93+
#else:
94+
# print(action.action_type.name, " on ", action.column_name)
8095

8196
# take action A, observe R, S'
8297
next_time_step = env.step(action)

src/datasets/dataset_wrapper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,17 @@ def normalize_column(self, column_name) -> None:
106106
"""
107107

108108
data_type = self.columns[column_name]
109-
if data_type is not int or data_type is not float:
110-
raise InvalidDataTypeException(param_name=column_name, param_types="[int, float]")
109+
110+
if data_type is not type(1) and data_type is not type(1.0):
111+
raise InvalidDataTypeException(param_name=column_name, param_type=data_type, param_types="[int, float]")
111112

112113
col_vals = self.get_column(col_name=column_name).values
113114

114115
min_val = np.min(col_vals)
115116
max_val = np.max(col_vals)
116117

117118
for i in range(len(col_vals)):
118-
col_vals[i] = (col_vals[i] - min_val) / (max_val - min_val)
119+
col_vals[i] = float((col_vals[i] - min_val)) / float((max_val - min_val))
119120

120121
self.ds[column_name] = col_vals
121122

src/exceptions/exceptions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13

24
class Error(Exception):
35
"""
@@ -24,8 +26,8 @@ def __str__(self):
2426

2527

2628
class InvalidDataTypeException(Exception):
27-
def __init__(self, param_name: str, param_types: str):
28-
self.message = "Parameter {0} has invalid type. Type not in {1}".format(param_name, param_types)
29+
def __init__(self, param_name: str, param_type: Any, param_types: str):
30+
self.message = "Parameter {0} has invalid type. Type {1} not in {2}".format(param_name, str(Any), param_types)
2931

3032
def __str__(self):
3133
return self.message
@@ -48,7 +50,7 @@ def __str__(self):
4850

4951

5052
class IncompatibleVectorSizesException(Exception):
51-
def __iter__(self, size1: int, size2: int) -> None:
53+
def __init__(self, size1: int, size2: int) -> None:
5254
self.message = "Size {0} does not match size {1} ".format(size1, size2)
5355

5456
def __str__(self):

src/spaces/action_space.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import random
88
from gym.spaces.discrete import Discrete
9-
from src.spaces.actions import ActionBase
9+
from src.spaces.actions import ActionBase, ActionType
1010

1111

1212
class ActionSpace(Discrete):
@@ -48,7 +48,7 @@ def shuffle(self) -> None:
4848
"""
4949
random.shuffle(self.actions)
5050

51-
def get_action_by_column_name(self, column_name: str) -> ActionBase:
51+
def get_action_by_name_and_type(self, column_name: str, action_type: ActionType) -> ActionBase:
5252
"""
5353
Get the action that corresponds to the column with
5454
the given name. Raises ValueError if such an action does not
@@ -58,10 +58,11 @@ def get_action_by_column_name(self, column_name: str) -> ActionBase:
5858
"""
5959

6060
for action in self.actions:
61-
if action.column_name == column_name:
61+
if action.column_name == column_name and \
62+
action.action_type == action_type:
6263
return action
6364

64-
raise ValueError("No action exists for column={0}".format(column_name))
65+
raise ValueError("No action exists for column={0} with type {1}".format(column_name, action_type.name))
6566

6667
def add(self, action: ActionBase) -> None:
6768
"""

src/spaces/discrete_state_environment.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from src.utils.numeric_distance_type import NumericDistanceType
1818
from src.utils.numeric_distance_calculator import NumericDistanceCalculator
1919

20-
2120
DataSet = TypeVar("DataSet")
2221
RewardManager = TypeVar("RewardManager")
2322
ActionSpace = TypeVar("ActionSpace")
23+
DistortionCalculator = TypeVar('DistortionCalculator')
2424

2525
_Reward = TypeVar('_Reward')
2626
_Discount = TypeVar('_Discount')
@@ -72,33 +72,36 @@ class DiscreteEnvConfig(object):
7272
"""
7373
Configuration for discrete environment
7474
"""
75+
7576
def __init__(self) -> None:
7677
self.data_set: DataSet = None
7778
self.action_space: ActionSpace = None
7879
self.reward_manager: RewardManager = None
7980
self.average_distortion_constraint: float = 0.0
8081
self.gamma: float = 0.99
81-
self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID
82-
self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
82+
# self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID
83+
# self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
8384
self.n_states: int = 10
8485
self.min_distortion: float = 0.4
8586
self.max_distortion: float = 0.7
8687
self.n_rounds_below_min_distortion: int = 10
8788
self.distorted_set_path: Path = None
89+
self.distortion_calculator: DistortionCalculator = None
8890

8991

9092
class DiscreteStateEnvironment(object):
9193
"""
9294
The DiscreteStateEnvironment class. Uses state aggregation in order
9395
to create bins where the average total distortion of the dataset falls in
9496
"""
97+
9598
def __init__(self, env_config: DiscreteEnvConfig) -> None:
9699
self.config = env_config
97100
self.n_rounds_below_min_distortion = 0
98101
self.state_bins: List[float] = []
99102
self.distorted_data_set = copy.deepcopy(self.config.data_set)
100103
self.current_time_step: TimeStep = None
101-
self.string_distance_calculator: TextDistanceCalculator = None
104+
# self.string_distance_calculator: TextDistanceCalculator = None
102105

103106
# dictionary that holds the distortion for every column
104107
# in the dataset
@@ -126,7 +129,8 @@ def get_action(self, aidx: int) -> ActionBase:
126129
return self.config.action_space[aidx]
127130

128131
def save_current_dataset(self, episode_index: int) -> None:
129-
self.distorted_data_set.save_to_csv(filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)))
132+
self.distorted_data_set.save_to_csv(
133+
filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)))
130134

131135
def create_bins(self) -> None:
132136
"""
@@ -167,7 +171,7 @@ def initialize_distances(self) -> None:
167171
normalized distance to 0.0 meaning that no distortion is assumed initially
168172
:return: None
169173
"""
170-
self.string_distance_calculator = TextDistanceCalculator(dist_type=self.config.string_column_distortion_type)
174+
# self.string_distance_calculator = TextDistanceCalculator(dist_type=self.config.string_column_distortion_type)
171175
col_names = self.config.data_set.get_columns_names()
172176
for col in col_names:
173177
self.column_distances[col] = 0.0
@@ -194,14 +198,21 @@ def apply_action(self, action: ActionBase):
194198
current_column = self.distorted_data_set.get_column(col_name=action.column_name)
195199
start_column = self.config.data_set.get_column(col_name=action.column_name)
196200

201+
datatype = 'float'
197202
# calculate column distortion
198203
if self.distorted_data_set.columns[action.column_name] == str:
204+
current_column = "".join(current_column.values)
205+
start_column = "".join(start_column.values)
206+
datatype = 'str'
207+
199208
# join the column to calculate the distance
200-
distance = self.string_distance_calculator.calculate(txt1="".join(current_column.values),
201-
txt2="".join(start_column.values))
202-
else:
203-
distance = NumericDistanceCalculator(dist_type=self.config.numeric_column_distortion_metric_type)\
204-
.calculate(state1=current_column, state2=start_column)
209+
# distance = self.string_distance_calculator.calculate(txt1="".join(current_column.values),
210+
# txt2="".join(start_column.values))
211+
# else:
212+
# distance = NumericDistanceCalculator(dist_type=self.config.numeric_column_distortion_metric_type)\
213+
# .calculate(state1=current_column, state2=start_column)
214+
215+
distance = self.config.distortion_calculator.calculate(current_column, start_column, datatype)
205216

206217
self.column_distances[action.column_name] = distance
207218

@@ -212,7 +223,8 @@ def total_average_current_distortion(self) -> float:
212223
:return:
213224
"""
214225

215-
return float(np.mean(list(self.column_distances.values())))
226+
return self.config.distortion_calculator.total_distortion(
227+
list(self.column_distances.values())) # float(np.mean(list(self.column_distances.values())))
216228

217229
def reset(self, **options) -> TimeStep:
218230
"""
@@ -294,16 +306,43 @@ def step(self, action: ActionBase) -> TimeStep:
294306
step_type = StepType.MID
295307
next_state = self.get_aggregated_state(state_val=current_distortion)
296308

309+
# get the bin for the min distortion
310+
min_dist_bin = self.get_aggregated_state(state_val=self.config.min_distortion)
311+
max_dist_bin = self.get_aggregated_state(state_val=self.config.max_distortion)
312+
313+
# TODO: these modifications will cause the agent to always
314+
# move close to transition points
315+
if next_state < min_dist_bin <= self.current_time_step.observation:
316+
# the agent chose to step into the chaos again
317+
# we punish him with double the reward
318+
reward = 2.0 * self.config.reward_manager.out_of_min_bound_reward
319+
elif next_state > max_dist_bin >= self.current_time_step.observation:
320+
# the agent is going to chaos from above
321+
# punish him
322+
reward = 2.0 * self.config.reward_manager.out_of_max_bound_reward
323+
324+
elif next_state >= min_dist_bin > self.current_time_step.observation:
325+
# the agent goes towards the transition of min point so give a higher reward
326+
# for this
327+
reward = 0.95 * self.config.reward_manager.in_bounds_reward
328+
329+
elif next_state <= max_dist_bin < self.current_time_step.observation:
330+
# the agent goes towards the transition of max point so give a higher reward
331+
# for this
332+
reward = 0.95 * self.config.reward_manager.in_bounds_reward
333+
297334
if next_state >= self.n_states:
298335
done = True
299336

300337
if done:
301338
step_type = StepType.LAST
302339
next_state = None
303340

304-
return TimeStep(step_type=step_type, reward=reward,
305-
observation=next_state,
306-
discount=self.config.gamma, info={"total_distortion": current_distortion})
341+
self.current_time_step = TimeStep(step_type=step_type, reward=reward,
342+
observation=next_state,
343+
discount=self.config.gamma, info={"total_distortion": current_distortion})
344+
345+
return self.current_time_step
307346

308347

309348
class MultiprocessEnv(object):

src/utils/distortion_calculator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Utilities for dataset distortion calculation
3+
"""
4+
import enum
5+
from typing import TypeVar
6+
from src.utils.numeric_distance_type import NumericDistanceType
7+
from src.utils.numeric_distance_calculator import NumericDistanceCalculator
8+
from src.utils.string_distance_calculator import StringDistanceType, TextDistanceCalculator
9+
from src.exceptions.exceptions import InvalidParamValue
10+
11+
Vector = TypeVar('Vector')
12+
13+
14+
class DistortionCalculationType(enum.IntEnum):
15+
"""
16+
17+
"""
18+
19+
INVALID = -1
20+
SUM = 0
21+
AVG = 1
22+
23+
24+
class DistortionCalculator(object):
25+
26+
def __init__(self, numeric_column_distortion_metric_type: NumericDistanceType,
27+
string_column_distortion_metric_type: StringDistanceType,
28+
dataset_distortion_type: DistortionCalculationType):
29+
self.numeric_column_distortion_metric_type = numeric_column_distortion_metric_type
30+
self.string_column_distortion_metric_type = string_column_distortion_metric_type
31+
self.dataset_distortion_type = dataset_distortion_type
32+
33+
def calculate(self, vec1: Vector, vec2: Vector, datatype: str) -> float:
34+
35+
if datatype == 'str':
36+
return TextDistanceCalculator(dist_type=self.string_column_distortion_metric_type).calculate(txt1=vec1,
37+
txt2=vec2)
38+
elif datatype == 'float' or datatype == 'int':
39+
return NumericDistanceCalculator(dist_type=self.numeric_column_distortion_metric_type).calculate(state1=vec1,
40+
state2=vec2)
41+
raise InvalidParamValue(param_name='datatype', param_value=datatype)
42+
43+
def total_distortion(self, distortions: Vector) -> float:
44+
45+
if self.dataset_distortion_type == DistortionCalculationType.SUM:
46+
return float(sum(distortions))
47+
elif self.dataset_distortion_type == DistortionCalculationType.AVG:
48+
return float(sum(distortions) / len(distortions))
49+
50+
raise InvalidParamValue(param_name='dataset_distortion_type', param_value=self.dataset_distortion_type.name)
51+

src/utils/mixins.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def max_action(self, state: Any, n_actions: int) -> int:
7777
:param n_actions: Total number of actions allowed
7878
:return: The action that corresponds to the maximum value
7979
"""
80-
values = np.array(self.q_table[state, a] for a in range(n_actions))
80+
values = [self.q_table[state, a] for a in range(n_actions)]
81+
values = np.array(values)
8182
action = np.argmax(values)
8283
return int(action)

src/utils/numeric_distance_calculator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,27 @@ def _numeric_distance_calculator(state1: Vector, state2: Vector, dist_type: Nume
2525
raise IncompatibleVectorSizesException(size1=len(state1), size2=len(state2))
2626

2727
if dist_type == NumericDistanceType.L1:
28-
return _l1_state_leakage(state1=state1, state2=state2)
28+
return np.linalg.norm(state1 - state2, ord=1)
2929
elif dist_type == NumericDistanceType.L2:
30-
return _l1_state_leakage(state1=state1, state2=state2)
30+
return np.linalg.norm(state1 - state2, ord=None)
3131
elif dist_type == NumericDistanceType.L2_NORMALIZED:
3232
return _normalized_l2_distance(state1=state1, state2=state2)
33+
elif dist_type == NumericDistanceType.L2_AVG:
34+
return _avg_l2_distance(state1=state1, state2=state2)
3335

3436
raise InvalidParamValue(param_name="dist_type", param_value=dist_type.name)
3537

3638

39+
def _avg_l2_distance(state1: Vector, state2: Vector) -> float:
40+
41+
size = len(state1)
42+
dist = 0.0
43+
for item1, item2 in zip(state1, state2):
44+
dist += ((item1 - item2) ** 2)
45+
46+
return np.sqrt(dist / float(size))
47+
48+
3749
def _normalized_l2_distance(state1: Vector, state2: Vector) -> float:
3850
"""
3951
Returns the normalized L2 norm between the two vectors
@@ -49,8 +61,11 @@ def _normalized_l2_distance(state1: Vector, state2: Vector) -> float:
4961

5062
return np.sqrt(dist)
5163

64+
65+
"""
5266
def _l2_state_leakage(state1: Vector, state2: Vector) -> float:
5367
return np.linalg.norm(state1 - state2, ord=None)
5468
5569
def _l1_state_leakage(state1: Vector, state2: Vector) -> float:
5670
return np.linalg.norm(state1 - state2, ord=1)
71+
"""

src/utils/numeric_distance_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ class NumericDistanceType(enum.IntEnum):
1616
L2 = 1
1717
L2_NORMALIZED = 2
1818
L1_NORMALIZED = 3
19+
L2_AVG = 4

0 commit comments

Comments
 (0)