From 2b421334342e4b41285b8dfd1504afbfe6fd94bd Mon Sep 17 00:00:00 2001 From: Jayce-Ping <315229706@qq.com> Date: Wed, 17 Dec 2025 11:21:49 +0800 Subject: [PATCH] Add noise scheduler --- diffsynth/diffusion/flow_match.py | 144 +++++++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 1 deletion(-) diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index bb5fbc52..817ffe98 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -1,5 +1,5 @@ import torch, math -from typing_extensions import Literal +from typing_extensions import Literal, Optional, List class FlowMatchScheduler(): @@ -177,3 +177,145 @@ def training_weight(self, timestep): timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) weights = self.linear_timesteps_weights[timestep_id] return weights + + +class FlowMatchSDEScheduler(FlowMatchScheduler): + + def __init__(self, + template : Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1", + noise_level : float = 0.1, + noise_window : Optional[List[int]] = None, + sde_step_num : Optional[int] = None, + sde_type : Literal['Flow-SDE', 'Dance-SDE', 'CPS'] = 'Flow-SDE', + seed: Optional[int] = None, + **kwargs, + ): + super().__init__(template=template, **kwargs) + if noise_window is None: + self.noise_window = list(range(self.num_train_timesteps)) + else: + self.noise_window = list(noise_window) + + self.noise_level = noise_level + self.sde_step_num = sde_step_num or len(self.noise_window) + self.sde_type = sde_type + self.seed = seed or 42 + + def set_seed(self, seed: int) -> torch.Tensor: + self.seed = seed + + @property + def current_noise_steps(self) -> torch.Tensor: + if self.sde_step_num >= len(self.noise_window): + return self.noise_window + generator = torch.Generator().manual_seed(self.seed) + selected_indices = torch.randperm(len(self.noise_window), generator=generator)[:self.sde_step_num] + return self.noise_window[selected_indices] + + def get_current_noise_level(self, timestep) -> float: + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + if timestep_id in self.current_noise_steps: + return self.noise_level + else: + return 0.0 + + def step(self, + model_output, + timestep, + sample, + to_final=False, + prev_sample=None, + generator=None, + return_log_prob=False, + return_dict=False, + **kwargs + ): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + + # Convert to float32 for numerical stability + model_output = model_output.float() + sample = sample.float() + + current_noise_level = self.get_current_noise_level(timestep) + + dt = sigma_ - sigma + sigma_max = self.sigmas[1] # Use the max sigma < 1 + if self.sde_type == 'Flow-SDE': + std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * current_noise_level + prev_sample_mean = sample * (1 + std_dev_t**2 / (2 * sigma) * dt) + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt + if prev_sample is None: + variance_noise = torch.randn_like( + model_output.shape, + generator=generator, + ).to(device=model_output.device, dtype=model_output.dtype) + prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise + + if return_log_prob: + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) + - torch.log(std_dev_t * torch.sqrt(-1 * dt)) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + elif self.sde_type == 'Dance-SDE': + pred_original_sample = sample - sigma * model_output + std_dev_t = current_noise_level * torch.sqrt(-1 * dt) + log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2 + prev_sample_mean = sample + (model_output + log_term) * dt + if prev_sample is None: + variance_noise = torch.randn( + model_output.shape, + generator=generator, + ).to(device=model_output.device, dtype=model_output.dtype) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + if return_log_prob: + log_prob = ( + (-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))) + - math.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + elif self.sde_type == 'CPS': + # Coefficient Preserving Sampling + std_dev_t = sigma_ * torch.sin(current_noise_level * torch.pi / 2) + pred_original_sample = sample - sigma * model_output + noise_estimate = sample + model_output * (1 - sigma) + prev_sample_mean = pred_original_sample * (1 - sigma_) + noise_estimate * torch.sqrt(sigma_**2 - std_dev_t**2) + + if prev_sample is None: + variance_noise = torch.randn( + model_output.shape, + generator=generator, + ).to(device=model_output.device, dtype=model_output.dtype) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + if return_log_prob: + log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + if not return_log_prob: + log_prob = torch.zeros(sample.shape[0], device=sample.device) + + if return_dict: + return { + "prev_sample": prev_sample, + "log_prob": log_prob, + "prev_sample_mean": prev_sample_mean, + "std_dev_t": std_dev_t, + } + + return prev_sample, log_prob, prev_sample_mean, std_dev_t \ No newline at end of file