1515from src .spaces .actions import ActionBase , ActionType
1616from src .spaces .state_space import StateSpace , State
1717from src .utils .string_distance_calculator import DistanceType , TextDistanceCalculator
18+ from src .utils .numeric_distance_type import NumericDistanceType
19+ from src .datasets .dataset_information_leakage import state_leakage
1820
1921DataSet = TypeVar ("DataSet" )
2022RewardManager = TypeVar ("RewardManager" )
@@ -77,6 +79,7 @@ def __init__(self):
7779 self .average_distortion_constraint : float = 0
7880 self .start_column : str = "None_Column"
7981 self .gamma : float = 0.99
82+ self .numeric_column_distortion_metric_type : NumericDistanceType = NumericDistanceType .INVALID
8083
8184
8285class Environment (object ):
@@ -99,6 +102,7 @@ def __init__(self, env_config: EnvConfig):
99102 self .state_space = StateSpace ()
100103 self .distance_calculator = None
101104 self .reward_manager : RewardManager = env_config .reward_manager
105+ self .numeric_column_distortion_metric_type = env_config .numeric_column_distortion_metric_type
102106
103107 # initialize the state space
104108 self .state_space .init_from_environment (env = self )
@@ -219,15 +223,26 @@ def prepare_column_state(self, column_name):
219223 start_column = self .start_ds .get_column (col_name = column_name )
220224
221225 row_count = 0
222- print ("Distance {0} " .format (self .distance_calculator .calculate (txt1 = "" .join (current_column .values ),
223- txt2 = "" .join (start_column .values ))))
224226
227+ # join the column to calculate the distance
225228 self .column_distances [column_name ] = self .distance_calculator .calculate (txt1 = "" .join (current_column .values ),
226229 txt2 = "" .join (start_column .values ))
227- #for item1, item2 in zip(current_column.values, start_column.values):
228- # #self.column_distances[column_name][row_count] = self.distance_calculator.calculate(txt1=item1, txt2=item2)
229230
230- # row_count += 1
231+ def get_state_distortion (self , state_name ) -> float :
232+ """
233+ Returns the distortion for the state with the given name
234+ :param state_name:
235+ :return:
236+ """
237+ if self .start_ds .columns [state_name ] == str :
238+ return self .column_distances [state_name ]
239+ else :
240+
241+ current_column = self .data_set .get_column (col_name = state_name )
242+ start_column = self .start_ds .get_column (col_name = state_name )
243+
244+ return state_leakage (state1 = current_column ,
245+ state2 = start_column , dist_type = self .numeric_column_distortion_metric_type )
231246
232247 def prepare_columns_state (self ):
233248 """
@@ -299,6 +314,7 @@ def apply_action(self, action: ActionBase):
299314 :return:
300315 """
301316
317+ # nothing to act on identity
302318 if action .action_type == ActionType .IDENTITY :
303319 return
304320
@@ -333,14 +349,17 @@ def step(self, action: ActionBase) -> TimeStep:
333349 # update the state space
334350 self .state_space .update_state (state_name = action .column_name , status = action .action_type )
335351
352+ # prepare the column state. We only do work
353+ # if the column is a string
336354 self .prepare_column_state (column_name = action .column_name )
337355
338356 # perform the action on the data set
339357 #self.prepare_columns_state()
340358
341359 # calculate the information leakage and establish the reward
342360 # to return to the agent
343- reward = self .reward_manager .get_state_reward (self .state_space , action )
361+ state_distortion = self .get_state_distortion (state_name = action .column_name )
362+ reward = self .reward_manager .get_state_reward (action .column_name , action , state_distortion )
344363
345364 # what is the next state? maybe do it randomly?
346365 # or select the next column in the dataset
0 commit comments