From 4b68de7cf65f4411cade22cb918d09e0bc793001 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 10 Nov 2025 15:00:10 +0100 Subject: [PATCH 1/6] Improve docstrings and type hints in scheduling_ddim.py - Add complete type hints for all function parameters - Enhance docstrings to follow project conventions - Add missing parameter descriptions Fixes #9567 --- src/diffusers/schedulers/scheduling_ddim.py | 90 +++++++++++++++++---- 1 file changed, 75 insertions(+), 15 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 5ee0d084f060..f098ab1fa53a 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): + num_diffusion_timesteps: int, + max_beta: float = 0.999, + alpha_transform_type: str = "cosine", +) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,7 +60,6 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. - Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to @@ -69,16 +68,16 @@ def betas_for_alpha_bar( Choose from `cosine` or `exp` Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + betas (`torch.Tensor`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": - def alpha_bar_fn(t): + def alpha_bar_fn(t: float) -> float: return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": - def alpha_bar_fn(t): + def alpha_bar_fn(t: float) -> float: return math.exp(t * -12.0) else: @@ -92,11 +91,10 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -250,7 +248,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + """ + Computes the variance of the noise added at a given diffusion step. + + For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM + literature: + var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively. + + Args: + timestep (`int`): + The current timestep in the diffusion process. + prev_timestep (`int`): + The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`. + + Returns: + `torch.Tensor`: + The variance for the current timestep. + """ alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -263,13 +279,21 @@ def _get_variance(self, timestep, prev_timestep): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." + photorealism as well as better image-text alignment, especially when using very large guidance weights. + + See https://huggingface.co/papers/2205.11487 + + Args: + sample (`torch.Tensor`): + The sample to threshold. - https://huggingface.co/papers/2205.11487 + Returns: + `torch.Tensor`: + The thresholded sample. """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -294,13 +318,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + device (`Union[str, torch.device]`, *optional*): + The device to use for the timesteps. + + Raises: + ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`. """ if num_inference_steps > self.config.num_train_timesteps: @@ -477,6 +506,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Adds noise to the original samples. + + Args: + original_samples (`torch.Tensor`): + The original samples to add noise to. + noise (`torch.Tensor`): + The noise to add to the original samples. + timesteps (`torch.IntTensor`): + The timesteps to add noise to. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -499,6 +543,22 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + """ + Computes the velocity of the sample. The velocity is defined as the difference between the original sample and + the noisy sample. See https://huggingface.co/papers/2010.02502 + + Args: + sample (`torch.Tensor`): + The sample to compute the velocity of. + noise (`torch.Tensor`): + The noise to compute the velocity of. + timesteps (`torch.IntTensor`): + The timesteps to compute the velocity of. + + Returns: + `torch.Tensor`: + The velocity of the sample. + """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) From 6103d84cd1bbc716ceb32ae12918ed9f370bb8ae Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 10 Nov 2025 19:28:11 +0100 Subject: [PATCH 2/6] Enhance docstrings and type hints in scheduling_ddim.py - Update parameter types and descriptions for clarity - Improve explanations in method docstrings to align with project standards - Add optional annotations for parameters where applicable --- src/diffusers/schedulers/scheduling_ddim.py | 49 ++++++++++++--------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index f098ab1fa53a..6489c790986f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -38,7 +38,7 @@ class DDIMSchedulerOutput(BaseOutput): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images, *optional*): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ @@ -375,7 +375,7 @@ def step( sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -386,20 +386,21 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - eta (`float`): - The weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`, defaults to `False`): + eta (`float`, *optional*, defaults to 0.0): + The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic) + and 1 corresponds to DDPM (fully stochastic). + use_clipped_model_output (`bool`, *optional*, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. generator (`torch.Generator`, *optional*): - A random number generator. - variance_noise (`torch.Tensor`): + A random number generator for reproducible sampling. + variance_noise (`torch.Tensor`, *optional*): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): @@ -507,19 +508,22 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.Tensor: """ - Adds noise to the original samples. + Add noise to the original samples according to the noise magnitude at each timestep. + + This implements the forward diffusion process using the formula: `noisy_sample = sqrt(alpha_prod) * + original_sample + sqrt(1 - alpha_prod) * noise` Args: original_samples (`torch.Tensor`): - The original samples to add noise to. + The original clean samples to which noise will be added. noise (`torch.Tensor`): - The noise to add to the original samples. + The noise tensor to add, typically sampled from a Gaussian distribution. timesteps (`torch.IntTensor`): - The timesteps to add noise to. + The timesteps indicating the noise level from the diffusion schedule. Returns: `torch.Tensor`: - The noisy samples. + The noisy samples with noise added according to the timestep schedule. """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement @@ -544,20 +548,25 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: """ - Computes the velocity of the sample. The velocity is defined as the difference between the original sample and - the noisy sample. See https://huggingface.co/papers/2010.02502 + Compute the velocity prediction for v-prediction models. + + The velocity is computed using the formula: `velocity = sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) * + sample` + + This is used in v-prediction models where the model directly predicts the velocity instead of the noise or the + sample. See section 2.4 of Imagen Video paper: https://imagen.research.google/video/paper.pdf Args: sample (`torch.Tensor`): - The sample to compute the velocity of. + The input sample (x_t) at the current timestep. noise (`torch.Tensor`): - The noise to compute the velocity of. + The noise tensor corresponding to the sample. timesteps (`torch.IntTensor`): - The timesteps to compute the velocity of. + The timesteps at which to compute the velocity. Returns: `torch.Tensor`: - The velocity of the sample. + The velocity prediction computed from the sample and noise at the given timesteps. """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) @@ -577,5 +586,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps From e69fae822497bd3cb0217388b9f28a564beab300 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 10 Nov 2025 20:18:11 +0100 Subject: [PATCH 3/6] Refine type hints and docstrings in scheduling_ddim.py - Update parameter types to use Literal for specific string options - Enhance docstring descriptions for clarity and consistency - Ensure all parameters have appropriate type annotations and defaults --- src/diffusers/schedulers/scheduling_ddim.py | 42 +++++++++++---------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6489c790986f..5f536ae31b4c 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: str = "cosine", + alpha_transform_type: Literal["cosine", "exp"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -61,14 +61,15 @@ def betas_for_alpha_bar( to that part of the diffusion process. Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` + num_diffusion_timesteps (`int`): + The number of betas to produce. + max_beta (`float`, defaults to 0.999): + The maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`Literal["cosine", "exp"]`, defaults to `"cosine"`): + The type of noise schedule for alpha_bar. Must be one of `"cosine"` or `"exp"`. Returns: - betas (`torch.Tensor`): the betas used by the scheduler to step the model outputs + `torch.Tensor`: The betas used by the scheduler to step the model outputs. """ if alpha_transform_type == "cosine": @@ -141,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one + of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): @@ -156,9 +157,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`): + Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion + process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such @@ -167,9 +168,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, defaults to `"leading"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`): + The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to + Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to @@ -185,17 +187,17 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: From 3f49eb33aa139389ac48a2bab67c2c673637ade1 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Wed, 12 Nov 2025 23:24:21 +0100 Subject: [PATCH 4/6] Apply review feedback on scheduling_ddim.py - Replace "prevent singularities" with "avoid numerical instability" for better clarity - Add backticks around `alpha_bar` variable name for consistent formatting - Convert Imagen Video paper URLs to Hugging Face papers references --- src/diffusers/schedulers/scheduling_ddim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 5f536ae31b4c..4389869b65f9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -64,9 +64,9 @@ def betas_for_alpha_bar( num_diffusion_timesteps (`int`): The number of betas to produce. max_beta (`float`, defaults to 0.999): - The maximum beta to use; use values lower than 1 to prevent singularities. + The maximum beta to use; use values lower than 1 to avoid numerical instability. alpha_transform_type (`Literal["cosine", "exp"]`, defaults to `"cosine"`): - The type of noise schedule for alpha_bar. Must be one of `"cosine"` or `"exp"`. + The type of noise schedule for `alpha_bar`. Must be one of `"cosine"` or `"exp"`. Returns: `torch.Tensor`: The betas used by the scheduler to step the model outputs. @@ -160,7 +160,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`): Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). + Video](https://huggingface.co/papers/2210.02303) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -556,7 +556,7 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor sample` This is used in v-prediction models where the model directly predicts the velocity instead of the noise or the - sample. See section 2.4 of Imagen Video paper: https://imagen.research.google/video/paper.pdf + sample. See section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper. Args: sample (`torch.Tensor`): From e1c7acff295a0e88f14b60a8aca4f5bc45c802c8 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Thu, 13 Nov 2025 08:42:57 +0100 Subject: [PATCH 5/6] Propagate changes using 'make fix-copies' --- src/diffusers/schedulers/scheduling_ddim.py | 81 ++++--------------- .../schedulers/scheduling_ddim_inverse.py | 1 - .../schedulers/scheduling_ddim_parallel.py | 12 ++- src/diffusers/schedulers/scheduling_ddpm.py | 1 - .../schedulers/scheduling_ddpm_parallel.py | 1 - .../scheduling_dpmsolver_multistep.py | 1 - .../scheduling_euler_ancestral_discrete.py | 1 - .../schedulers/scheduling_euler_discrete.py | 1 - src/diffusers/schedulers/scheduling_lcm.py | 1 - src/diffusers/schedulers/scheduling_tcd.py | 19 ++++- .../schedulers/scheduling_unipc_multistep.py | 1 - 11 files changed, 43 insertions(+), 77 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 4389869b65f9..c63f1f4c1675 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -38,7 +38,7 @@ class DDIMSchedulerOutput(BaseOutput): prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images, *optional*): + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ @@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps: int, - max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", -) -> torch.Tensor: + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,25 +60,25 @@ def betas_for_alpha_bar( Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. + Args: - num_diffusion_timesteps (`int`): - The number of betas to produce. - max_beta (`float`, defaults to 0.999): - The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`Literal["cosine", "exp"]`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Must be one of `"cosine"` or `"exp"`. + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: - `torch.Tensor`: The betas used by the scheduler to step the model outputs. + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": - def alpha_bar_fn(t: float) -> float: + def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": - def alpha_bar_fn(t: float) -> float: + def alpha_bar_fn(t): return math.exp(t * -12.0) else: @@ -281,21 +281,13 @@ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ - Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights. - - See https://huggingface.co/papers/2205.11487 + photorealism as well as better image-text alignment, especially when using very large guidance weights." - Args: - sample (`torch.Tensor`): - The sample to threshold. - - Returns: - `torch.Tensor`: - The thresholded sample. + https://huggingface.co/papers/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape @@ -509,24 +501,6 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: - """ - Add noise to the original samples according to the noise magnitude at each timestep. - - This implements the forward diffusion process using the formula: `noisy_sample = sqrt(alpha_prod) * - original_sample + sqrt(1 - alpha_prod) * noise` - - Args: - original_samples (`torch.Tensor`): - The original clean samples to which noise will be added. - noise (`torch.Tensor`): - The noise tensor to add, typically sampled from a Gaussian distribution. - timesteps (`torch.IntTensor`): - The timesteps indicating the noise level from the diffusion schedule. - - Returns: - `torch.Tensor`: - The noisy samples with noise added according to the timestep schedule. - """ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls @@ -549,27 +523,6 @@ def add_noise( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: - """ - Compute the velocity prediction for v-prediction models. - - The velocity is computed using the formula: `velocity = sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) * - sample` - - This is used in v-prediction models where the model directly predicts the velocity instead of the noise or the - sample. See section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper. - - Args: - sample (`torch.Tensor`): - The input sample (x_t) at the current timestep. - noise (`torch.Tensor`): - The noise tensor corresponding to the sample. - timesteps (`torch.IntTensor`): - The timesteps at which to compute the velocity. - - Returns: - `torch.Tensor`: - The velocity prediction computed from the sample and noise at the given timesteps. - """ # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 49dba840d089..d13ac606805c 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -95,7 +95,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 7c3f03a8dbe1..2c62fdbb2e4f 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -194,17 +193,17 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: @@ -324,6 +323,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + device (`Union[str, torch.device]`, *optional*): + The device to use for the timesteps. + + Raises: + ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`. """ if num_inference_steps > self.config.num_train_timesteps: diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0fab6d910a82..b59fae066495 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -94,7 +94,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index ec741f9ecb7d..c78bfe290f53 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -96,7 +96,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 8b523cd13f1f..0560a030321d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 9cdaa2c5e101..38ad401edc49 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index f58d918dbfbe..59199bf71013 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -100,7 +100,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index cd7a29fe675f..8a0fd480505c 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -99,7 +99,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 3fd5c341eca9..ce7d1d5316b4 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -98,7 +98,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -316,6 +315,24 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance def _get_variance(self, timestep, prev_timestep): + """ + Computes the variance of the noise added at a given diffusion step. + + For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM + literature: + var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively. + + Args: + timestep (`int`): + The current timestep in the diffusion process. + prev_timestep (`int`): + The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`. + + Returns: + `torch.Tensor`: + The variance for the current timestep. + """ alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 162a34bd2774..a596fef24559 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. From bda3ea729046717b13ac4ac1ec33e3cc7ce7ed78 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Thu, 13 Nov 2025 17:49:18 +0100 Subject: [PATCH 6/6] Add missing Literal --- src/diffusers/schedulers/scheduling_ddim_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 2c62fdbb2e4f..deffdb4ff7d3 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch