|
14 | 14 |
|
15 | 15 |
|
16 | 16 | class RolloutStorage: |
| 17 | + """Storage for the data collected during a rollout. |
| 18 | +
|
| 19 | + The rollout storage is populated by adding transitions during the rollout phase. It then returns a generator for |
| 20 | + learning, depending on the algorithm and the policy architecture. |
| 21 | + """ |
| 22 | + |
17 | 23 | class Transition: |
| 24 | + """Storage for a single state transition.""" |
| 25 | + |
18 | 26 | def __init__(self) -> None: |
19 | 27 | self.observations: TensorDict | None = None |
20 | 28 | self.actions: torch.Tensor | None = None |
@@ -75,7 +83,7 @@ def __init__( |
75 | 83 | # Counter for the number of transitions stored |
76 | 84 | self.step = 0 |
77 | 85 |
|
78 | | - def add_transitions(self, transition: Transition) -> None: |
| 86 | + def add_transition(self, transition: Transition) -> None: |
79 | 87 | # Check if the transition is valid |
80 | 88 | if self.step >= self.num_transitions_per_env: |
81 | 89 | raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.") |
@@ -103,53 +111,9 @@ def add_transitions(self, transition: Transition) -> None: |
103 | 111 | # Increment the counter |
104 | 112 | self.step += 1 |
105 | 113 |
|
106 | | - def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None: |
107 | | - if hidden_states == (None, None): |
108 | | - return |
109 | | - # Make a tuple out of GRU hidden states to match the LSTM format |
110 | | - hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) |
111 | | - hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) |
112 | | - # Initialize hidden states if needed |
113 | | - if self.saved_hidden_state_a is None: |
114 | | - self.saved_hidden_state_a = [ |
115 | | - torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device) |
116 | | - for i in range(len(hidden_state_a)) |
117 | | - ] |
118 | | - self.saved_hidden_state_c = [ |
119 | | - torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device) |
120 | | - for i in range(len(hidden_state_c)) |
121 | | - ] |
122 | | - # Copy the states |
123 | | - for i in range(len(hidden_state_a)): |
124 | | - self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i]) |
125 | | - self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i]) |
126 | | - |
127 | 114 | def clear(self) -> None: |
128 | 115 | self.step = 0 |
129 | 116 |
|
130 | | - def compute_returns( |
131 | | - self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True |
132 | | - ) -> None: |
133 | | - advantage = 0 |
134 | | - for step in reversed(range(self.num_transitions_per_env)): |
135 | | - # If we are at the last step, bootstrap the return value |
136 | | - next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1] |
137 | | - # 1 if we are not in a terminal state, 0 otherwise |
138 | | - next_is_not_terminal = 1.0 - self.dones[step].float() |
139 | | - # TD error: r_t + gamma * V(s_{t+1}) - V(s_t) |
140 | | - delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] |
141 | | - # Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1}) |
142 | | - advantage = delta + next_is_not_terminal * gamma * lam * advantage |
143 | | - # Return: R_t = A(s_t, a_t) + V(s_t) |
144 | | - self.returns[step] = advantage + self.values[step] |
145 | | - |
146 | | - # Compute the advantages |
147 | | - self.advantages = self.returns - self.values |
148 | | - # Normalize the advantages if flag is set |
149 | | - # Note: This is to prevent double normalization (i.e. if per minibatch normalization is used) |
150 | | - if normalize_advantage: |
151 | | - self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) |
152 | | - |
153 | 117 | # For distillation |
154 | 118 | def generator(self) -> Generator: |
155 | 119 | if self.training_type != "distillation": |
@@ -289,3 +253,24 @@ def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int |
289 | 253 | ) |
290 | 254 |
|
291 | 255 | first_traj = last_traj |
| 256 | + |
| 257 | + def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None: |
| 258 | + if hidden_states == (None, None): |
| 259 | + return |
| 260 | + # Make a tuple out of GRU hidden states to match the LSTM format |
| 261 | + hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) |
| 262 | + hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) |
| 263 | + # Initialize hidden states if needed |
| 264 | + if self.saved_hidden_state_a is None: |
| 265 | + self.saved_hidden_state_a = [ |
| 266 | + torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device) |
| 267 | + for i in range(len(hidden_state_a)) |
| 268 | + ] |
| 269 | + self.saved_hidden_state_c = [ |
| 270 | + torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device) |
| 271 | + for i in range(len(hidden_state_c)) |
| 272 | + ] |
| 273 | + # Copy the states |
| 274 | + for i in range(len(hidden_state_a)): |
| 275 | + self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i]) |
| 276 | + self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i]) |
0 commit comments