33"""
44
55import copy
6- from typing import TypeVar
6+ from typing import TypeVar , List
77from dataclasses import dataclass
88from src .extern .tile_coding import IHT , tiles
99from src .spaces .actions import ActionBase , ActionType
1313from src .spaces .time_step import copy_time_step
1414
1515Env = TypeVar ('Env' )
16+ Tile = TypeVar ('Tile' )
17+ Config = TypeVar ('Config' )
1618
1719
1820@dataclass (init = True , repr = True )
@@ -24,10 +26,14 @@ class TiledEnvConfig(object):
2426 num_tilings : int = 0
2527 max_size : int = 0
2628 tiling_dim : int = 0
27- column_scales : dict = None
29+ column_ranges : dict = None
2830
2931
3032class TiledEnv (object ):
33+ """The TiledEnv class. It models a tiled
34+ environment
35+ """
36+
3137 IS_TILED_ENV_CONSTRAINT = True
3238
3339 def __init__ (self , config : TiledEnvConfig ) -> None :
@@ -40,11 +46,13 @@ def __init__(self, config: TiledEnvConfig) -> None:
4046 # set up the columns scaling
4147 # only the columns that are to be altered participate in the
4248 # tiling
43- self .column_scales = config .column_scales
49+ self .column_ranges = config .column_ranges
50+ self .column_scales = {}
4451
4552 # Initialize index hash table (IHT) for tile coding.
4653 # This assigns a unique index to each tile up to max_size tiles.
4754 self ._validate ()
55+ self ._create_column_scales ()
4856 self .iht = IHT (self .max_size )
4957
5058 @property
@@ -59,6 +67,10 @@ def n_actions(self) -> int:
5967 def n_states (self ) -> int :
6068 return self .env .n_states
6169
70+ @property
71+ def config (self ) -> Config :
72+ return self .env .config
73+
6274 def step (self , action : ActionBase ) -> TimeStep :
6375 """Execute the action in the environment and return
6476 a new state for observation
@@ -83,16 +95,11 @@ def step(self, action: ActionBase) -> TimeStep:
8395 # of the bin that the total distortion falls into
8496 state .bin_idx = raw_time_step .observation
8597 state .total_distortion = raw_time_step .info ["total_distortion" ]
86- state .column_names = self .env .column_names
98+ state .column_distortions = self .env .column_distortions
8799
88100 time_step = copy_time_step (time_step = raw_time_step , ** {"observation" : state })
89- #time_step = copy.deepcopy(raw_time_step)
90- #time_step.observation = state
91-
92101 return time_step
93102
94- return
95-
96103 def reset (self , ** options ) -> TimeStep :
97104 """Reset the environment so that a new sequence
98105 of episodes can be generated
@@ -116,24 +123,29 @@ def reset(self, **options) -> TimeStep:
116123 # of the bin that the total distortion falls into
117124 state .bin_idx = raw_time_step .observation
118125 state .total_distortion = raw_time_step .info ["total_distortion" ]
119- state .column_names = self .env .column_names
126+ state .column_distortions = self .env .column_distortions
120127
121128 time_step = copy_time_step (time_step = raw_time_step , ** {"observation" : state })
122129
123- #time_step = copy.deepcopy(raw_time_step)
124- #time_step.observation = state
125-
126130 return time_step
127131
128132 def get_action (self , aidx : int ) -> ActionBase :
129133 return self .env .action_space [aidx ]
130134
131135 def save_current_dataset (self , episode_index : int , save_index : bool = False ) -> None :
132136 """
133- Save the current distorted datase for the given episode index
134- :param episode_index:
135- :param save_index:
136- :return:
137+ Save the current data set at the given episode index
138+ Parameters
139+ ----------
140+
141+ episode_index: Episode index corresponding to the training episode
142+ save_index: if True the Pandas index is also saved
143+
144+ Returns
145+ -------
146+
147+ None
148+
137149 """
138150 self .env .save_current_dataset (episode_index , save_index )
139151
@@ -200,22 +212,54 @@ def get_scaled_state(self, state: State) -> list:
200212 """
201213 scaled_state_vals = []
202214 for name in state :
203- scaled_state_vals .append (state [name ] * self .columns_scales [name ])
215+ scaled_state_vals .append (state [name ] * self .column_scales [name ])
204216
205217 return scaled_state_vals
206218
207- def featurize_state_action (self , state , action : ActionBase ) -> None :
208- """
209- Returns the featurized representation for a state-action pair
210- :param state:
211- :param action:
212- :return:
219+ def featurize_state_action (self , state : State , action : ActionBase ) -> List [Tile ]:
220+ """Get a list of Tiles for the given state and action
221+
222+ Parameters
223+ ----------
224+ state: The environment state observed
225+ action: The action
226+
227+ Returns
228+ -------
229+
230+ A list of tiles
231+
213232 """
233+
214234 scaled_state = self .get_scaled_state (state )
215235 featurized = tiles (self .iht , self .num_tilings , scaled_state , [action ])
216236 return featurized
217237
238+ def _create_column_scales (self ) -> None :
239+ """
240+ Create the scales for each column
241+
242+ Returns
243+ -------
244+
245+ None
246+
247+ """
248+
249+ for name in self .column_ranges :
250+ range_ = self .column_ranges [name ]
251+ self .column_scales [name ] = self .tiling_dim / (range_ [1 ] - range_ [0 ])
252+
218253 def _validate (self ) -> None :
254+ """
255+ Validate the internal data structures
256+
257+ Returns
258+ -------
259+
260+ None
261+
262+ """
219263 if self .max_size <= 0 :
220264 raise InvalidParamValue (param_name = "max_size" ,
221265 param_value = str (self .max_size ) + " should be > 0" )
@@ -227,7 +271,10 @@ def _validate(self) -> None:
227271 param_value = str (self .max_size ) +
228272 " should be >=num_tilings * tiling_dim * tiling_dim" )
229273
230- if len (self .column_scales ) == 0 :
274+ if len (self .column_ranges ) == 0 :
231275 raise InvalidParamValue (param_name = "column_scales" ,
232276 param_value = str (len (self .column_scales )) + " should not be empty" )
233277
278+ if len (self .column_ranges ) != len (self .env .column_names ):
279+ raise ValueError ("Column ranges is not equal to number of columns" )
280+
0 commit comments