diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 9d2298febc..4c050741a8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1025,6 +1025,7 @@ def run(self) -> None: prof.start() def step(_step_id: int, task_key: str = "Default") -> None: + display_step_id = _step_id + 1 if self.multi_task: model_index = dp_random.choice( np.arange(self.num_model, dtype=np.int_), @@ -1058,7 +1059,27 @@ def step(_step_id: int, task_key: str = "Default") -> None: **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key ) loss.backward() + # === Initialize gradient diagnostics variables === + total_norm: torch.Tensor | None = None + pre_clip_named_norms: list[tuple[str, float]] = [] if self.gradient_max_norm > 0.0: + # Collect per-parameter gradient norms before clipping. + # NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor, + # so p.grad.norm() computes the shard-local L2 norm, not the full-parameter + # norm. Skip per-param collection in this case to avoid misleading values. + if ( + self.enable_tensorboard + and self.zero_stage < 2 + and ( + display_step_id % self.tensorboard_freq == 0 + or display_step_id == 1 + ) + ): + pre_clip_named_norms = [ + (name, p.grad.detach().norm().item()) + for name, p in self.wrapper.named_parameters() + if p.grad is not None + ] # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead. total_norm = torch.nn.utils.clip_grad_norm_( self.wrapper.parameters(), @@ -1183,7 +1204,6 @@ def fake_model() -> dict: self.train_loss_accu[task_key][item] += more_loss[item] # Log and persist - display_step_id = _step_id + 1 if self.display_in_training and ( display_step_id % self.disp_freq == 0 or display_step_id == 1 ): @@ -1410,6 +1430,32 @@ def log_loss_valid(_task_key: str = "Default") -> dict: writer.add_scalar( f"{task_key}/{item}", more_loss[item], display_step_id ) + # === Gradient diagnostics (pre-clip) === + # Only log if total_norm was computed (i.e., not LKF optimizer). + if self.gradient_max_norm > 0.0 and total_norm is not None: + writer.add_scalar( + f"{task_key}/grad/total_norm", + total_norm.item(), + display_step_id, + ) + # Only log per-parameter norms if list is non-empty. + if pre_clip_named_norms: + # Use float32 for histogram to ensure numerical stability + # when gradients are in lower precision (FP16/BF16). + norms = torch.tensor( + [gn for _, gn in pre_clip_named_norms], + dtype=torch.float32, + device="cpu", + ) + writer.add_histogram( + f"{task_key}/grad/param_norms", norms, display_step_id + ) + # Log top-10 largest per-parameter gradient norms. + pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True) + for name, gn in pre_clip_named_norms[:10]: + writer.add_scalar( + f"{task_key}/grad_top10/{name}", gn, display_step_id + ) self.wrapper.train() self.t0 = time.time()