Skip to content

[Feat] Adds LongCat-AudioDiT pipeline #13390

Open
RuixiangMa wants to merge 12 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit
Open

[Feat] Adds LongCat-AudioDiT pipeline #13390
RuixiangMa wants to merge 12 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit

Conversation

@RuixiangMa
Copy link
Copy Markdown

@RuixiangMa RuixiangMa commented Apr 2, 2026

What does this PR do?

Adds LongCat-AudioDiT model support to diffusers.

Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.

Test

import soundfile as sf
import torch
from diffusers import LongCatAudioDiTPipeline

pipeline = LongCatAudioDiTPipeline.from_pretrained(
    "meituan-longcat/LongCat-AudioDiT-1B",
    torch_dtype=torch.float16,
)
pipeline = pipeline.to("cuda")

audio = pipeline(
    prompt="A calm ocean wave ambience with soft wind in the background.",
    audio_end_in_s=5.0,
    num_inference_steps=16,
    guidance_scale=4.0,
    output_type="pt",
).audios

output = audio[0, 0].float().cpu().numpy()
sf.write("longcat.wav", output, pipeline.sample_rate)

Result

longcat.wav

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa RuixiangMa changed the title Longcataudiodit [Feat] Adds LongCat-AudioDiT support Apr 2, 2026
@RuixiangMa RuixiangMa changed the title [Feat] Adds LongCat-AudioDiT support [Feat] Adds LongCat-AudioDiT pipeline Apr 2, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@dg845 dg845 requested review from dg845 and yiyixuxu April 4, 2026 00:31
)


def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).

Comment on lines +515 to +519
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(

See #13390 (comment).

Comment on lines +584 to +589
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)

Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
@RuixiangMa
Copy link
Copy Markdown
Author

These CI failures do not appear to be related to this PR


def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]:
num_inference_steps = max(int(num_inference_steps), 2)
num_updates = num_inference_steps - 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should define num_inference_steps to match the number of function evaluations we're performing (that is, to have the same semantics that num_updates currently has), which is the usual diffusers behavior. This would also allow us to remove the behavior where we overwrite num_inference_steps=1 below in __call__.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_inference_steps now matches the number of model evaluations, following the usual diffusers semantics. This also removes the previous behavior of promoting num_inference_steps=1 to 2.

return {key[len(prefix) :]: value for key, value in state_dict.items() if key.startswith(prefix)}


def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should inline _get_uniform_flow_match_scheduler_sigmas into __call__ so that it's easier to understand how the sigma schedule is being prepared. See e.g. Flux2Pipeline for an example of this:

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

We generally prefer not to have too many small functions in the pipeline code.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I inlined the uniform flow-matching sigma schedule preparation into __call__, similar to Flux2Pipeline.

Comment on lines +54 to +60
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
if isinstance(text, list):
if not text:
return 0.0
return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text)

en_dur_per_char = 0.082
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
if isinstance(text, list):
if not text:
return 0.0
return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text)
en_dur_per_char = 0.082
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
if not text:
return 0.0
if isinstance(text, str):
text = [text]
en_dur_per_char = 0.082

nit: I think refactoring this function to be non-recursive (by making it work naturally with a list of strings) would make it more clear.

first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6)
prompt_embeds = prompt_embeds + first_hidden
lengths = attention_mask.sum(dim=1).to(device)
return prompt_embeds.float(), lengths
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return prompt_embeds.float(), lengths
return prompt_embeds, lengths

Do we need to call .float() on prompt_embeds here? I think we should generally respect the output dtype from self.text_encoder.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the .float() cast from encode_prompt so that we respect the dtype produced by the text encoder.

)
self.scheduler.set_begin_index(0)
timesteps = self.scheduler.timesteps
sample = latents
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the standard name latents instead of sample would be more clear. It would also work better with PipelineTesterMixin tests.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed the denoising state from sample to the standard latents.

if latents is None:
duration = max(1, min(duration, max_duration))

text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device)
prompt_embeds, text_condition_len = self.encode_prompt(normalized_prompts, device)

Similarly to #13390 (comment), I think using the standard name prompt_embeds would be better here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed text_condition to the standard prompt_embeds

Comment on lines +536 to +537
if not return_dict:
return (waveform,)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not return_dict:
return (waveform,)
self.maybe_free_model_hooks()
if not return_dict:
return (waveform,)

Calling self.maybe_free_model_hooks() allows the pipeline to clear model hooks correctly, such as those used to support model offloading.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pipeline now calls self.maybe_free_model_hooks() before returning

Comment on lines +527 to +530
if output_type == "latent":
if not return_dict:
return (sample,)
return AudioPipelineOutput(audios=sample)
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if output_type == "latent":
if not return_dict:
return (sample,)
return AudioPipelineOutput(audios=sample)
if output_type == "latent":
waveform = sample

A little simpler. Also makes it so that we don't have to call self.maybe_free_model_hooks() twice (see #13390 (comment)).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latent output path now assigns waveform = latents and shares the same final return path, so maybe_free_model_hooks() only needs to be called once.

