[Feat] Adds LongCat-AudioDiT pipeline #13390
[Feat] Adds LongCat-AudioDiT pipeline #13390RuixiangMa wants to merge 12 commits intohuggingface:mainfrom
Conversation
Signed-off-by: Lancer <maruixiang6688@gmail.com>
9c4613f to
d2a2621
Compare
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
|
|
||
| def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: |
There was a problem hiding this comment.
Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| self.time_embed = AudioDiTTimestepEmbedding(dim) | ||
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | ||
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | ||
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | ||
| self.blocks = nn.ModuleList( |
There was a problem hiding this comment.
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( | |
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( |
See #13390 (comment).
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| batch_size = hidden_states.shape[0] | ||
| if timestep.ndim == 0: | ||
| timestep = timestep.repeat(batch_size) | ||
| timestep_embed = self.time_embed(timestep) | ||
| text_mask = encoder_attention_mask.bool() | ||
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
There was a problem hiding this comment.
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) | |
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:
There was a problem hiding this comment.
Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.
|
These CI failures do not appear to be related to this PR |
|
|
||
| def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: | ||
| num_inference_steps = max(int(num_inference_steps), 2) | ||
| num_updates = num_inference_steps - 1 |
There was a problem hiding this comment.
I think we should define num_inference_steps to match the number of function evaluations we're performing (that is, to have the same semantics that num_updates currently has), which is the usual diffusers behavior. This would also allow us to remove the behavior where we overwrite num_inference_steps=1 below in __call__.
There was a problem hiding this comment.
num_inference_steps now matches the number of model evaluations, following the usual diffusers semantics. This also removes the previous behavior of promoting num_inference_steps=1 to 2.
| return {key[len(prefix) :]: value for key, value in state_dict.items() if key.startswith(prefix)} | ||
|
|
||
|
|
||
| def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: |
There was a problem hiding this comment.
I think we should inline _get_uniform_flow_match_scheduler_sigmas into __call__ so that it's easier to understand how the sigma schedule is being prepared. See e.g. Flux2Pipeline for an example of this:
We generally prefer not to have too many small functions in the pipeline code.
There was a problem hiding this comment.
I inlined the uniform flow-matching sigma schedule preparation into __call__, similar to Flux2Pipeline.
| def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: | ||
| if isinstance(text, list): | ||
| if not text: | ||
| return 0.0 | ||
| return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text) | ||
|
|
||
| en_dur_per_char = 0.082 |
There was a problem hiding this comment.
| def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: | |
| if isinstance(text, list): | |
| if not text: | |
| return 0.0 | |
| return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text) | |
| en_dur_per_char = 0.082 | |
| def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: | |
| if not text: | |
| return 0.0 | |
| if isinstance(text, str): | |
| text = [text] | |
| en_dur_per_char = 0.082 |
nit: I think refactoring this function to be non-recursive (by making it work naturally with a list of strings) would make it more clear.
| first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6) | ||
| prompt_embeds = prompt_embeds + first_hidden | ||
| lengths = attention_mask.sum(dim=1).to(device) | ||
| return prompt_embeds.float(), lengths |
There was a problem hiding this comment.
| return prompt_embeds.float(), lengths | |
| return prompt_embeds, lengths |
Do we need to call .float() on prompt_embeds here? I think we should generally respect the output dtype from self.text_encoder.
There was a problem hiding this comment.
Removed the .float() cast from encode_prompt so that we respect the dtype produced by the text encoder.
| ) | ||
| self.scheduler.set_begin_index(0) | ||
| timesteps = self.scheduler.timesteps | ||
| sample = latents |
There was a problem hiding this comment.
I think using the standard name latents instead of sample would be more clear. It would also work better with PipelineTesterMixin tests.
There was a problem hiding this comment.
Renamed the denoising state from sample to the standard latents.
| if latents is None: | ||
| duration = max(1, min(duration, max_duration)) | ||
|
|
||
| text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) |
There was a problem hiding this comment.
| text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) | |
| prompt_embeds, text_condition_len = self.encode_prompt(normalized_prompts, device) |
Similarly to #13390 (comment), I think using the standard name prompt_embeds would be better here.
There was a problem hiding this comment.
Renamed text_condition to the standard prompt_embeds
| if not return_dict: | ||
| return (waveform,) |
There was a problem hiding this comment.
| if not return_dict: | |
| return (waveform,) | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (waveform,) |
Calling self.maybe_free_model_hooks() allows the pipeline to clear model hooks correctly, such as those used to support model offloading.
There was a problem hiding this comment.
The pipeline now calls self.maybe_free_model_hooks() before returning
| if output_type == "latent": | ||
| if not return_dict: | ||
| return (sample,) | ||
| return AudioPipelineOutput(audios=sample) |
There was a problem hiding this comment.
| if output_type == "latent": | |
| if not return_dict: | |
| return (sample,) | |
| return AudioPipelineOutput(audios=sample) | |
| if output_type == "latent": | |
| waveform = sample |
A little simpler. Also makes it so that we don't have to call self.maybe_free_model_hooks() twice (see #13390 (comment)).
There was a problem hiding this comment.
The latent output path now assigns waveform = latents and shares the same final return path, so maybe_free_model_hooks() only needs to be called once.
| latent_cond=latent_cond, | ||
| ).sample | ||
| pred = null_pred + (pred - null_pred) * guidance_scale | ||
| sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] |
There was a problem hiding this comment.
| sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] | |
| sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| sample = callback_outputs.pop("latents", latent) | |
| text_condition = callback_outputs.pop("prompt_embeds", prompt_embeds) |
Example for supporting callbacks. This assumes we use the standard names latents and prompt_embeds (see #13390 (comment), #13390 (comment)). See also how e.g. Flux2Pipeline supports callbacks:
diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 993 to 997 in dc8d903
There was a problem hiding this comment.
Added callback_on_step_end support in the denoising loop
| guidance_scale: float = 4.0, | ||
| generator: torch.Generator | list[torch.Generator] | None = None, | ||
| output_type: str = "np", | ||
| return_dict: bool = True, |
There was a problem hiding this comment.
| return_dict: bool = True, | |
| return_dict: bool = True, | |
| callback_on_step_end: Callable[[int, int], None] | None = None, | |
| callback_on_step_end_tensor_inputs: list[str] = ["latents"], |
Follow-up for callback support (see #13390 (comment)).
There was a problem hiding this comment.
Added callback_on_step_end and callback_on_step_end_tensor_inputs to __call__
|
|
||
|
|
||
| class LongCatAudioDiTPipeline(DiffusionPipeline): | ||
| model_cpu_offload_seq = "text_encoder->transformer->vae" |
There was a problem hiding this comment.
| model_cpu_offload_seq = "text_encoder->transformer->vae" | |
| model_cpu_offload_seq = "text_encoder->transformer->vae" | |
| _callback_tensor_inputs = ["latents", "prompt_embeds"] |
Follow up for callback support (#13390 (comment)). The callback tests specifically check for the name latents here, which is one reason to use it over sample.
|
|
||
|
|
||
| class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin): | ||
| _supports_gradient_checkpointing = False |
There was a problem hiding this comment.
| _supports_gradient_checkpointing = False | |
| _supports_gradient_checkpointing = False | |
| _repeated_blocks = ["AudioDiTBlock"] |
Setting _repeated_blocks here enables regional compilation support. This also allows us to not skip the TestLongCatAudioDiTTransformerCompile.test_torch_compile_repeated_blocks test.
| dtype = next(self.parameters()).dtype | ||
| hidden_states = hidden_states.to(dtype) | ||
| encoder_hidden_states = encoder_hidden_states.to(dtype) | ||
| timestep = timestep.to(dtype) |
There was a problem hiding this comment.
| dtype = next(self.parameters()).dtype | |
| hidden_states = hidden_states.to(dtype) | |
| encoder_hidden_states = encoder_hidden_states.to(dtype) | |
| timestep = timestep.to(dtype) | |
| hidden_states = hidden_states | |
| encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) | |
| timestep = timestep.to(hidden_states.dtype) |
Matching the dtype of hidden_states rather than casting to next(self.parameters()).dtype makes supporting layerwise casting easier (as not all parameters will have the same dtype when layerwise casting is enabled).
There was a problem hiding this comment.
The transformer now follows hidden_states.dtype instead of casting activations to next(self.parameters()).dtype, which is more compatible with layerwise casting.
| encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0) | ||
| hidden_states = self.input_embed(hidden_states, attention_mask) | ||
| if self.use_latent_condition and latent_cond is not None: | ||
| latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask) |
There was a problem hiding this comment.
| latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask) | |
| latent_cond = self.latent_embed(latent_cond.to(hidden_states.dtype), attention_mask) |
Follow-up suggestion to #13390 (comment).
| return (latents,) | ||
| return LongCatAudioDiTVaeEncoderOutput(latents=latents) | ||
|
|
||
| def decode( |
There was a problem hiding this comment.
| def decode( | |
| @apply_forward_hook | |
| def decode( |
Using the apply_forward_hook decorator here allows us to apply the model offloading hook to the decode method as well, because that hook would normally operate on forward. This helps support model offloading in the pipeline by correctly applying the hook when we decode the latents back to a waveform.
There was a problem hiding this comment.
Added @apply_forward_hook to LongCatAudioDiTVae.decode.
| upsample_shortcut=upsample_shortcut, | ||
| ) | ||
|
|
||
| def encode( |
There was a problem hiding this comment.
| def encode( | |
| @apply_forward_hook | |
| def encode( |
Analogous change to #13390 (comment) for encode.
| from torch.nn.utils import weight_norm | ||
|
|
||
| from ...configuration_utils import ConfigMixin, register_to_config | ||
| from ...utils import BaseOutput |
There was a problem hiding this comment.
| from ...utils import BaseOutput | |
| from ...utils import BaseOutput | |
| from ...utils.accelerate_utils import apply_forward_hook |
Follow up change to #13390 (comment).
| class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin): | ||
| @register_to_config |
There was a problem hiding this comment.
| class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin): | |
| @register_to_config | |
| class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin): | |
| _supports_group_offloading = False | |
| @register_to_config |
Group offloading is not compatible with torch.nn.utils.weight_norm, disable it here so that we don't have to manually skip the corresponding test.
There was a problem hiding this comment.
Set _supports_group_offloading = False for the VAE because it uses torch.nn.utils.weight_norm
| class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): | ||
| def test_layerwise_casting_memory(self): | ||
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.") | ||
|
|
||
| def test_layerwise_casting_training(self): | ||
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") | ||
|
|
||
| def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): | ||
| pytest.skip( | ||
| "LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet." | ||
| ) |
There was a problem hiding this comment.
| class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): | |
| def test_layerwise_casting_memory(self): | |
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.") | |
| def test_layerwise_casting_training(self): | |
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") | |
| def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): | |
| pytest.skip( | |
| "LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet." | |
| ) | |
| class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): | |
| pass |
Layerwise casting should work if #13390 (comment) is applied.
There was a problem hiding this comment.
I removed the layerwise casting training and combined group-offloading/layerwise-casting skips after updating the dtype handling. I kept test_layerwise_casting_memory skipped
because the tiny transformer config does not provide stable peak-memory behavior for that assertion.
| class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin): | ||
| def test_torch_compile_repeated_blocks(self): | ||
| pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.") |
There was a problem hiding this comment.
| class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin): | |
| def test_torch_compile_repeated_blocks(self): | |
| pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.") | |
| class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin): | |
| pass |
test_torch_compile_repeated_blocks should work if #13390 (comment) is applied.
There was a problem hiding this comment.
Removed the repeated-block compile skip after defining _repeated_blocks
| def test_save_load_optional_components(self): | ||
| self.skipTest("LongCatAudioDiTPipeline does not define optional components.") | ||
|
|
There was a problem hiding this comment.
| def test_save_load_optional_components(self): | |
| self.skipTest("LongCatAudioDiTPipeline does not define optional components.") |
I think this test should handle the fact that LongCatAudioDiTPipeline doesn't have any optional components gracefully.
There was a problem hiding this comment.
Removed the optional-components skip
| self.skipTest( | ||
| "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." | ||
| ) | ||
|
|
There was a problem hiding this comment.
| def test_sequential_cpu_offload_forward_pass(self): | |
| self.skipTest( | |
| "LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with sequential offloading." | |
| ) | |
| def test_sequential_offload_forward_pass_twice(self): | |
| self.skipTest( | |
| "LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with sequential offloading." | |
| ) |
I believe that sequential offloading currently doesn't work with weight_norm, so document that here.
There was a problem hiding this comment.
The sequential offload skip messages now explicitly mention the torch.nn.utils.weight_norm incompatibility.
| enable_full_determinism() | ||
|
|
||
|
|
||
| class DummyTokenizer: |
There was a problem hiding this comment.
Would it be possible to create a small tokenizer checkpoint that we could use for testing?
There was a problem hiding this comment.
I replaced the ad-hoc DummyTokenizer with AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5"), matching the pattern used by other diffusers pipeline tests. The tokenizer now participates in the normal save/load path, so we no longer need to manually inject a dummy tokenizer after reload.
dg845
left a comment
There was a problem hiding this comment.
Thanks for your continued work on this! Left some suggestions that should help LongCatAudioDiTPipeline support model offloading, layerwise casting, etc.
What does this PR do?
Adds LongCat-AudioDiT model support to diffusers.
Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.
Test
Result
longcat.wav
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.