Skip to content

Commit 6d6d3a4

Browse files
Restructure rollout storage for clarity (#137)
1 parent e1e7071 commit 6d6d3a4

File tree

6 files changed

+76
-114
lines changed

6 files changed

+76
-114
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ reportMissingImports = "none"
5656
# This is required to ignore type checks of modules with stubs missing.
5757
reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers
5858
reportGeneralTypeIssues = "none" # -> usage of literal MISSING in dataclasses
59-
reportOptionalMemberAccess = "warning"
59+
reportOptionalMemberAccess = "none"
6060
reportPrivateUsage = "warning"

rsl_rl/algorithms/distillation.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Distillation:
2121
def __init__(
2222
self,
2323
policy: StudentTeacher | StudentTeacherRecurrent,
24+
storage: RolloutStorage,
2425
num_learning_epochs: int = 1,
2526
gradient_length: int = 15,
2627
learning_rate: float = 1e-3,
@@ -46,12 +47,12 @@ def __init__(
4647
# Distillation components
4748
self.policy = policy
4849
self.policy.to(self.device)
49-
self.storage = None # Initialized later
5050

51-
# Initialize the optimizer
51+
# Create the optimizer
5252
self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate)
5353

54-
# Initialize the transition
54+
# Add storage
55+
self.storage = storage
5556
self.transition = RolloutStorage.Transition()
5657
self.last_hidden_states = (None, None)
5758

@@ -73,24 +74,6 @@ def __init__(
7374

7475
self.num_updates = 0
7576

76-
def init_storage(
77-
self,
78-
training_type: str,
79-
num_envs: int,
80-
num_transitions_per_env: int,
81-
obs: TensorDict,
82-
actions_shape: tuple[int],
83-
) -> None:
84-
# Create rollout storage
85-
self.storage = RolloutStorage(
86-
training_type,
87-
num_envs,
88-
num_transitions_per_env,
89-
obs,
90-
actions_shape,
91-
self.device,
92-
)
93-
9477
def act(self, obs: TensorDict) -> torch.Tensor:
9578
# Compute the actions
9679
self.transition.actions = self.policy.act(obs).detach()
@@ -104,12 +87,11 @@ def process_env_step(
10487
) -> None:
10588
# Update the normalizers
10689
self.policy.update_normalization(obs)
107-
10890
# Record the rewards and dones
10991
self.transition.rewards = rewards
11092
self.transition.dones = dones
11193
# Record the transition
112-
self.storage.add_transitions(self.transition)
94+
self.storage.add_transition(self.transition)
11395
self.transition.clear()
11496
self.policy.reset(dones)
11597

rsl_rl/algorithms/ppo.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class PPO:
2626
def __init__(
2727
self,
2828
policy: ActorCritic | ActorCriticRecurrent | ActorCriticCNN,
29+
storage: RolloutStorage,
2930
num_learning_epochs: int = 5,
3031
num_mini_batches: int = 4,
3132
clip_param: float = 0.2,
@@ -38,8 +39,8 @@ def __init__(
3839
use_clipped_value_loss: bool = True,
3940
schedule: str = "adaptive",
4041
desired_kl: float = 0.01,
41-
device: str = "cpu",
4242
normalize_advantage_per_mini_batch: bool = False,
43+
device: str = "cpu",
4344
# RND parameters
4445
rnd_cfg: dict | None = None,
4546
# Symmetry parameters
@@ -100,11 +101,11 @@ def __init__(
100101
self.policy = policy
101102
self.policy.to(self.device)
102103

103-
# Create optimizer
104+
# Create the optimizer
104105
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
105106

106-
# Create rollout storage
107-
self.storage: RolloutStorage | None = None
107+
# Add storage
108+
self.storage = storage
108109
self.transition = RolloutStorage.Transition()
109110

110111
# PPO parameters
@@ -122,24 +123,6 @@ def __init__(
122123
self.learning_rate = learning_rate
123124
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
124125

125-
def init_storage(
126-
self,
127-
training_type: str,
128-
num_envs: int,
129-
num_transitions_per_env: int,
130-
obs: TensorDict,
131-
actions_shape: tuple[int] | list[int],
132-
) -> None:
133-
# Create rollout storage
134-
self.storage = RolloutStorage(
135-
training_type,
136-
num_envs,
137-
num_transitions_per_env,
138-
obs,
139-
actions_shape,
140-
self.device,
141-
)
142-
143126
def act(self, obs: TensorDict) -> torch.Tensor:
144127
if self.policy.is_recurrent:
145128
self.transition.hidden_states = self.policy.get_hidden_states()
@@ -180,16 +163,32 @@ def process_env_step(
180163
)
181164

182165
# Record the transition
183-
self.storage.add_transitions(self.transition)
166+
self.storage.add_transition(self.transition)
184167
self.transition.clear()
185168
self.policy.reset(dones)
186169

187170
def compute_returns(self, obs: TensorDict) -> None:
171+
st = self.storage
188172
# Compute value for the last step
189173
last_values = self.policy.evaluate(obs).detach()
190-
self.storage.compute_returns(
191-
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
192-
)
174+
# Compute returns and advantages
175+
advantage = 0
176+
for step in reversed(range(st.num_transitions_per_env)):
177+
# If we are at the last step, bootstrap the return value
178+
next_values = last_values if step == st.num_transitions_per_env - 1 else st.values[step + 1]
179+
# 1 if we are not in a terminal state, 0 otherwise
180+
next_is_not_terminal = 1.0 - st.dones[step].float()
181+
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
182+
delta = st.rewards[step] + next_is_not_terminal * self.gamma * next_values - st.values[step]
183+
# Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
184+
advantage = delta + next_is_not_terminal * self.gamma * self.lam * advantage
185+
# Return: R_t = A(s_t, a_t) + V(s_t)
186+
st.returns[step] = advantage + st.values[step]
187+
# Compute the advantages
188+
st.advantages = st.returns - st.values
189+
# Normalize the advantages if per minibatch normalization is not used
190+
if not self.normalize_advantage_per_mini_batch:
191+
st.advantages = (st.advantages - st.advantages.mean()) / (st.advantages.std() + 1e-8)
193192

194193
def update(self) -> dict[str, float]:
195194
mean_value_loss = 0

rsl_rl/runners/distillation_runner.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from rsl_rl.env import VecEnv
1717
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
1818
from rsl_rl.runners import OnPolicyRunner
19+
from rsl_rl.storage import RolloutStorage
1920
from rsl_rl.utils import resolve_obs_groups, store_code_state
2021

2122

@@ -159,19 +160,15 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
159160
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
160161
).to(self.device)
161162

163+
# Initialize the storage
164+
storage = RolloutStorage(
165+
"distillation", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
166+
)
167+
162168
# Initialize the algorithm
163169
alg_class = eval(self.alg_cfg.pop("class_name"))
164170
alg: Distillation = alg_class(
165-
student_teacher, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
166-
)
167-
168-
# Initialize the storage
169-
alg.init_storage(
170-
"distillation",
171-
self.env.num_envs,
172-
self.num_steps_per_env,
173-
obs,
174-
[self.env.num_actions],
171+
student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
175172
)
176173

177174
return alg

rsl_rl/runners/on_policy_runner.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
resolve_rnd_config,
2424
resolve_symmetry_config,
2525
)
26+
from rsl_rl.storage import RolloutStorage
2627
from rsl_rl.utils import resolve_obs_groups, store_code_state
2728

2829

@@ -424,17 +425,15 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO:
424425
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
425426
).to(self.device)
426427

428+
# Initialize the storage
429+
storage = RolloutStorage(
430+
"rl", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
431+
)
432+
427433
# Initialize the algorithm
428434
alg_class = eval(self.alg_cfg.pop("class_name"))
429-
alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg)
430-
431-
# Initialize the storage
432-
alg.init_storage(
433-
"rl",
434-
self.env.num_envs,
435-
self.num_steps_per_env,
436-
obs,
437-
[self.env.num_actions],
435+
alg: PPO = alg_class(
436+
actor_critic, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
438437
)
439438

440439
return alg

rsl_rl/storage/rollout_storage.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515

1616
class RolloutStorage:
17+
"""Storage for the data collected during a rollout.
18+
19+
The rollout storage is populated by adding transitions during the rollout phase. It then returns a generator for
20+
learning, depending on the algorithm and the policy architecture.
21+
"""
22+
1723
class Transition:
24+
"""Storage for a single state transition."""
25+
1826
def __init__(self) -> None:
1927
self.observations: TensorDict | None = None
2028
self.actions: torch.Tensor | None = None
@@ -75,7 +83,7 @@ def __init__(
7583
# Counter for the number of transitions stored
7684
self.step = 0
7785

78-
def add_transitions(self, transition: Transition) -> None:
86+
def add_transition(self, transition: Transition) -> None:
7987
# Check if the transition is valid
8088
if self.step >= self.num_transitions_per_env:
8189
raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.")
@@ -103,53 +111,9 @@ def add_transitions(self, transition: Transition) -> None:
103111
# Increment the counter
104112
self.step += 1
105113

106-
def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
107-
if hidden_states == (None, None):
108-
return
109-
# Make a tuple out of GRU hidden states to match the LSTM format
110-
hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
111-
hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
112-
# Initialize hidden states if needed
113-
if self.saved_hidden_state_a is None:
114-
self.saved_hidden_state_a = [
115-
torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
116-
for i in range(len(hidden_state_a))
117-
]
118-
self.saved_hidden_state_c = [
119-
torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
120-
for i in range(len(hidden_state_c))
121-
]
122-
# Copy the states
123-
for i in range(len(hidden_state_a)):
124-
self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
125-
self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])
126-
127114
def clear(self) -> None:
128115
self.step = 0
129116

130-
def compute_returns(
131-
self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True
132-
) -> None:
133-
advantage = 0
134-
for step in reversed(range(self.num_transitions_per_env)):
135-
# If we are at the last step, bootstrap the return value
136-
next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1]
137-
# 1 if we are not in a terminal state, 0 otherwise
138-
next_is_not_terminal = 1.0 - self.dones[step].float()
139-
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
140-
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
141-
# Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
142-
advantage = delta + next_is_not_terminal * gamma * lam * advantage
143-
# Return: R_t = A(s_t, a_t) + V(s_t)
144-
self.returns[step] = advantage + self.values[step]
145-
146-
# Compute the advantages
147-
self.advantages = self.returns - self.values
148-
# Normalize the advantages if flag is set
149-
# Note: This is to prevent double normalization (i.e. if per minibatch normalization is used)
150-
if normalize_advantage:
151-
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
152-
153117
# For distillation
154118
def generator(self) -> Generator:
155119
if self.training_type != "distillation":
@@ -289,3 +253,24 @@ def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int
289253
)
290254

291255
first_traj = last_traj
256+
257+
def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
258+
if hidden_states == (None, None):
259+
return
260+
# Make a tuple out of GRU hidden states to match the LSTM format
261+
hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
262+
hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
263+
# Initialize hidden states if needed
264+
if self.saved_hidden_state_a is None:
265+
self.saved_hidden_state_a = [
266+
torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
267+
for i in range(len(hidden_state_a))
268+
]
269+
self.saved_hidden_state_c = [
270+
torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
271+
for i in range(len(hidden_state_c))
272+
]
273+
# Copy the states
274+
for i in range(len(hidden_state_a)):
275+
self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
276+
self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])

0 commit comments

Comments
 (0)