1313
1414from src .exceptions .exceptions import Error
1515from src .spaces .actions import ActionBase , ActionType
16+ from src .spaces .state_space import StateSpace , State
1617from src .utils .string_distance_calculator import DistanceType , TextDistanceCalculator
1718
1819DataSet = TypeVar ("DataSet" )
20+ RewardManager = TypeVar ("RewardManager" )
1921
2022_Reward = TypeVar ('_Reward' )
2123_Discount = TypeVar ('_Discount' )
@@ -65,20 +67,37 @@ def last(self) -> bool:
6567class Environment (object ):
6668
6769 def __init__ (self , data_set , action_space ,
68- gamma : float , start_column : str , ):
70+ gamma : float , start_column : str , reward_manager : RewardManager ):
6971 self .data_set = data_set
7072 self .start_ds = copy .deepcopy (data_set )
7173 self .current_time_step = self .start_ds
7274 self .action_space = action_space
7375 self .gamma = gamma
7476 self .start_column = start_column
7577 self .column_distances = {}
78+ self .state_space = StateSpace ()
7679 self .distance_calculator = None
80+ self .reward_manager : RewardManager = reward_manager
81+
82+ # initialize the state space
83+ self .state_space .init_from_environment (env = self )
7784
7885 @property
7986 def n_features (self ) -> int :
87+ """
88+ Returns the number of features in the dataset
89+ :return:
90+ """
8091 return self .start_ds .n_columns
8192
93+ @property
94+ def feature_names (self ) -> list :
95+ """
96+ Returns the feature names in the dataset
97+ :return:
98+ """
99+ return self .start_ds .get_columns_names ()
100+
82101 @property
83102 def n_examples (self ) -> int :
84103 return self .start_ds .n_rows
@@ -99,6 +118,24 @@ def initialize_text_distances(self, distance_type: DistanceType) -> None:
99118 def sample_action (self ) -> ActionBase :
100119 return self .action_space .sample_and_get ()
101120
121+ def get_column_as_tensor (self , column_name ) -> torch .Tensor :
122+ """
123+ Returns the column in the dataset as a torch tensor
124+ :param column_name:
125+ :return:
126+ """
127+ data = {}
128+
129+ if self .start_ds .columns [column_name ] == str :
130+
131+ numpy_vals = self .column_distances [column_name ]
132+ data [column_name ] = numpy_vals
133+ else :
134+ data [column_name ] = self .data_set .get_column (col_name = column_name ).to_numpy ()
135+
136+ target_df = pd .DataFrame (data )
137+ return torch .tensor (target_df .to_numpy (), dtype = torch .float64 )
138+
102139 def get_ds_as_tensor (self ) -> torch .Tensor :
103140
104141 """
@@ -111,7 +148,6 @@ def get_ds_as_tensor(self) -> torch.Tensor:
111148 for col in col_names :
112149
113150 if self .start_ds .columns [col ] == str :
114- #print("col: {0} type {1}".format(col, self.start_ds.get_column_type(col_name=col)))
115151 numpy_vals = self .column_distances [col ]
116152 data [col ] = numpy_vals
117153 else :
@@ -195,28 +231,22 @@ def step(self, action: ActionBase) -> TimeStep:
195231 `action` will be ignored.
196232 """
197233
234+ # apply the action
198235 self .apply_action (action = action )
199236
200- # if the action is identity don't bother
201- # doing anything
202- #if action.action_type == ActionType.IDENTITY:
203- # return TimeStep(step_type=StepType.MID, reward=0.0,
204- # observation=self.get_ds_as_tensor().float(), discount=self.gamma)
205-
206- # apply the transform of the data set
207- #self.data_set.apply_column_transform(transform=action)
237+ # update the state space
238+ self .state_space .update_state (state_name = action .column_name , status = action .action_type )
208239
209240 # perform the action on the data set
210241 self .prepare_column_states ()
211242
212243 # calculate the information leakage and establish the reward
213244 # to return to the agent
245+ reward = self .reward_manager .get_state_reward (self .state_space , action )
214246
215- return TimeStep (step_type = StepType .MID , reward = 0.0 ,
216- observation = self .get_ds_as_tensor ().float (), discount = self .gamma )
217-
218-
219-
247+ return TimeStep (step_type = StepType .MID , reward = reward ,
248+ observation = self .get_column_as_tensor (column_name = action .column_name ).float (),
249+ discount = self .gamma )
220250
221251
222252class MultiprocessEnv (object ):
0 commit comments