latent_cond=latent_cond,
).sample
pred = null_pred + (pred - null_pred) * guidance_scale
sample = self.scheduler.step(pred, t, sample, return_dict=False)[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sample = self.scheduler.step(pred, t, sample, return_dict=False)[0]
sample = self.scheduler.step(pred, t, sample, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
sample = callback_outputs.pop("latents", latent)
text_condition = callback_outputs.pop("prompt_embeds", prompt_embeds)

Example for supporting callbacks. This assumes we use the standard names latents and prompt_embeds (see #13390 (comment), #13390 (comment)). See also how e.g. Flux2Pipeline supports callbacks:

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added callback_on_step_end support in the denoising loop

guidance_scale: float = 4.0,
generator: torch.Generator | list[torch.Generator] | None = None,
output_type: str = "np",
return_dict: bool = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return_dict: bool = True,
return_dict: bool = True,
callback_on_step_end: Callable[[int, int], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],

Follow-up for callback support (see #13390 (comment)).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added callback_on_step_end and callback_on_step_end_tensor_inputs to __call__



class LongCatAudioDiTPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->transformer->vae"
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]

Follow up for callback support (#13390 (comment)). The callback tests specifically check for the name latents here, which is one reason to use it over sample.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_supports_gradient_checkpointing = False
_supports_gradient_checkpointing = False
_repeated_blocks = ["AudioDiTBlock"]

Setting _repeated_blocks here enables regional compilation support. This also allows us to not skip the TestLongCatAudioDiTTransformerCompile.test_torch_compile_repeated_blocks test.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +546 to +549
dtype = next(self.parameters()).dtype
hidden_states = hidden_states.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
timestep = timestep.to(dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dtype = next(self.parameters()).dtype
hidden_states = hidden_states.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
timestep = timestep.to(dtype)
hidden_states = hidden_states
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
timestep = timestep.to(hidden_states.dtype)

Matching the dtype of hidden_states rather than casting to next(self.parameters()).dtype makes supporting layerwise casting easier (as not all parameters will have the same dtype when layerwise casting is enabled).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformer now follows hidden_states.dtype instead of casting activations to next(self.parameters()).dtype, which is more compatible with layerwise casting.

encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0)
hidden_states = self.input_embed(hidden_states, attention_mask)
if self.use_latent_condition and latent_cond is not None:
latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask)
latent_cond = self.latent_embed(latent_cond.to(hidden_states.dtype), attention_mask)

Follow-up suggestion to #13390 (comment).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return (latents,)
return LongCatAudioDiTVaeEncoderOutput(latents=latents)

def decode(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def decode(
@apply_forward_hook
def decode(

Using the apply_forward_hook decorator here allows us to apply the model offloading hook to the decode method as well, because that hook would normally operate on forward. This helps support model offloading in the pipeline by correctly applying the hook when we decode the latents back to a waveform.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added @apply_forward_hook to LongCatAudioDiTVae.decode.

upsample_shortcut=upsample_shortcut,
)

def encode(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def encode(
@apply_forward_hook
def encode(

Analogous change to #13390 (comment) for encode.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from torch.nn.utils import weight_norm

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...utils import BaseOutput
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook

Follow up change to #13390 (comment).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +295 to +296
class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin):
@register_to_config
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin):
@register_to_config
class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin):
_supports_group_offloading = False
@register_to_config

Group offloading is not compatible with torch.nn.utils.weight_norm, disable it here so that we don't have to manually skip the corresponding test.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set _supports_group_offloading = False for the VAE because it uses torch.nn.utils.weight_norm

Comment on lines +88 to +98
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
def test_layerwise_casting_memory(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.")

def test_layerwise_casting_training(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.")

def test_group_offloading_with_layerwise_casting(self, *args, **kwargs):
pytest.skip(
"LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
def test_layerwise_casting_memory(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.")
def test_layerwise_casting_training(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.")
def test_group_offloading_with_layerwise_casting(self, *args, **kwargs):
pytest.skip(
"LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet."
)
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
pass

Layerwise casting should work if #13390 (comment) is applied.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the layerwise casting training and combined group-offloading/layerwise-casting skips after updating the dtype handling. I kept test_layerwise_casting_memory skipped
because the tiny transformer config does not provide stable peak-memory behavior for that assertion.

Comment on lines +101 to +103
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
def test_torch_compile_repeated_blocks(self):
pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
def test_torch_compile_repeated_blocks(self):
pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.")
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
pass

test_torch_compile_repeated_blocks should work if #13390 (comment) is applied.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the repeated-block compile skip after defining _repeated_blocks

Comment on lines +144 to +146
def test_save_load_optional_components(self):
self.skipTest("LongCatAudioDiTPipeline does not define optional components.")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_save_load_optional_components(self):
self.skipTest("LongCatAudioDiTPipeline does not define optional components.")

I think this test should handle the fact that LongCatAudioDiTPipeline doesn't have any optional components gracefully.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the optional-components skip

self.skipTest(
"LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test."
)

Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_sequential_cpu_offload_forward_pass(self):
self.skipTest(
"LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with sequential offloading."
)
def test_sequential_offload_forward_pass_twice(self):
self.skipTest(
"LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with sequential offloading."
)

I believe that sequential offloading currently doesn't work with weight_norm, so document that here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sequential offload skip messages now explicitly mention the torch.nn.utils.weight_norm incompatibility.

enable_full_determinism()


class DummyTokenizer:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to create a small tokenizer checkpoint that we could use for testing?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the ad-hoc DummyTokenizer with AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5"), matching the pattern used by other diffusers pipeline tests. The tokenizer now participates in the normal save/load path, so we no longer need to manually inject a dummy tokenizer after reload.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your continued work on this! Left some suggestions that should help LongCatAudioDiTPipeline support model offloading, layerwise casting, etc.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants