1717from src .utils .numeric_distance_type import NumericDistanceType
1818from src .utils .numeric_distance_calculator import NumericDistanceCalculator
1919
20-
2120DataSet = TypeVar ("DataSet" )
2221RewardManager = TypeVar ("RewardManager" )
2322ActionSpace = TypeVar ("ActionSpace" )
23+ DistortionCalculator = TypeVar ('DistortionCalculator' )
2424
2525_Reward = TypeVar ('_Reward' )
2626_Discount = TypeVar ('_Discount' )
@@ -72,33 +72,36 @@ class DiscreteEnvConfig(object):
7272 """
7373 Configuration for discrete environment
7474 """
75+
7576 def __init__ (self ) -> None :
7677 self .data_set : DataSet = None
7778 self .action_space : ActionSpace = None
7879 self .reward_manager : RewardManager = None
7980 self .average_distortion_constraint : float = 0.0
8081 self .gamma : float = 0.99
81- self .string_column_distortion_type : StringDistanceType = StringDistanceType .INVALID
82- self .numeric_column_distortion_metric_type : NumericDistanceType = NumericDistanceType .INVALID
82+ # self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID
83+ # self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
8384 self .n_states : int = 10
8485 self .min_distortion : float = 0.4
8586 self .max_distortion : float = 0.7
8687 self .n_rounds_below_min_distortion : int = 10
8788 self .distorted_set_path : Path = None
89+ self .distortion_calculator : DistortionCalculator = None
8890
8991
9092class DiscreteStateEnvironment (object ):
9193 """
9294 The DiscreteStateEnvironment class. Uses state aggregation in order
9395 to create bins where the average total distortion of the dataset falls in
9496 """
97+
9598 def __init__ (self , env_config : DiscreteEnvConfig ) -> None :
9699 self .config = env_config
97100 self .n_rounds_below_min_distortion = 0
98101 self .state_bins : List [float ] = []
99102 self .distorted_data_set = copy .deepcopy (self .config .data_set )
100103 self .current_time_step : TimeStep = None
101- self .string_distance_calculator : TextDistanceCalculator = None
104+ # self.string_distance_calculator: TextDistanceCalculator = None
102105
103106 # dictionary that holds the distortion for every column
104107 # in the dataset
@@ -126,7 +129,8 @@ def get_action(self, aidx: int) -> ActionBase:
126129 return self .config .action_space [aidx ]
127130
128131 def save_current_dataset (self , episode_index : int ) -> None :
129- self .distorted_data_set .save_to_csv (filename = Path (str (self .config .distorted_set_path ) + "_" + str (episode_index )))
132+ self .distorted_data_set .save_to_csv (
133+ filename = Path (str (self .config .distorted_set_path ) + "_" + str (episode_index )))
130134
131135 def create_bins (self ) -> None :
132136 """
@@ -167,7 +171,7 @@ def initialize_distances(self) -> None:
167171 normalized distance to 0.0 meaning that no distortion is assumed initially
168172 :return: None
169173 """
170- self .string_distance_calculator = TextDistanceCalculator (dist_type = self .config .string_column_distortion_type )
174+ # self.string_distance_calculator = TextDistanceCalculator(dist_type=self.config.string_column_distortion_type)
171175 col_names = self .config .data_set .get_columns_names ()
172176 for col in col_names :
173177 self .column_distances [col ] = 0.0
@@ -194,14 +198,21 @@ def apply_action(self, action: ActionBase):
194198 current_column = self .distorted_data_set .get_column (col_name = action .column_name )
195199 start_column = self .config .data_set .get_column (col_name = action .column_name )
196200
201+ datatype = 'float'
197202 # calculate column distortion
198203 if self .distorted_data_set .columns [action .column_name ] == str :
204+ current_column = "" .join (current_column .values )
205+ start_column = "" .join (start_column .values )
206+ datatype = 'str'
207+
199208 # join the column to calculate the distance
200- distance = self .string_distance_calculator .calculate (txt1 = "" .join (current_column .values ),
201- txt2 = "" .join (start_column .values ))
202- else :
203- distance = NumericDistanceCalculator (dist_type = self .config .numeric_column_distortion_metric_type )\
204- .calculate (state1 = current_column , state2 = start_column )
209+ # distance = self.string_distance_calculator.calculate(txt1="".join(current_column.values),
210+ # txt2="".join(start_column.values))
211+ # else:
212+ # distance = NumericDistanceCalculator(dist_type=self.config.numeric_column_distortion_metric_type)\
213+ # .calculate(state1=current_column, state2=start_column)
214+
215+ distance = self .config .distortion_calculator .calculate (current_column , start_column , datatype )
205216
206217 self .column_distances [action .column_name ] = distance
207218
@@ -212,7 +223,8 @@ def total_average_current_distortion(self) -> float:
212223 :return:
213224 """
214225
215- return float (np .mean (list (self .column_distances .values ())))
226+ return self .config .distortion_calculator .total_distortion (
227+ list (self .column_distances .values ())) # float(np.mean(list(self.column_distances.values())))
216228
217229 def reset (self , ** options ) -> TimeStep :
218230 """
@@ -294,16 +306,43 @@ def step(self, action: ActionBase) -> TimeStep:
294306 step_type = StepType .MID
295307 next_state = self .get_aggregated_state (state_val = current_distortion )
296308
309+ # get the bin for the min distortion
310+ min_dist_bin = self .get_aggregated_state (state_val = self .config .min_distortion )
311+ max_dist_bin = self .get_aggregated_state (state_val = self .config .max_distortion )
312+
313+ # TODO: these modifications will cause the agent to always
314+ # move close to transition points
315+ if next_state < min_dist_bin <= self .current_time_step .observation :
316+ # the agent chose to step into the chaos again
317+ # we punish him with double the reward
318+ reward = 2.0 * self .config .reward_manager .out_of_min_bound_reward
319+ elif next_state > max_dist_bin >= self .current_time_step .observation :
320+ # the agent is going to chaos from above
321+ # punish him
322+ reward = 2.0 * self .config .reward_manager .out_of_max_bound_reward
323+
324+ elif next_state >= min_dist_bin > self .current_time_step .observation :
325+ # the agent goes towards the transition of min point so give a higher reward
326+ # for this
327+ reward = 0.95 * self .config .reward_manager .in_bounds_reward
328+
329+ elif next_state <= max_dist_bin < self .current_time_step .observation :
330+ # the agent goes towards the transition of max point so give a higher reward
331+ # for this
332+ reward = 0.95 * self .config .reward_manager .in_bounds_reward
333+
297334 if next_state >= self .n_states :
298335 done = True
299336
300337 if done :
301338 step_type = StepType .LAST
302339 next_state = None
303340
304- return TimeStep (step_type = step_type , reward = reward ,
305- observation = next_state ,
306- discount = self .config .gamma , info = {"total_distortion" : current_distortion })
341+ self .current_time_step = TimeStep (step_type = step_type , reward = reward ,
342+ observation = next_state ,
343+ discount = self .config .gamma , info = {"total_distortion" : current_distortion })
344+
345+ return self .current_time_step
307346
308347
309348class MultiprocessEnv (object ):
0 commit comments