Skip to content

Commit e8455fa

Browse files
Adds perceptive actor-critic class (#114)
Adds a new perceptive actor-critic class, that can define CNN layers for every 2D observation term. --------- Co-authored-by: ClemensSchwarke <clemens.schwarke@gmail.com>
1 parent c656748 commit e8455fa

File tree

12 files changed

+494
-31
lines changed

12 files changed

+494
-31
lines changed

rsl_rl/algorithms/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from itertools import chain
1212
from tensordict import TensorDict
1313

14-
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
14+
from rsl_rl.modules import ActorCritic, ActorCriticCNN, ActorCriticRecurrent
1515
from rsl_rl.modules.rnd import RandomNetworkDistillation
1616
from rsl_rl.storage import RolloutStorage
1717
from rsl_rl.utils import string_to_callable
@@ -20,12 +20,12 @@
2020
class PPO:
2121
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
2222

23-
policy: ActorCritic | ActorCriticRecurrent
23+
policy: ActorCritic | ActorCriticRecurrent | ActorCriticCNN
2424
"""The actor critic module."""
2525

2626
def __init__(
2727
self,
28-
policy: ActorCritic | ActorCriticRecurrent,
28+
policy: ActorCritic | ActorCriticRecurrent | ActorCriticCNN,
2929
num_learning_epochs: int = 5,
3030
num_mini_batches: int = 4,
3131
clip_param: float = 0.2,

rsl_rl/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""Definitions for neural-network components for RL-agents."""
77

88
from .actor_critic import ActorCritic
9+
from .actor_critic_cnn import ActorCriticCNN
910
from .actor_critic_recurrent import ActorCriticRecurrent
1011
from .rnd import RandomNetworkDistillation, resolve_rnd_config
1112
from .student_teacher import StudentTeacher
@@ -14,6 +15,7 @@
1415

1516
__all__ = [
1617
"ActorCritic",
18+
"ActorCriticCNN",
1719
"ActorCriticRecurrent",
1820
"RandomNetworkDistillation",
1921
"StudentTeacher",

rsl_rl/modules/actor_critic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def __init__(
4949
assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
5050
num_critic_obs += obs[obs_group].shape[-1]
5151

52-
self.state_dependent_std = state_dependent_std
53-
5452
# Actor
53+
self.state_dependent_std = state_dependent_std
5554
if self.state_dependent_std:
5655
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
5756
else:
@@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor:
121120
def entropy(self) -> torch.Tensor:
122121
return self.distribution.entropy().sum(dim=-1)
123122

124-
def _update_distribution(self, obs: TensorDict) -> None:
123+
def _update_distribution(self, obs: torch.Tensor) -> None:
125124
if self.state_dependent_std:
126125
# Compute mean and standard deviation
127126
mean_and_std = self.actor(obs)

rsl_rl/modules/actor_critic_cnn.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
import torch
9+
import torch.nn as nn
10+
from tensordict import TensorDict
11+
from torch.distributions import Normal
12+
from typing import Any
13+
14+
from rsl_rl.networks import CNN, MLP, EmpiricalNormalization
15+
16+
from .actor_critic import ActorCritic
17+
18+
19+
class ActorCriticCNN(ActorCritic):
20+
def __init__(
21+
self,
22+
obs: TensorDict,
23+
obs_groups: dict[str, list[str]],
24+
num_actions: int,
25+
actor_obs_normalization: bool = False,
26+
critic_obs_normalization: bool = False,
27+
actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
28+
critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
29+
actor_cnn_cfg: dict[str, dict] | dict | None = None,
30+
critic_cnn_cfg: dict[str, dict] | dict | None = None,
31+
activation: str = "elu",
32+
init_noise_std: float = 1.0,
33+
noise_std_type: str = "scalar",
34+
state_dependent_std: bool = False,
35+
**kwargs: dict[str, Any],
36+
) -> None:
37+
if kwargs:
38+
print(
39+
"ActorCriticCNN.__init__ got unexpected arguments, which will be ignored: "
40+
+ str([key for key in kwargs])
41+
)
42+
super(ActorCritic, self).__init__()
43+
44+
# Get the observation dimensions
45+
self.obs_groups = obs_groups
46+
num_actor_obs_1d = 0
47+
self.actor_obs_groups_1d = []
48+
actor_in_dims_2d = []
49+
actor_in_channels_2d = []
50+
self.actor_obs_groups_2d = []
51+
for obs_group in obs_groups["policy"]:
52+
if len(obs[obs_group].shape) == 4: # B, C, H, W
53+
self.actor_obs_groups_2d.append(obs_group)
54+
actor_in_dims_2d.append(obs[obs_group].shape[2:4])
55+
actor_in_channels_2d.append(obs[obs_group].shape[1])
56+
elif len(obs[obs_group].shape) == 2: # B, C
57+
self.actor_obs_groups_1d.append(obs_group)
58+
num_actor_obs_1d += obs[obs_group].shape[-1]
59+
else:
60+
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
61+
num_critic_obs_1d = 0
62+
self.critic_obs_groups_1d = []
63+
critic_in_dims_2d = []
64+
critic_in_channels_2d = []
65+
self.critic_obs_groups_2d = []
66+
for obs_group in obs_groups["critic"]:
67+
if len(obs[obs_group].shape) == 4: # B, C, H, W
68+
self.critic_obs_groups_2d.append(obs_group)
69+
critic_in_dims_2d.append(obs[obs_group].shape[2:4])
70+
critic_in_channels_2d.append(obs[obs_group].shape[1])
71+
elif len(obs[obs_group].shape) == 2: # B, C
72+
self.critic_obs_groups_1d.append(obs_group)
73+
num_critic_obs_1d += obs[obs_group].shape[-1]
74+
else:
75+
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}")
76+
77+
# Assert that there are 2D observations
78+
assert self.actor_obs_groups_2d or self.critic_obs_groups_2d, (
79+
"No 2D observations are provided. If this is intentional, use the ActorCritic module instead."
80+
)
81+
82+
# Actor CNN
83+
if self.actor_obs_groups_2d:
84+
# Resolve the actor CNN configuration
85+
assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations."
86+
# If a single configuration dictionary is provided, create a dictionary for each 2D observation group
87+
if not all(isinstance(v, dict) for v in actor_cnn_cfg.values()):
88+
actor_cnn_cfg = {group: actor_cnn_cfg for group in self.actor_obs_groups_2d}
89+
# Check that the number of configs matches the number of observation groups
90+
assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), (
91+
"The number of CNN configurations must match the number of 2D actor observations."
92+
)
93+
94+
# Create CNNs for each 2D actor observation
95+
self.actor_cnns = nn.ModuleDict()
96+
encoding_dim = 0
97+
for idx, obs_group in enumerate(self.actor_obs_groups_2d):
98+
self.actor_cnns[obs_group] = CNN(
99+
input_dim=actor_in_dims_2d[idx],
100+
input_channels=actor_in_channels_2d[idx],
101+
**actor_cnn_cfg[obs_group],
102+
)
103+
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
104+
# Get the output dimension of the CNN
105+
if self.actor_cnns[obs_group].output_channels is None:
106+
encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore
107+
else:
108+
raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.")
109+
else:
110+
self.actor_cnns = None
111+
encoding_dim = 0
112+
113+
# Actor MLP
114+
self.state_dependent_std = state_dependent_std
115+
if self.state_dependent_std:
116+
self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation)
117+
else:
118+
self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation)
119+
print(f"Actor MLP: {self.actor}")
120+
121+
# Actor observation normalization (only for 1D actor observations)
122+
self.actor_obs_normalization = actor_obs_normalization
123+
if actor_obs_normalization:
124+
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d)
125+
else:
126+
self.actor_obs_normalizer = torch.nn.Identity()
127+
128+
# Critic CNN
129+
if self.critic_obs_groups_2d:
130+
# Resolve the critic CNN configuration
131+
assert critic_cnn_cfg is not None, "A critic CNN configuration is required for 2D critic observations."
132+
# If a single configuration dictionary is provided, create a dictionary for each 2D observation group
133+
if not all(isinstance(v, dict) for v in critic_cnn_cfg.values()):
134+
critic_cnn_cfg = {group: critic_cnn_cfg for group in self.critic_obs_groups_2d}
135+
# Check that the number of configs matches the number of observation groups
136+
assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), (
137+
"The number of CNN configurations must match the number of 2D critic observations."
138+
)
139+
140+
# Create CNNs for each 2D critic observation
141+
self.critic_cnns = nn.ModuleDict()
142+
encoding_dim = 0
143+
for idx, obs_group in enumerate(self.critic_obs_groups_2d):
144+
self.critic_cnns[obs_group] = CNN(
145+
input_dim=critic_in_dims_2d[idx],
146+
input_channels=critic_in_channels_2d[idx],
147+
**critic_cnn_cfg[obs_group],
148+
)
149+
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
150+
# Get the output dimension of the CNN
151+
if self.critic_cnns[obs_group].output_channels is None:
152+
encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore
153+
else:
154+
raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.")
155+
else:
156+
self.critic_cnns = None
157+
encoding_dim = 0
158+
159+
# Critic MLP
160+
self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation)
161+
print(f"Critic MLP: {self.critic}")
162+
163+
# Critic observation normalization (only for 1D critic observations)
164+
self.critic_obs_normalization = critic_obs_normalization
165+
if critic_obs_normalization:
166+
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d)
167+
else:
168+
self.critic_obs_normalizer = torch.nn.Identity()
169+
170+
# Action noise
171+
self.noise_std_type = noise_std_type
172+
if self.state_dependent_std:
173+
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
174+
if self.noise_std_type == "scalar":
175+
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
176+
elif self.noise_std_type == "log":
177+
torch.nn.init.constant_(
178+
self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7))
179+
)
180+
else:
181+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
182+
else:
183+
if self.noise_std_type == "scalar":
184+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
185+
elif self.noise_std_type == "log":
186+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
187+
else:
188+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
189+
190+
# Action distribution
191+
# Note: Populated in update_distribution
192+
self.distribution = None
193+
194+
# Disable args validation for speedup
195+
Normal.set_default_validate_args(False)
196+
197+
def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None:
198+
if self.actor_cnns is not None:
199+
# Encode the 2D actor observations
200+
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
201+
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
202+
# Concatenate to the MLP observations
203+
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
204+
205+
super()._update_distribution(mlp_obs)
206+
207+
def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
208+
mlp_obs, cnn_obs = self.get_actor_obs(obs)
209+
mlp_obs = self.actor_obs_normalizer(mlp_obs)
210+
self._update_distribution(mlp_obs, cnn_obs)
211+
return self.distribution.sample()
212+
213+
def act_inference(self, obs: TensorDict) -> torch.Tensor:
214+
mlp_obs, cnn_obs = self.get_actor_obs(obs)
215+
mlp_obs = self.actor_obs_normalizer(mlp_obs)
216+
217+
if self.actor_cnns is not None:
218+
# Encode the 2D actor observations
219+
cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d]
220+
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
221+
# Concatenate to the MLP observations
222+
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
223+
224+
if self.state_dependent_std:
225+
return self.actor(mlp_obs)[..., 0, :]
226+
else:
227+
return self.actor(mlp_obs)
228+
229+
def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
230+
mlp_obs, cnn_obs = self.get_critic_obs(obs)
231+
mlp_obs = self.critic_obs_normalizer(mlp_obs)
232+
233+
if self.critic_cnns is not None:
234+
# Encode the 2D critic observations
235+
cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d]
236+
cnn_enc = torch.cat(cnn_enc_list, dim=-1)
237+
# Concatenate to the MLP observations
238+
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1)
239+
240+
return self.critic(mlp_obs)
241+
242+
def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
243+
obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d]
244+
obs_dict_2d = {}
245+
for obs_group in self.actor_obs_groups_2d:
246+
obs_dict_2d[obs_group] = obs[obs_group]
247+
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
248+
249+
def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
250+
obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d]
251+
obs_dict_2d = {}
252+
for obs_group in self.critic_obs_groups_2d:
253+
obs_dict_2d[obs_group] = obs[obs_group]
254+
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d
255+
256+
def update_normalization(self, obs: TensorDict) -> None:
257+
if self.actor_obs_normalization:
258+
actor_obs, _ = self.get_actor_obs(obs)
259+
self.actor_obs_normalizer.update(actor_obs)
260+
if self.critic_obs_normalization:
261+
critic_obs, _ = self.get_critic_obs(obs)
262+
self.critic_obs_normalizer.update(critic_obs)

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def __init__(
6161
assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
6262
num_critic_obs += obs[obs_group].shape[-1]
6363

64-
self.state_dependent_std = state_dependent_std
65-
6664
# Actor
65+
self.state_dependent_std = state_dependent_std
6766
self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
6867
if self.state_dependent_std:
6968
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
@@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None:
138137
def forward(self) -> NoReturn:
139138
raise NotImplementedError
140139

141-
def _update_distribution(self, obs: TensorDict) -> None:
140+
def _update_distribution(self, obs: torch.Tensor) -> None:
142141
if self.state_dependent_std:
143142
# Compute mean and standard deviation
144143
mean_and_std = self.actor(obs)

rsl_rl/networks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
"""Definitions for components of modules."""
77

8+
from .cnn import CNN
89
from .memory import HiddenState, Memory
910
from .mlp import MLP
1011
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
1112

1213
__all__ = [
14+
"CNN",
1315
"MLP",
1416
"EmpiricalDiscountedVariationNormalization",
1517
"EmpiricalNormalization",

0 commit comments

Comments
 (0)