@@ -130,11 +130,23 @@ def reset(self, **options) -> TimeStep:
130130 return time_step
131131
132132 def get_action (self , aidx : int ) -> ActionBase :
133+ """Returns the action that corresponds to the given index
134+
135+ Parameters
136+ ----------
137+ aidx: The index of the action to return
138+
139+ Returns
140+ -------
141+
142+ An instance of the ActionBase class
143+
144+ """
133145 return self .env .action_space [aidx ]
134146
135147 def save_current_dataset (self , episode_index : int , save_index : bool = False ) -> None :
136- """
137- Save the current data set at the given episode index
148+ """Save the current data set at the given episode index
149+
138150 Parameters
139151 ----------
140152
@@ -150,65 +162,107 @@ def save_current_dataset(self, episode_index: int, save_index: bool = False) ->
150162 self .env .save_current_dataset (episode_index , save_index )
151163
152164 def create_bins (self ) -> None :
153- """
154- Create the bins
155- :return:
165+ """Create the bins
166+
167+ Returns
168+ -------
169+
170+ None
171+
156172 """
157173 self .env .create_bins ()
158174
159175 def get_aggregated_state (self , state_val : float ) -> int :
160176 """
161177 Returns the bin index that the state_val corresponds to
162- :param state_val: The value of the state. This typically will be
163- either a column normalized distortion value or the dataset average total
164- distortion
165- :return:
178+
179+ Parameters
180+ ----------
181+
182+ state_val: The bin index that the distortion corresponds to
183+
184+ Returns
185+ -------
186+
187+ The bin index corresponding to the distortion
166188 """
167189 return self .env .get_aggregated_state (state_val )
168190
169191 def initialize_column_counts (self ) -> None :
170192 """
171193 Set the column visit counts to zero
172- :return:
194+ Returns
195+ -------
196+
197+ None
198+
173199 """
174200 self .env .initialize_column_counts ()
175201
176202 def all_columns_visited (self ) -> bool :
177203 """
178204 Returns True is all column counts are greater than zero
179- :return:
205+
206+ Returns
207+ -------
208+
209+ Returns True is all column counts are greater than zero
210+ otherwise False
211+
180212 """
181213 return self .env .all_columns_visited ()
182214
183215 def initialize_distances (self ) -> None :
184216 """
185217 Initialize the text distances for features of type string. We set the
186218 normalized distance to 0.0 meaning that no distortion is assumed initially
187- :return: None
219+
220+ Returns
221+ -------
222+
223+ None
224+
188225 """
189226 self .env .initialize_distances ()
190227
191- def apply_action (self , action : ActionBase ):
192- """
193- Apply the action on the environment
194- :param action: The action to apply on the environment
195- :return:
228+ def apply_action (self , action : ActionBase ) -> None :
229+ """Apply the action on the environment
230+
231+ Parameters
232+ ----------
233+ action: The action to apply
234+
235+ Returns
236+ -------
237+
238+ None
239+
196240 """
197241 self .env .apply_action (action )
198242
199243 def total_current_distortion (self ) -> float :
200- """
201- Calculates the current total distortion of the dataset.
202- :return:
244+ """Calculates the current total distortion of the dataset.
245+
246+ Returns
247+ -------
248+ The total current distortion
249+
203250 """
204251 return self .env .total_current_distortion ()
205252
206253 def get_scaled_state (self , state : State ) -> list :
207- """
208- Scales the state components ad returns the
254+ """Scales the state components and returns the
209255 scaled state
210- :param state:
211- :return:
256+
257+ Parameters
258+ ----------
259+ state: The state to scale
260+
261+ Returns
262+ -------
263+
264+ A list of scaled state values
265+
212266 """
213267 scaled_state_vals = []
214268 for name in state :
0 commit comments