Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion diffsynth/diffusion/flow_match.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch, math
from typing_extensions import Literal
from typing_extensions import Literal, Optional, List


class FlowMatchScheduler():
Expand Down Expand Up @@ -177,3 +177,145 @@ def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights


class FlowMatchSDEScheduler(FlowMatchScheduler):

def __init__(self,
template : Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1",
noise_level : float = 0.1,
noise_window : Optional[List[int]] = None,
sde_step_num : Optional[int] = None,
sde_type : Literal['Flow-SDE', 'Dance-SDE', 'CPS'] = 'Flow-SDE',
seed: Optional[int] = None,
**kwargs,
):
super().__init__(template=template, **kwargs)
if noise_window is None:
self.noise_window = list(range(self.num_train_timesteps))
else:
self.noise_window = list(noise_window)
Comment on lines +194 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

self.noise_window is initialized as a list. However, in the current_noise_steps property (line 208), it is indexed with a torch.Tensor on line 212 (self.noise_window[selected_indices]), which will raise a TypeError. To fix this, self.noise_window should be a torch.Tensor. This change will also make the return type of current_noise_steps consistent with its type hint (torch.Tensor).

Suggested change
if noise_window is None:
self.noise_window = list(range(self.num_train_timesteps))
else:
self.noise_window = list(noise_window)
if noise_window is None:
self.noise_window = torch.arange(self.num_train_timesteps)
else:
self.noise_window = torch.tensor(noise_window)


self.noise_level = noise_level
self.sde_step_num = sde_step_num or len(self.noise_window)
self.sde_type = sde_type
self.seed = seed or 42

def set_seed(self, seed: int) -> torch.Tensor:
self.seed = seed

@property
def current_noise_steps(self) -> torch.Tensor:
if self.sde_step_num >= len(self.noise_window):
return self.noise_window
generator = torch.Generator().manual_seed(self.seed)
selected_indices = torch.randperm(len(self.noise_window), generator=generator)[:self.sde_step_num]
return self.noise_window[selected_indices]

def get_current_noise_level(self, timestep) -> float:
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
if timestep_id in self.current_noise_steps:
return self.noise_level
else:
return 0.0

def step(self,
model_output,
timestep,
sample,
to_final=False,
prev_sample=None,
generator=None,
return_log_prob=False,
return_dict=False,
**kwargs
):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 0
else:
sigma_ = self.sigmas[timestep_id + 1]

# Convert to float32 for numerical stability
model_output = model_output.float()
sample = sample.float()

current_noise_level = self.get_current_noise_level(timestep)

dt = sigma_ - sigma
sigma_max = self.sigmas[1] # Use the max sigma < 1
if self.sde_type == 'Flow-SDE':
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * current_noise_level
prev_sample_mean = sample * (1 + std_dev_t**2 / (2 * sigma) * dt) + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
if prev_sample is None:
variance_noise = torch.randn_like(
model_output.shape,
generator=generator,
).to(device=model_output.device, dtype=model_output.dtype)
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise

if return_log_prob:
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
Comment on lines +263 to +268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The log probability calculation for Flow-SDE can be unstable if dt is zero, which would cause division by zero or log(0). It's safer to add a small epsilon for numerical stability. Additionally, torch.sqrt(-1 * dt) is computed multiple times; pre-calculating it can improve clarity and slightly improve efficiency.

Suggested change
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
sqrt_neg_dt = torch.sqrt(-dt)
variance = (std_dev_t * sqrt_neg_dt) ** 2
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance + 1e-9)
- torch.log(std_dev_t * sqrt_neg_dt + 1e-9)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))


elif self.sde_type == 'Dance-SDE':
pred_original_sample = sample - sigma * model_output
std_dev_t = current_noise_level * torch.sqrt(-1 * dt)
log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line involves a division by sigma**2. Some scheduler configurations, like set_timesteps_wan, can produce sigma values of 0. This will lead to a division-by-zero error. Please add a small epsilon to the denominator for numerical stability.

Suggested change
log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2
log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / (sigma**2 + 1e-9)

prev_sample_mean = sample + (model_output + log_term) * dt
if prev_sample is None:
variance_noise = torch.randn(
model_output.shape,
generator=generator,
).to(device=model_output.device, dtype=model_output.dtype)
prev_sample = prev_sample_mean + std_dev_t * variance_noise

if return_log_prob:
log_prob = (
(-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)))
- math.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
Comment on lines +284 to +286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues in this block:

  1. math.log is used on a tensor std_dev_t (line 285), which will cause a TypeError. You should use torch.log instead.
  2. If dt is zero, std_dev_t will be zero, leading to division by zero in the first term and log(0) in the second term. Adding a small epsilon will improve numerical stability.
Suggested change
(-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)))
- math.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
(-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2) + 1e-9))
- torch.log(std_dev_t + 1e-9)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))

)

# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))

elif self.sde_type == 'CPS':
# Coefficient Preserving Sampling
std_dev_t = sigma_ * torch.sin(current_noise_level * torch.pi / 2)
pred_original_sample = sample - sigma * model_output
noise_estimate = sample + model_output * (1 - sigma)
prev_sample_mean = pred_original_sample * (1 - sigma_) + noise_estimate * torch.sqrt(sigma_**2 - std_dev_t**2)

if prev_sample is None:
variance_noise = torch.randn(
model_output.shape,
generator=generator,
).to(device=model_output.device, dtype=model_output.dtype)
prev_sample = prev_sample_mean + std_dev_t * variance_noise

if return_log_prob:
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
Comment on lines +307 to +308
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for log_prob for the CPS SDE type appears to be incomplete. It's currently calculating only the negative squared error. A proper log probability for a Gaussian distribution should also include terms for the variance and the normalization constant, following the formula: log P(x) = - (x - μ)² / (2σ²) - log(σ) - 0.5 * log(2π). Here, the variance is std_dev_t**2.

Suggested change
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
variance = std_dev_t**2
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance + 1e-9)
- torch.log(std_dev_t + 1e-9)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))


if not return_log_prob:
log_prob = torch.zeros(sample.shape[0], device=sample.device)

if return_dict:
return {
"prev_sample": prev_sample,
"log_prob": log_prob,
"prev_sample_mean": prev_sample_mean,
"std_dev_t": std_dev_t,
}

return prev_sample, log_prob, prev_sample_mean, std_dev_t