@@ -128,128 +128,8 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
128128 self .save (os .path .join (self .logger .log_dir , f"model_{ it } .pt" )) # type: ignore
129129
130130 # Save the final model after training
131- if self .log_dir is not None and not self .disable_logs :
132- self .save (os .path .join (self .log_dir , f"model_{ self .current_learning_iteration } .pt" ))
133-
134- def log (self , locs : dict , width : int = 80 , pad : int = 35 ) -> None :
135- # Compute the collection size
136- collection_size = self .num_steps_per_env * self .env .num_envs * self .gpu_world_size
137- # Update total time-steps and time
138- self .tot_timesteps += collection_size
139- self .tot_time += locs ["collection_time" ] + locs ["learn_time" ]
140- iteration_time = locs ["collection_time" ] + locs ["learn_time" ]
141-
142- # Log episode information
143- ep_string = ""
144- if locs ["ep_infos" ]:
145- for key in locs ["ep_infos" ][0 ]:
146- infotensor = torch .tensor ([], device = self .device )
147- for ep_info in locs ["ep_infos" ]:
148- # Handle scalar and zero dimensional tensor infos
149- if key not in ep_info :
150- continue
151- if not isinstance (ep_info [key ], torch .Tensor ):
152- ep_info [key ] = torch .Tensor ([ep_info [key ]])
153- if len (ep_info [key ].shape ) == 0 :
154- ep_info [key ] = ep_info [key ].unsqueeze (0 )
155- infotensor = torch .cat ((infotensor , ep_info [key ].to (self .device )))
156- value = torch .mean (infotensor )
157- # Log to logger and terminal
158- if "/" in key :
159- self .writer .add_scalar (key , value , locs ["it" ])
160- ep_string += f"""{ f"{ key } :" :>{pad }} { value :.4f} \n """
161- else :
162- self .writer .add_scalar ("Episode/" + key , value , locs ["it" ])
163- ep_string += f"""{ f"Mean episode { key } :" :>{pad }} { value :.4f} \n """
164-
165- mean_std = self .alg .policy .action_std .mean ()
166- fps = int (collection_size / (locs ["collection_time" ] + locs ["learn_time" ]))
167-
168- # Log losses
169- for key , value in locs ["loss_dict" ].items ():
170- self .writer .add_scalar (f"Loss/{ key } " , value , locs ["it" ])
171- self .writer .add_scalar ("Loss/learning_rate" , self .alg .learning_rate , locs ["it" ])
172-
173- # Log noise std
174- self .writer .add_scalar ("Policy/mean_noise_std" , mean_std .item (), locs ["it" ])
175-
176- # Log performance
177- self .writer .add_scalar ("Perf/total_fps" , fps , locs ["it" ])
178- self .writer .add_scalar ("Perf/collection time" , locs ["collection_time" ], locs ["it" ])
179- self .writer .add_scalar ("Perf/learning_time" , locs ["learn_time" ], locs ["it" ])
180-
181- # Log training
182- if len (locs ["rewbuffer" ]) > 0 :
183- # Separate logging for intrinsic and extrinsic rewards
184- if hasattr (self .alg , "rnd" ) and self .alg .rnd :
185- self .writer .add_scalar ("Rnd/mean_extrinsic_reward" , statistics .mean (locs ["erewbuffer" ]), locs ["it" ])
186- self .writer .add_scalar ("Rnd/mean_intrinsic_reward" , statistics .mean (locs ["irewbuffer" ]), locs ["it" ])
187- self .writer .add_scalar ("Rnd/weight" , self .alg .rnd .weight , locs ["it" ])
188- # Everything else
189- self .writer .add_scalar ("Train/mean_reward" , statistics .mean (locs ["rewbuffer" ]), locs ["it" ])
190- self .writer .add_scalar ("Train/mean_episode_length" , statistics .mean (locs ["lenbuffer" ]), locs ["it" ])
191- if self .logger_type != "wandb" : # wandb does not support non-integer x-axis logging
192- self .writer .add_scalar ("Train/mean_reward/time" , statistics .mean (locs ["rewbuffer" ]), self .tot_time )
193- self .writer .add_scalar (
194- "Train/mean_episode_length/time" , statistics .mean (locs ["lenbuffer" ]), self .tot_time
195- )
196-
197- str = f" \033 [1m Learning iteration { locs ['it' ]} /{ locs ['tot_iter' ]} \033 [0m "
198-
199- run_name = self .cfg .get ("run_name" )
200- run_name_string = f"""{ "Run name:" :>{pad }} { run_name } \n """ if run_name else ""
201-
202- if len (locs ["rewbuffer" ]) > 0 :
203- log_string = (
204- f"""{ "#" * width } \n """
205- f"""{ str .center (width , " " )} \n \n """
206- f"""{ run_name_string } """
207- f"""{ "Computation:" :>{pad }} { fps :.0f} steps/s (collection: { locs ["collection_time" ]:.3f} s, learning {
208- locs ["learn_time" ]:.3f} s)\n """
209- f"""{ "Mean action noise std:" :>{pad }} { mean_std .item ():.2f} \n """
210- )
211- # Print losses
212- for key , value in locs ["loss_dict" ].items ():
213- log_string += f"""{ f"Mean { key } loss:" :>{pad }} { value :.4f} \n """
214- # Print rewards
215- if hasattr (self .alg , "rnd" ) and self .alg .rnd :
216- log_string += (
217- f"""{ "Mean extrinsic reward:" :>{pad }} { statistics .mean (locs ["erewbuffer" ]):.2f} \n """
218- f"""{ "Mean intrinsic reward:" :>{pad }} { statistics .mean (locs ["irewbuffer" ]):.2f} \n """
219- )
220- log_string += f"""{ "Mean reward:" :>{pad }} { statistics .mean (locs ["rewbuffer" ]):.2f} \n """
221- # Print episode information
222- log_string += f"""{ "Mean episode length:" :>{pad }} { statistics .mean (locs ["lenbuffer" ]):.2f} \n """
223- else :
224- log_string = (
225- f"""{ "#" * width } \n """
226- f"""{ str .center (width , " " )} \n \n """
227- f"""{ run_name_string } """
228- f"""{ "Computation:" :>{pad }} { fps :.0f} steps/s (collection: { locs ["collection_time" ]:.3f} s, learning {
229- locs ["learn_time" ]:.3f} s)\n """
230- f"""{ "Mean action noise std:" :>{pad }} { mean_std .item ():.2f} \n """
231- )
232- for key , value in locs ["loss_dict" ].items ():
233- log_string += f"""{ f"{ key } :" :>{pad }} { value :.4f} \n """
234-
235- log_string += ep_string
236- log_string += (
237- f"""{ "-" * width } \n """
238- f"""{ "Total timesteps:" :>{pad }} { self .tot_timesteps } \n """
239- f"""{ "Iteration time:" :>{pad }} { iteration_time :.2f} s\n """
240- f"""{ "Time elapsed:" :>{pad }} { time .strftime ("%H:%M:%S" , time .gmtime (self .tot_time ))} \n """
241- f"""{ "ETA:" :>{pad }} {
242- time .strftime (
243- "%H:%M:%S" ,
244- time .gmtime (
245- self .tot_time
246- / (locs ["it" ] - locs ["start_iter" ] + 1 )
247- * (locs ["start_iter" ] + locs ["num_learning_iterations" ] - locs ["it" ])
248- ),
249- )
250- } \n """
251- )
252- print (log_string )
131+ if self .logger .log_dir is not None and not self .logger .disable_logs :
132+ self .save (os .path .join (self .logger .log_dir , f"model_{ self .current_learning_iteration } .pt" ))
253133
254134 def save (self , path : str , infos : dict | None = None ) -> None :
255135 # Save model
0 commit comments