Skip to content

Commit 17bb7c0

Browse files
Separates the logging functionality from the runner (#140)
This PR adds a new Logger class to remove the logging functionality from OnPolicyRunner. This also allows to make the DistillationRunnner leaner. A few minor changes improve overall code quality, e.g., neptune_utils now follows the same structure as wandb_utils. Tested for PPO and Distillation for Tensorboard.
1 parent 6d6d3a4 commit 17bb7c0

File tree

12 files changed

+408
-448
lines changed

12 files changed

+408
-448
lines changed

rsl_rl/algorithms/distillation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def process_env_step(
9595
self.transition.clear()
9696
self.policy.reset(dones)
9797

98+
def compute_returns(self, obs: TensorDict) -> None:
99+
# Not needed for distillation
100+
pass
101+
98102
def update(self) -> dict[str, float]:
99103
self.num_updates += 1
100104
mean_behavior_loss = 0

rsl_rl/algorithms/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
self.gpu_world_size = 1
6262

6363
# RND components
64-
if rnd_cfg is not None:
64+
if rnd_cfg:
6565
# Extract parameters used in ppo
6666
rnd_lr = rnd_cfg.pop("learning_rate", 1e-3)
6767
# Create RND module
@@ -404,7 +404,7 @@ def update(self) -> dict[str, float]:
404404

405405
# Construct the loss dictionary
406406
loss_dict = {
407-
"value_function": mean_value_loss,
407+
"value": mean_value_loss,
408408
"surrogate": mean_surrogate_loss,
409409
"entropy": mean_entropy,
410410
}

rsl_rl/modules/actor_critic_cnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}")
104104
# Get the output dimension of the CNN
105105
if self.actor_cnns[obs_group].output_channels is None:
106-
encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore
106+
encoding_dim += int(self.actor_cnns[obs_group].output_dim)
107107
else:
108108
raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.")
109109
else:
@@ -149,7 +149,7 @@ def __init__(
149149
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}")
150150
# Get the output dimension of the CNN
151151
if self.critic_cnns[obs_group].output_channels is None:
152-
encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore
152+
encoding_dim += int(self.critic_cnns[obs_group].output_dim)
153153
else:
154154
raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.")
155155
else:
@@ -208,7 +208,7 @@ def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
208208
mlp_obs, cnn_obs = self.get_actor_obs(obs)
209209
mlp_obs = self.actor_obs_normalizer(mlp_obs)
210210
self._update_distribution(mlp_obs, cnn_obs)
211-
return self.distribution.sample()
211+
return self.distribution.sample() # type: ignore
212212

213213
def act_inference(self, obs: TensorDict) -> torch.Tensor:
214214
mlp_obs, cnn_obs = self.get_actor_obs(obs)

rsl_rl/modules/rnd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,6 @@ def resolve_rnd_config(alg_cfg: dict, obs: TensorDict, obs_groups: dict[str, lis
205205
alg_cfg["rnd_cfg"]["obs_groups"] = obs_groups
206206
# Scale down the rnd weight with timestep
207207
alg_cfg["rnd_cfg"]["weight"] *= env.unwrapped.step_dt
208+
else:
209+
alg_cfg["rnd_cfg"] = None
208210
return alg_cfg

rsl_rl/modules/symmetry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ def resolve_symmetry_config(alg_cfg: dict, env: VecEnv) -> dict:
2222
# Note: This is used by the symmetry function for handling different observation terms
2323
if "symmetry_cfg" in alg_cfg and alg_cfg["symmetry_cfg"] is not None:
2424
alg_cfg["symmetry_cfg"]["_env"] = env
25+
else:
26+
alg_cfg["symmetry_cfg"] = None
2527
return alg_cfg

rsl_rl/runners/distillation_runner.py

Lines changed: 12 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -5,152 +5,31 @@
55

66
from __future__ import annotations
77

8-
import os
9-
import time
10-
import torch
11-
from collections import deque
128
from tensordict import TensorDict
139

14-
import rsl_rl
1510
from rsl_rl.algorithms import Distillation
16-
from rsl_rl.env import VecEnv
1711
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
1812
from rsl_rl.runners import OnPolicyRunner
1913
from rsl_rl.storage import RolloutStorage
20-
from rsl_rl.utils import resolve_obs_groups, store_code_state
2114

2215

2316
class DistillationRunner(OnPolicyRunner):
24-
"""On-policy runner for training and evaluation of teacher-student training."""
25-
26-
def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device: str = "cpu") -> None:
27-
self.cfg = train_cfg
28-
self.alg_cfg = train_cfg["algorithm"]
29-
self.policy_cfg = train_cfg["policy"]
30-
self.device = device
31-
self.env = env
32-
33-
# Check if multi-GPU is enabled
34-
self._configure_multi_gpu()
35-
36-
# Store training configuration
37-
self.num_steps_per_env = self.cfg["num_steps_per_env"]
38-
self.save_interval = self.cfg["save_interval"]
39-
40-
# Query observations from environment for algorithm construction
41-
obs = self.env.get_observations()
42-
self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets=["teacher"])
43-
44-
# Create the algorithm
45-
self.alg = self._construct_algorithm(obs)
46-
47-
# Decide whether to disable logging
48-
# Note: We only log from the process with rank 0 (main process)
49-
self.disable_logs = self.is_distributed and self.gpu_global_rank != 0
50-
51-
# Logging
52-
self.log_dir = log_dir
53-
self.writer = None
54-
self.tot_timesteps = 0
55-
self.tot_time = 0
56-
self.current_learning_iteration = 0
57-
self.git_status_repos = [rsl_rl.__file__]
17+
"""Distillation runner for training and evaluation of teacher-student methods."""
5818

5919
def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None:
60-
# Initialize writer
61-
self._prepare_logging_writer()
62-
6320
# Check if teacher is loaded
6421
if not self.alg.policy.loaded_teacher:
6522
raise ValueError("Teacher model parameters not loaded. Please load a teacher model to distill.")
6623

67-
# Randomize initial episode lengths (for exploration)
68-
if init_at_random_ep_len:
69-
self.env.episode_length_buf = torch.randint_like(
70-
self.env.episode_length_buf, high=int(self.env.max_episode_length)
71-
)
72-
73-
# Start learning
74-
obs = self.env.get_observations().to(self.device)
75-
self.train_mode() # switch to train mode (for dropout for example)
24+
super().learn(num_learning_iterations, init_at_random_ep_len)
7625

77-
# Book keeping
78-
ep_infos = []
79-
rewbuffer = deque(maxlen=100)
80-
lenbuffer = deque(maxlen=100)
81-
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
82-
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
26+
def _get_default_obs_sets(self) -> list[str]:
27+
"""Get the the default observation sets required for the algorithm.
8328
84-
# Ensure all parameters are in-synced
85-
if self.is_distributed:
86-
print(f"Synchronizing parameters for rank {self.gpu_global_rank}...")
87-
self.alg.broadcast_parameters()
88-
89-
# Start training
90-
start_iter = self.current_learning_iteration
91-
tot_iter = start_iter + num_learning_iterations
92-
for it in range(start_iter, tot_iter):
93-
start = time.time()
94-
# Rollout
95-
with torch.inference_mode():
96-
for _ in range(self.num_steps_per_env):
97-
# Sample actions
98-
actions = self.alg.act(obs)
99-
# Step the environment
100-
obs, rewards, dones, extras = self.env.step(actions.to(self.env.device))
101-
# Move to device
102-
obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device))
103-
# Process the step
104-
self.alg.process_env_step(obs, rewards, dones, extras)
105-
# Book keeping
106-
if self.log_dir is not None:
107-
if "episode" in extras:
108-
ep_infos.append(extras["episode"])
109-
elif "log" in extras:
110-
ep_infos.append(extras["log"])
111-
# Update rewards
112-
cur_reward_sum += rewards
113-
# Update episode length
114-
cur_episode_length += 1
115-
# Clear data for completed episodes
116-
new_ids = (dones > 0).nonzero(as_tuple=False)
117-
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
118-
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
119-
cur_reward_sum[new_ids] = 0
120-
cur_episode_length[new_ids] = 0
121-
122-
stop = time.time()
123-
collection_time = stop - start
124-
start = stop
125-
126-
# Update policy
127-
loss_dict = self.alg.update()
128-
129-
stop = time.time()
130-
learn_time = stop - start
131-
self.current_learning_iteration = it
132-
133-
if self.log_dir is not None and not self.disable_logs:
134-
# Log information
135-
self.log(locals())
136-
# Save model
137-
if it % self.save_interval == 0:
138-
self.save(os.path.join(self.log_dir, f"model_{it}.pt"))
139-
140-
# Clear episode infos
141-
ep_infos.clear()
142-
# Save code state
143-
if it == start_iter and not self.disable_logs:
144-
# Obtain all the diff files
145-
git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
146-
# If possible store them to wandb or neptune
147-
if self.logger_type in ["wandb", "neptune"] and git_file_paths:
148-
for path in git_file_paths:
149-
self.writer.save_file(path)
150-
151-
# Save the final model after training
152-
if self.log_dir is not None and not self.disable_logs:
153-
self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
29+
.. note::
30+
See :func:`resolve_obs_groups` for more details on the handling of observation sets.
31+
"""
32+
return ["teacher"]
15433

15534
def _construct_algorithm(self, obs: TensorDict) -> Distillation:
15635
"""Construct the distillation algorithm."""
@@ -162,7 +41,7 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
16241

16342
# Initialize the storage
16443
storage = RolloutStorage(
165-
"distillation", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
44+
"distillation", self.env.num_envs, self.cfg["num_steps_per_env"], obs, [self.env.num_actions], self.device
16645
)
16746

16847
# Initialize the algorithm
@@ -171,4 +50,7 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
17150
student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
17251
)
17352

53+
# Set RND configuration to None as it does not apply to distillation
54+
self.cfg["algorithm"]["rnd_cfg"] = None
55+
17456
return alg

0 commit comments

Comments
 (0)