22Tile environment
33"""
44
5+ import copy
56from typing import TypeVar
7+ from dataclasses import dataclass
68from src .extern .tile_coding import IHT , tiles
79from src .spaces .actions import ActionBase , ActionType
810from src .spaces .time_step import TimeStep
911from src .exceptions .exceptions import InvalidParamValue
12+ from src .spaces .state import State
13+ from src .spaces .time_step import copy_time_step
1014
1115Env = TypeVar ('Env' )
12- State = TypeVar ('State' )
1316
1417
18+ @dataclass (init = True , repr = True )
1519class TiledEnvConfig (object ):
20+ """Configuration for the TiledEnvironment
1621 """
17- Configuration for the TiledEnvironment
18- """
19- def __init__ (self ):
20- self .env : Env = None
21- self .num_tilings : int = 0
22- self .max_size = 0
23- self .tiling_dim = 0
24- self .column_scales = {}
22+
23+ env : Env = None
24+ num_tilings : int = 0
25+ max_size : int = 0
26+ tiling_dim : int = 0
27+ column_scales : dict = None
2528
2629
2730class TiledEnv (object ):
@@ -44,14 +47,6 @@ def __init__(self, config: TiledEnvConfig) -> None:
4447 self ._validate ()
4548 self .iht = IHT (self .max_size )
4649
47- def step (self , action : ActionBase ) -> TimeStep :
48- """
49- Apply the action and return new state
50- :param action: The action to apply
51- :return:
52- """
53- return self .env .step (action )
54-
5550 @property
5651 def action_space (self ):
5752 return self .env .action_space
@@ -64,6 +59,72 @@ def n_actions(self) -> int:
6459 def n_states (self ) -> int :
6560 return self .env .n_states
6661
62+ def step (self , action : ActionBase ) -> TimeStep :
63+ """Execute the action in the environment and return
64+ a new state for observation
65+
66+ Parameters
67+ ----------
68+ action: The action to execute
69+
70+ Returns
71+ -------
72+
73+ An instance of TimeStep type
74+
75+ """
76+
77+ raw_time_step = self .env .step (action )
78+
79+ # a state wrapper to communicate
80+ state = State ()
81+
82+ # the raw environment returns an index
83+ # of the bin that the total distortion falls into
84+ state .bin_idx = raw_time_step .observation
85+ state .total_distortion = raw_time_step .info ["total_distortion" ]
86+ state .column_names = self .env .column_names
87+
88+ time_step = copy_time_step (time_step = raw_time_step , ** {"observation" : state })
89+ #time_step = copy.deepcopy(raw_time_step)
90+ #time_step.observation = state
91+
92+ return time_step
93+
94+ return
95+
96+ def reset (self , ** options ) -> TimeStep :
97+ """Reset the environment so that a new sequence
98+ of episodes can be generated
99+
100+ Parameters
101+ ----------
102+ options: Client provided named options
103+
104+ Returns
105+ -------
106+
107+ An instance of TimeStep type
108+ """
109+
110+ raw_time_step = self .env .reset (** options )
111+
112+ # a state wrapper to communicate
113+ state = State ()
114+
115+ # the raw environment returns an index
116+ # of the bin that the total distortion falls into
117+ state .bin_idx = raw_time_step .observation
118+ state .total_distortion = raw_time_step .info ["total_distortion" ]
119+ state .column_names = self .env .column_names
120+
121+ time_step = copy_time_step (time_step = raw_time_step , ** {"observation" : state })
122+
123+ #time_step = copy.deepcopy(raw_time_step)
124+ #time_step.observation = state
125+
126+ return time_step
127+
67128 def get_action (self , aidx : int ) -> ActionBase :
68129 return self .env .action_space [aidx ]
69130
@@ -130,21 +191,6 @@ def total_current_distortion(self) -> float:
130191 """
131192 return self .env .total_current_distortion ()
132193
133- def reset (self , ** options ) -> TimeStep :
134- """
135- Starts a new sequence and returns the first `TimeStep` of this sequence.
136- Returns:
137- A `TimeStep` namedtuple containing:
138- step_type: A `StepType` of `FIRST`.
139- reward: `None`, indicating the reward is undefined.
140- discount: `None`, indicating the discount is undefined.
141- observation: A NumPy array, or a nested dict, list or tuple of arrays.
142- Scalar values that can be cast to NumPy arrays (e.g. Python floats)
143- are also valid in place of a scalar array. Must conform to the
144- specification returned by `observation_spec()`.
145- """
146- return self .env .reset (** options )
147-
148194 def get_scaled_state (self , state : State ) -> list :
149195 """
150196 Scales the state components ad returns the
0 commit comments