Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ def run(self) -> None:
prof.start()

def step(_step_id: int, task_key: str = "Default") -> None:
display_step_id = _step_id + 1
if self.multi_task:
model_index = dp_random.choice(
np.arange(self.num_model, dtype=np.int_),
Expand Down Expand Up @@ -1058,7 +1059,27 @@ def step(_step_id: int, task_key: str = "Default") -> None:
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
)
loss.backward()
# === Initialize gradient diagnostics variables ===
total_norm: torch.Tensor | None = None
pre_clip_named_norms: list[tuple[str, float]] = []
if self.gradient_max_norm > 0.0:
# Collect per-parameter gradient norms before clipping.
# NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor,
# so p.grad.norm() computes the shard-local L2 norm, not the full-parameter
# norm. Skip per-param collection in this case to avoid misleading values.
if (
self.enable_tensorboard
and self.zero_stage < 2
and (
display_step_id % self.tensorboard_freq == 0
or display_step_id == 1
)
):
pre_clip_named_norms = [
(name, p.grad.detach().norm().item())
for name, p in self.wrapper.named_parameters()
if p.grad is not None
]
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
total_norm = torch.nn.utils.clip_grad_norm_(
self.wrapper.parameters(),
Expand Down Expand Up @@ -1183,7 +1204,6 @@ def fake_model() -> dict:
self.train_loss_accu[task_key][item] += more_loss[item]

# Log and persist
display_step_id = _step_id + 1
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
Expand Down Expand Up @@ -1410,6 +1430,32 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
writer.add_scalar(
f"{task_key}/{item}", more_loss[item], display_step_id
)
# === Gradient diagnostics (pre-clip) ===
# Only log if total_norm was computed (i.e., not LKF optimizer).
if self.gradient_max_norm > 0.0 and total_norm is not None:
writer.add_scalar(
f"{task_key}/grad/total_norm",
total_norm.item(),
display_step_id,
)
# Only log per-parameter norms if list is non-empty.
if pre_clip_named_norms:
# Use float32 for histogram to ensure numerical stability
# when gradients are in lower precision (FP16/BF16).
norms = torch.tensor(
[gn for _, gn in pre_clip_named_norms],
dtype=torch.float32,
device="cpu",
)
writer.add_histogram(
f"{task_key}/grad/param_norms", norms, display_step_id
)
# Log top-10 largest per-parameter gradient norms.
pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True)
for name, gn in pre_clip_named_norms[:10]:
writer.add_scalar(
f"{task_key}/grad_top10/{name}", gn, display_step_id
)

self.wrapper.train()
self.t0 = time.time()
Expand Down