From 38cf8fd6a923148356246d599066408d39382308 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 10 Nov 2025 17:08:06 +0100 Subject: [PATCH 1/3] Improve docstrings and type hints in scheduling_amused.py - Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk) - Enhance AmusedSchedulerOutput with proper Optional typing - Add comprehensive docstrings for AmusedScheduler class - Improve __init__, set_timesteps, step, and add_noise methods - Fix type hints to match documentation conventions - All changes follow project standards from issue #9567 --- src/diffusers/schedulers/scheduling_amused.py | 122 ++++++++++++++++-- 1 file changed, 114 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py index 238b8d869171..632b29dbbd3b 100644 --- a/src/diffusers/schedulers/scheduling_amused.py +++ b/src/diffusers/schedulers/scheduling_amused.py @@ -9,13 +9,48 @@ from .scheduling_utils import SchedulerMixin -def gumbel_noise(t, generator=None): +def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor: + """ + Generate Gumbel noise for sampling. + + Args: + t (`torch.Tensor`): + Input tensor to match the shape and dtype of the output noise. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible sampling. + + Returns: + `torch.Tensor`: + Gumbel-distributed noise with the same shape as the input tensor. + """ device = generator.device if generator is not None else t.device noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20)) -def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): +def mask_by_random_topk( + mask_len: torch.Tensor, + probs: torch.Tensor, + temperature: float = 1.0, + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + """ + Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness. + + Args: + mask_len (`torch.Tensor`): + Number of tokens to mask per sample in the batch. + probs (`torch.Tensor`): + Probability scores for each token. + temperature (`float`, *optional*, defaults to 1.0): + Temperature parameter for controlling randomness in the masking process. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible sampling. + + Returns: + `torch.Tensor`: + Boolean mask indicating which tokens should be masked. + """ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator) sorted_confidence = torch.sort(confidence, dim=-1).values cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) @@ -38,13 +73,14 @@ class AmusedSchedulerOutput(BaseOutput): """ prev_sample: torch.Tensor - pred_original_sample: torch.Tensor = None + pred_original_sample: Optional[torch.Tensor] = None class AmusedScheduler(SchedulerMixin, ConfigMixin): order = 1 - temperatures: torch.Tensor + temperatures: Optional[torch.Tensor] + timesteps: Optional[torch.Tensor] @register_to_config def __init__( @@ -52,6 +88,16 @@ def __init__( mask_token_id: int, masking_schedule: str = "cosine", ): + """ + Create a new AmusedScheduler instance. + + Args: + mask_token_id (`int`): + The token ID used to represent masked tokens in the sequence. + masking_schedule (`str`, *optional*, defaults to `"cosine"`): + The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or + `"linear"`. + """ self.temperatures = None self.timesteps = None @@ -60,7 +106,21 @@ def set_timesteps( num_inference_steps: int, temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), device: Union[str, torch.device] = None, - ): + ) -> None: + """ + Set the discrete timesteps used for the diffusion chain. + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to `(2, 0)`): + Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided, + temperatures will be linearly interpolated between the first and second values across all timesteps. If + a single value is provided, temperatures will be linearly interpolated from that value to 0.01. + device (`Union[str, torch.device]`, *optional*): + The device to which the timesteps and temperatures should be moved. If not specified, uses the default + device. + """ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) if isinstance(temperature, (tuple, list)): @@ -71,12 +131,38 @@ def set_timesteps( def step( self, model_output: torch.Tensor, - timestep: torch.long, + timestep: int, sample: torch.LongTensor, - starting_mask_ratio: int = 1, + starting_mask_ratio: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[AmusedSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by masking tokens based on confidence scores. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens, + codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.LongTensor`): + A current instance of a sample created by the diffusion process. Contains token IDs, with masked + positions indicated by `mask_token_id`. + starting_mask_ratio (`float`, *optional*, defaults to 1.0): + A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being + masked at each step. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible sampling. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple. + + Returns: + [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`: + If `return_dict` is `True`, returns [`~schedulers.scheduling_amused.AmusedSchedulerOutput`], otherwise + returns a tuple where the first element is the sample tensor and the second element is the predicted + original sample tensor. + """ two_dim_input = sample.ndim == 3 and model_output.ndim == 4 if two_dim_input: @@ -137,7 +223,27 @@ def step( return AmusedSchedulerOutput(prev_sample, pred_original_sample) - def add_noise(self, sample, timesteps, generator=None): + def add_noise( + self, + sample: torch.LongTensor, + timesteps: int, + generator: Optional[torch.Generator] = None, + ) -> torch.LongTensor: + """ + Add noise to a sample by randomly masking tokens according to the masking schedule. + + Args: + sample (`torch.LongTensor`): + The input sample containing token IDs to be partially masked. + timesteps (`int`): + The timestep that determines how much masking to apply. Higher timesteps result in more masking. + generator (`torch.Generator`, *optional*): + A random number generator for reproducible masking. + + Returns: + `torch.LongTensor`: + The sample with some tokens replaced by `mask_token_id` according to the masking schedule. + """ step_idx = (self.timesteps == timesteps).nonzero() ratio = (step_idx + 1) / len(self.timesteps) From bb689ca86fd90beb7a2c620572ec31070a116cec Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 10 Nov 2025 19:59:03 +0100 Subject: [PATCH 2/3] Enhance type hints and docstrings in scheduling_amused.py - Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types. - Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device. - Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters. - Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency. --- src/diffusers/schedulers/scheduling_amused.py | 67 ++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py index 632b29dbbd3b..5698d95b0650 100644 --- a/src/diffusers/schedulers/scheduling_amused.py +++ b/src/diffusers/schedulers/scheduling_amused.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import torch @@ -21,7 +21,7 @@ def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) - Returns: `torch.Tensor`: - Gumbel-distributed noise with the same shape as the input tensor. + Gumbel-distributed noise with the same shape, dtype, and device as the input tensor. """ device = generator.device if generator is not None else t.device noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) @@ -64,12 +64,12 @@ class AmusedSchedulerOutput(BaseOutput): Output class for the scheduler's `step` function output. Args: - prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample `(x_{0})` based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. + prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`): + Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model + input in the denoising loop. + pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*): + The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current + timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor @@ -77,6 +77,23 @@ class AmusedSchedulerOutput(BaseOutput): class AmusedScheduler(SchedulerMixin, ConfigMixin): + """ + A scheduler for masked token generation as used in [Amused](https://huggingface.co/amused). + + This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear + schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates + on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models. + + This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the + generic methods the library implements for all schedulers such as loading and saving. + + Args: + mask_token_id (`int`): + The token ID used to represent masked tokens in the sequence. + masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`): + The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`. + """ + order = 1 temperatures: Optional[torch.Tensor] @@ -86,40 +103,30 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin): def __init__( self, mask_token_id: int, - masking_schedule: str = "cosine", + masking_schedule: Literal["cosine", "linear"] = "cosine", ): - """ - Create a new AmusedScheduler instance. - - Args: - mask_token_id (`int`): - The token ID used to represent masked tokens in the sequence. - masking_schedule (`str`, *optional*, defaults to `"cosine"`): - The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or - `"linear"`. - """ self.temperatures = None self.timesteps = None def set_timesteps( self, num_inference_steps: int, - temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), - device: Union[str, torch.device] = None, + temperature: Union[float, Tuple[float, float], List[float]] = (2, 0), + device: Optional[Union[str, torch.device]] = None, ) -> None: """ - Set the discrete timesteps used for the diffusion chain. + Set the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. - temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to `(2, 0)`): + temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`): Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided, temperatures will be linearly interpolated between the first and second values across all timesteps. If a single value is provided, temperatures will be linearly interpolated from that value to 0.01. - device (`Union[str, torch.device]`, *optional*): - The device to which the timesteps and temperatures should be moved. If not specified, uses the default - device. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not + moved. """ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) @@ -136,7 +143,7 @@ def step( starting_mask_ratio: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, - ) -> Union[AmusedSchedulerOutput, Tuple]: + ) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]: """ Predict the sample at the previous timestep by masking tokens based on confidence scores. @@ -159,9 +166,9 @@ def step( Returns: [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`: - If `return_dict` is `True`, returns [`~schedulers.scheduling_amused.AmusedSchedulerOutput`], otherwise - returns a tuple where the first element is the sample tensor and the second element is the predicted - original sample tensor. + If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the + second element is the predicted original sample tensor (`pred_original_sample`). """ two_dim_input = sample.ndim == 3 and model_output.ndim == 4 From c5cc592a2bcecdfd18493e99b4925b2be6596830 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Wed, 12 Nov 2025 23:30:02 +0100 Subject: [PATCH 3/3] Apply review feedback on scheduling_amused.py - Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions --- src/diffusers/schedulers/scheduling_amused.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py index 5698d95b0650..a0b8fbc862b0 100644 --- a/src/diffusers/schedulers/scheduling_amused.py +++ b/src/diffusers/schedulers/scheduling_amused.py @@ -78,7 +78,7 @@ class AmusedSchedulerOutput(BaseOutput): class AmusedScheduler(SchedulerMixin, ConfigMixin): """ - A scheduler for masked token generation as used in [Amused](https://huggingface.co/amused). + A scheduler for masked token generation as used in [`AmusedPipeline`]. This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates