55
66from __future__ import annotations
77
8- import os
9- import time
10- import torch
11- from collections import deque
128from tensordict import TensorDict
139
14- import rsl_rl
1510from rsl_rl .algorithms import Distillation
16- from rsl_rl .env import VecEnv
1711from rsl_rl .modules import StudentTeacher , StudentTeacherRecurrent
1812from rsl_rl .runners import OnPolicyRunner
1913from rsl_rl .storage import RolloutStorage
20- from rsl_rl .utils import resolve_obs_groups , store_code_state
2114
2215
2316class DistillationRunner (OnPolicyRunner ):
24- """On-policy runner for training and evaluation of teacher-student training."""
25-
26- def __init__ (self , env : VecEnv , train_cfg : dict , log_dir : str | None = None , device : str = "cpu" ) -> None :
27- self .cfg = train_cfg
28- self .alg_cfg = train_cfg ["algorithm" ]
29- self .policy_cfg = train_cfg ["policy" ]
30- self .device = device
31- self .env = env
32-
33- # Check if multi-GPU is enabled
34- self ._configure_multi_gpu ()
35-
36- # Store training configuration
37- self .num_steps_per_env = self .cfg ["num_steps_per_env" ]
38- self .save_interval = self .cfg ["save_interval" ]
39-
40- # Query observations from environment for algorithm construction
41- obs = self .env .get_observations ()
42- self .cfg ["obs_groups" ] = resolve_obs_groups (obs , self .cfg ["obs_groups" ], default_sets = ["teacher" ])
43-
44- # Create the algorithm
45- self .alg = self ._construct_algorithm (obs )
46-
47- # Decide whether to disable logging
48- # Note: We only log from the process with rank 0 (main process)
49- self .disable_logs = self .is_distributed and self .gpu_global_rank != 0
50-
51- # Logging
52- self .log_dir = log_dir
53- self .writer = None
54- self .tot_timesteps = 0
55- self .tot_time = 0
56- self .current_learning_iteration = 0
57- self .git_status_repos = [rsl_rl .__file__ ]
17+ """Distillation runner for training and evaluation of teacher-student methods."""
5818
5919 def learn (self , num_learning_iterations : int , init_at_random_ep_len : bool = False ) -> None :
60- # Initialize writer
61- self ._prepare_logging_writer ()
62-
6320 # Check if teacher is loaded
6421 if not self .alg .policy .loaded_teacher :
6522 raise ValueError ("Teacher model parameters not loaded. Please load a teacher model to distill." )
6623
67- # Randomize initial episode lengths (for exploration)
68- if init_at_random_ep_len :
69- self .env .episode_length_buf = torch .randint_like (
70- self .env .episode_length_buf , high = int (self .env .max_episode_length )
71- )
72-
73- # Start learning
74- obs = self .env .get_observations ().to (self .device )
75- self .train_mode () # switch to train mode (for dropout for example)
24+ super ().learn (num_learning_iterations , init_at_random_ep_len )
7625
77- # Book keeping
78- ep_infos = []
79- rewbuffer = deque (maxlen = 100 )
80- lenbuffer = deque (maxlen = 100 )
81- cur_reward_sum = torch .zeros (self .env .num_envs , dtype = torch .float , device = self .device )
82- cur_episode_length = torch .zeros (self .env .num_envs , dtype = torch .float , device = self .device )
26+ def _get_default_obs_sets (self ) -> list [str ]:
27+ """Get the the default observation sets required for the algorithm.
8328
84- # Ensure all parameters are in-synced
85- if self .is_distributed :
86- print (f"Synchronizing parameters for rank { self .gpu_global_rank } ..." )
87- self .alg .broadcast_parameters ()
88-
89- # Start training
90- start_iter = self .current_learning_iteration
91- tot_iter = start_iter + num_learning_iterations
92- for it in range (start_iter , tot_iter ):
93- start = time .time ()
94- # Rollout
95- with torch .inference_mode ():
96- for _ in range (self .num_steps_per_env ):
97- # Sample actions
98- actions = self .alg .act (obs )
99- # Step the environment
100- obs , rewards , dones , extras = self .env .step (actions .to (self .env .device ))
101- # Move to device
102- obs , rewards , dones = (obs .to (self .device ), rewards .to (self .device ), dones .to (self .device ))
103- # Process the step
104- self .alg .process_env_step (obs , rewards , dones , extras )
105- # Book keeping
106- if self .log_dir is not None :
107- if "episode" in extras :
108- ep_infos .append (extras ["episode" ])
109- elif "log" in extras :
110- ep_infos .append (extras ["log" ])
111- # Update rewards
112- cur_reward_sum += rewards
113- # Update episode length
114- cur_episode_length += 1
115- # Clear data for completed episodes
116- new_ids = (dones > 0 ).nonzero (as_tuple = False )
117- rewbuffer .extend (cur_reward_sum [new_ids ][:, 0 ].cpu ().numpy ().tolist ())
118- lenbuffer .extend (cur_episode_length [new_ids ][:, 0 ].cpu ().numpy ().tolist ())
119- cur_reward_sum [new_ids ] = 0
120- cur_episode_length [new_ids ] = 0
121-
122- stop = time .time ()
123- collection_time = stop - start
124- start = stop
125-
126- # Update policy
127- loss_dict = self .alg .update ()
128-
129- stop = time .time ()
130- learn_time = stop - start
131- self .current_learning_iteration = it
132-
133- if self .log_dir is not None and not self .disable_logs :
134- # Log information
135- self .log (locals ())
136- # Save model
137- if it % self .save_interval == 0 :
138- self .save (os .path .join (self .log_dir , f"model_{ it } .pt" ))
139-
140- # Clear episode infos
141- ep_infos .clear ()
142- # Save code state
143- if it == start_iter and not self .disable_logs :
144- # Obtain all the diff files
145- git_file_paths = store_code_state (self .log_dir , self .git_status_repos )
146- # If possible store them to wandb or neptune
147- if self .logger_type in ["wandb" , "neptune" ] and git_file_paths :
148- for path in git_file_paths :
149- self .writer .save_file (path )
150-
151- # Save the final model after training
152- if self .log_dir is not None and not self .disable_logs :
153- self .save (os .path .join (self .log_dir , f"model_{ self .current_learning_iteration } .pt" ))
29+ .. note::
30+ See :func:`resolve_obs_groups` for more details on the handling of observation sets.
31+ """
32+ return ["teacher" ]
15433
15534 def _construct_algorithm (self , obs : TensorDict ) -> Distillation :
15635 """Construct the distillation algorithm."""
@@ -162,7 +41,7 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
16241
16342 # Initialize the storage
16443 storage = RolloutStorage (
165- "distillation" , self .env .num_envs , self .num_steps_per_env , obs , [self .env .num_actions ], self .device
44+ "distillation" , self .env .num_envs , self .cfg [ " num_steps_per_env" ] , obs , [self .env .num_actions ], self .device
16645 )
16746
16847 # Initialize the algorithm
@@ -171,4 +50,7 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
17150 student_teacher , storage , device = self .device , ** self .alg_cfg , multi_gpu_cfg = self .multi_gpu_cfg
17251 )
17352
53+ # Set RND configuration to None as it does not apply to distillation
54+ self .cfg ["algorithm" ]["rnd_cfg" ] = None
55+
17456 return alg
0 commit comments