|
| 1 | +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +from tensordict import TensorDict |
| 11 | +from torch.distributions import Normal |
| 12 | +from typing import Any |
| 13 | + |
| 14 | +from rsl_rl.networks import CNN, MLP, EmpiricalNormalization |
| 15 | + |
| 16 | +from .actor_critic import ActorCritic |
| 17 | + |
| 18 | + |
| 19 | +class ActorCriticCNN(ActorCritic): |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + obs: TensorDict, |
| 23 | + obs_groups: dict[str, list[str]], |
| 24 | + num_actions: int, |
| 25 | + actor_obs_normalization: bool = False, |
| 26 | + critic_obs_normalization: bool = False, |
| 27 | + actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], |
| 28 | + critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], |
| 29 | + actor_cnn_cfg: dict[str, dict] | dict | None = None, |
| 30 | + critic_cnn_cfg: dict[str, dict] | dict | None = None, |
| 31 | + activation: str = "elu", |
| 32 | + init_noise_std: float = 1.0, |
| 33 | + noise_std_type: str = "scalar", |
| 34 | + state_dependent_std: bool = False, |
| 35 | + **kwargs: dict[str, Any], |
| 36 | + ) -> None: |
| 37 | + if kwargs: |
| 38 | + print( |
| 39 | + "ActorCriticCNN.__init__ got unexpected arguments, which will be ignored: " |
| 40 | + + str([key for key in kwargs]) |
| 41 | + ) |
| 42 | + super(ActorCritic, self).__init__() |
| 43 | + |
| 44 | + # Get the observation dimensions |
| 45 | + self.obs_groups = obs_groups |
| 46 | + num_actor_obs_1d = 0 |
| 47 | + self.actor_obs_groups_1d = [] |
| 48 | + actor_in_dims_2d = [] |
| 49 | + actor_in_channels_2d = [] |
| 50 | + self.actor_obs_groups_2d = [] |
| 51 | + for obs_group in obs_groups["policy"]: |
| 52 | + if len(obs[obs_group].shape) == 4: # B, C, H, W |
| 53 | + self.actor_obs_groups_2d.append(obs_group) |
| 54 | + actor_in_dims_2d.append(obs[obs_group].shape[2:4]) |
| 55 | + actor_in_channels_2d.append(obs[obs_group].shape[1]) |
| 56 | + elif len(obs[obs_group].shape) == 2: # B, C |
| 57 | + self.actor_obs_groups_1d.append(obs_group) |
| 58 | + num_actor_obs_1d += obs[obs_group].shape[-1] |
| 59 | + else: |
| 60 | + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") |
| 61 | + num_critic_obs_1d = 0 |
| 62 | + self.critic_obs_groups_1d = [] |
| 63 | + critic_in_dims_2d = [] |
| 64 | + critic_in_channels_2d = [] |
| 65 | + self.critic_obs_groups_2d = [] |
| 66 | + for obs_group in obs_groups["critic"]: |
| 67 | + if len(obs[obs_group].shape) == 4: # B, C, H, W |
| 68 | + self.critic_obs_groups_2d.append(obs_group) |
| 69 | + critic_in_dims_2d.append(obs[obs_group].shape[2:4]) |
| 70 | + critic_in_channels_2d.append(obs[obs_group].shape[1]) |
| 71 | + elif len(obs[obs_group].shape) == 2: # B, C |
| 72 | + self.critic_obs_groups_1d.append(obs_group) |
| 73 | + num_critic_obs_1d += obs[obs_group].shape[-1] |
| 74 | + else: |
| 75 | + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") |
| 76 | + |
| 77 | + # Assert that there are 2D observations |
| 78 | + assert self.actor_obs_groups_2d or self.critic_obs_groups_2d, ( |
| 79 | + "No 2D observations are provided. If this is intentional, use the ActorCritic module instead." |
| 80 | + ) |
| 81 | + |
| 82 | + # Actor CNN |
| 83 | + if self.actor_obs_groups_2d: |
| 84 | + # Resolve the actor CNN configuration |
| 85 | + assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." |
| 86 | + # If a single configuration dictionary is provided, create a dictionary for each 2D observation group |
| 87 | + if not all(isinstance(v, dict) for v in actor_cnn_cfg.values()): |
| 88 | + actor_cnn_cfg = {group: actor_cnn_cfg for group in self.actor_obs_groups_2d} |
| 89 | + # Check that the number of configs matches the number of observation groups |
| 90 | + assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( |
| 91 | + "The number of CNN configurations must match the number of 2D actor observations." |
| 92 | + ) |
| 93 | + |
| 94 | + # Create CNNs for each 2D actor observation |
| 95 | + self.actor_cnns = nn.ModuleDict() |
| 96 | + encoding_dim = 0 |
| 97 | + for idx, obs_group in enumerate(self.actor_obs_groups_2d): |
| 98 | + self.actor_cnns[obs_group] = CNN( |
| 99 | + input_dim=actor_in_dims_2d[idx], |
| 100 | + input_channels=actor_in_channels_2d[idx], |
| 101 | + **actor_cnn_cfg[obs_group], |
| 102 | + ) |
| 103 | + print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") |
| 104 | + # Get the output dimension of the CNN |
| 105 | + if self.actor_cnns[obs_group].output_channels is None: |
| 106 | + encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore |
| 107 | + else: |
| 108 | + raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.") |
| 109 | + else: |
| 110 | + self.actor_cnns = None |
| 111 | + encoding_dim = 0 |
| 112 | + |
| 113 | + # Actor MLP |
| 114 | + self.state_dependent_std = state_dependent_std |
| 115 | + if self.state_dependent_std: |
| 116 | + self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation) |
| 117 | + else: |
| 118 | + self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation) |
| 119 | + print(f"Actor MLP: {self.actor}") |
| 120 | + |
| 121 | + # Actor observation normalization (only for 1D actor observations) |
| 122 | + self.actor_obs_normalization = actor_obs_normalization |
| 123 | + if actor_obs_normalization: |
| 124 | + self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d) |
| 125 | + else: |
| 126 | + self.actor_obs_normalizer = torch.nn.Identity() |
| 127 | + |
| 128 | + # Critic CNN |
| 129 | + if self.critic_obs_groups_2d: |
| 130 | + # Resolve the critic CNN configuration |
| 131 | + assert critic_cnn_cfg is not None, "A critic CNN configuration is required for 2D critic observations." |
| 132 | + # If a single configuration dictionary is provided, create a dictionary for each 2D observation group |
| 133 | + if not all(isinstance(v, dict) for v in critic_cnn_cfg.values()): |
| 134 | + critic_cnn_cfg = {group: critic_cnn_cfg for group in self.critic_obs_groups_2d} |
| 135 | + # Check that the number of configs matches the number of observation groups |
| 136 | + assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( |
| 137 | + "The number of CNN configurations must match the number of 2D critic observations." |
| 138 | + ) |
| 139 | + |
| 140 | + # Create CNNs for each 2D critic observation |
| 141 | + self.critic_cnns = nn.ModuleDict() |
| 142 | + encoding_dim = 0 |
| 143 | + for idx, obs_group in enumerate(self.critic_obs_groups_2d): |
| 144 | + self.critic_cnns[obs_group] = CNN( |
| 145 | + input_dim=critic_in_dims_2d[idx], |
| 146 | + input_channels=critic_in_channels_2d[idx], |
| 147 | + **critic_cnn_cfg[obs_group], |
| 148 | + ) |
| 149 | + print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") |
| 150 | + # Get the output dimension of the CNN |
| 151 | + if self.critic_cnns[obs_group].output_channels is None: |
| 152 | + encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore |
| 153 | + else: |
| 154 | + raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.") |
| 155 | + else: |
| 156 | + self.critic_cnns = None |
| 157 | + encoding_dim = 0 |
| 158 | + |
| 159 | + # Critic MLP |
| 160 | + self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation) |
| 161 | + print(f"Critic MLP: {self.critic}") |
| 162 | + |
| 163 | + # Critic observation normalization (only for 1D critic observations) |
| 164 | + self.critic_obs_normalization = critic_obs_normalization |
| 165 | + if critic_obs_normalization: |
| 166 | + self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d) |
| 167 | + else: |
| 168 | + self.critic_obs_normalizer = torch.nn.Identity() |
| 169 | + |
| 170 | + # Action noise |
| 171 | + self.noise_std_type = noise_std_type |
| 172 | + if self.state_dependent_std: |
| 173 | + torch.nn.init.zeros_(self.actor[-2].weight[num_actions:]) |
| 174 | + if self.noise_std_type == "scalar": |
| 175 | + torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std) |
| 176 | + elif self.noise_std_type == "log": |
| 177 | + torch.nn.init.constant_( |
| 178 | + self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)) |
| 179 | + ) |
| 180 | + else: |
| 181 | + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") |
| 182 | + else: |
| 183 | + if self.noise_std_type == "scalar": |
| 184 | + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) |
| 185 | + elif self.noise_std_type == "log": |
| 186 | + self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) |
| 187 | + else: |
| 188 | + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") |
| 189 | + |
| 190 | + # Action distribution |
| 191 | + # Note: Populated in update_distribution |
| 192 | + self.distribution = None |
| 193 | + |
| 194 | + # Disable args validation for speedup |
| 195 | + Normal.set_default_validate_args(False) |
| 196 | + |
| 197 | + def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: |
| 198 | + if self.actor_cnns is not None: |
| 199 | + # Encode the 2D actor observations |
| 200 | + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] |
| 201 | + cnn_enc = torch.cat(cnn_enc_list, dim=-1) |
| 202 | + # Concatenate to the MLP observations |
| 203 | + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) |
| 204 | + |
| 205 | + super()._update_distribution(mlp_obs) |
| 206 | + |
| 207 | + def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: |
| 208 | + mlp_obs, cnn_obs = self.get_actor_obs(obs) |
| 209 | + mlp_obs = self.actor_obs_normalizer(mlp_obs) |
| 210 | + self._update_distribution(mlp_obs, cnn_obs) |
| 211 | + return self.distribution.sample() |
| 212 | + |
| 213 | + def act_inference(self, obs: TensorDict) -> torch.Tensor: |
| 214 | + mlp_obs, cnn_obs = self.get_actor_obs(obs) |
| 215 | + mlp_obs = self.actor_obs_normalizer(mlp_obs) |
| 216 | + |
| 217 | + if self.actor_cnns is not None: |
| 218 | + # Encode the 2D actor observations |
| 219 | + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] |
| 220 | + cnn_enc = torch.cat(cnn_enc_list, dim=-1) |
| 221 | + # Concatenate to the MLP observations |
| 222 | + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) |
| 223 | + |
| 224 | + if self.state_dependent_std: |
| 225 | + return self.actor(mlp_obs)[..., 0, :] |
| 226 | + else: |
| 227 | + return self.actor(mlp_obs) |
| 228 | + |
| 229 | + def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: |
| 230 | + mlp_obs, cnn_obs = self.get_critic_obs(obs) |
| 231 | + mlp_obs = self.critic_obs_normalizer(mlp_obs) |
| 232 | + |
| 233 | + if self.critic_cnns is not None: |
| 234 | + # Encode the 2D critic observations |
| 235 | + cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d] |
| 236 | + cnn_enc = torch.cat(cnn_enc_list, dim=-1) |
| 237 | + # Concatenate to the MLP observations |
| 238 | + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) |
| 239 | + |
| 240 | + return self.critic(mlp_obs) |
| 241 | + |
| 242 | + def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: |
| 243 | + obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d] |
| 244 | + obs_dict_2d = {} |
| 245 | + for obs_group in self.actor_obs_groups_2d: |
| 246 | + obs_dict_2d[obs_group] = obs[obs_group] |
| 247 | + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d |
| 248 | + |
| 249 | + def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: |
| 250 | + obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d] |
| 251 | + obs_dict_2d = {} |
| 252 | + for obs_group in self.critic_obs_groups_2d: |
| 253 | + obs_dict_2d[obs_group] = obs[obs_group] |
| 254 | + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d |
| 255 | + |
| 256 | + def update_normalization(self, obs: TensorDict) -> None: |
| 257 | + if self.actor_obs_normalization: |
| 258 | + actor_obs, _ = self.get_actor_obs(obs) |
| 259 | + self.actor_obs_normalizer.update(actor_obs) |
| 260 | + if self.critic_obs_normalization: |
| 261 | + critic_obs, _ = self.get_critic_obs(obs) |
| 262 | + self.critic_obs_normalizer.update(critic_obs) |
0 commit comments