diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 14fdfd3be..065e58909 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -91,7 +91,9 @@ def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_tea progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) latents_ = trajectory_teacher[progress_id_teacher] - target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) + denom = sigma_ - sigma + denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6) + target = (latents_ - inputs_shared["latents"]) / denom loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) return loss