From 88f03c1c10b69fd55f0aa072748a88169aaa0882 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 5 Nov 2025 17:33:05 +0100 Subject: [PATCH 1/4] nit on dac! --- src/transformers/models/dac/modeling_dac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 81cfcbb931d4..1489c0c70b4b 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -263,7 +263,7 @@ def forward(self, hidden_state): return hidden_state -class DacResidualVectorQuantize(nn.Module): +class DacResidualVectorQuantizer(nn.Module): """ ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312) """ @@ -566,7 +566,7 @@ def __init__(self, config: DacConfig): self.encoder = DacEncoder(config) self.decoder = DacDecoder(config) - self.quantizer = DacResidualVectorQuantize(config) + self.quantizer = DacResidualVectorQuantizer(config) self.bits_per_codebook = int(math.log2(self.config.codebook_size)) if 2**self.bits_per_codebook != self.config.codebook_size: From 3b3311f8834023856428d2cb6e23900558082455 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 7 Nov 2025 16:56:53 +0100 Subject: [PATCH 2/4] fix --- src/transformers/audio_utils.py | 12 +++++ .../models/xcodec/modeling_xcodec.py | 46 +++++++++++++++++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index df4bbf1ca604..ed9b39e1d499 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -219,6 +219,18 @@ def load_audio_as( raise ValueError(f"Error loading audio: {e}") +def conv1d_output_length(module: "torch.nn.Conv1d", input_length: int) -> int: + """ + Computes the output length of a 1D convolution layer according to torch's documentation: + https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + """ + return int( + (input_length + 2 * module.padding[0] - module.dilation[0] * (module.kernel_size[0] - 1) - 1) + / module.stride[0] + + 1 + ) + + def is_valid_audio(audio): return is_numpy_array(audio) or is_torch_tensor(audio) diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 774f9c74b8de..e85e2ae3516a 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -16,12 +16,14 @@ import math from dataclasses import dataclass +from functools import lru_cache from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F +from ...audio_utils import conv1d_output_length from ...modeling_utils import PreTrainedAudioTokenizerBase from ...utils import ModelOutput, auto_docstring from ..auto import AutoModel @@ -394,6 +396,40 @@ def remove_weight_norm(self): if hasattr(m, "parametrizations") and "weight" in m.parametrizations: torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) + @lru_cache + def _get_conv1d_layers(self, module): + """ + Recursively iterate to fetch all Conv1d layers. + """ + + def get_conv1d_layers_recursive(module: nn.Module): + params_list = [] + + if isinstance(module, nn.Conv1d): + params_list.append(module) + + # Recursively check all child modules + for child in module.children(): + params_list.extend(get_conv1d_layers_recursive(child)) + + return params_list + + return tuple(get_conv1d_layers_recursive(module)) + + def _get_conv1d_output_lengths(self, input_length, module=None): + """ + For a given module, compute the output length that would be obtained after all Conv1d layers. + """ + if module is None: + module = self + + conv1d_layers = self._get_conv1d_layers(module) + + for layer in conv1d_layers: + input_length = conv1d_output_length(layer, input_length) + + return input_length + @auto_docstring(custom_intro="""The Xcodec neural audio codec model.""") class XcodecModel(XcodecPreTrainedModel): @@ -474,11 +510,13 @@ def encode( e_semantic_input = self._extract_semantic_features(input_values).detach() e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) - e_acoustic = self.acoustic_encoder(input_values) - if e_acoustic.shape[2] != e_semantic.shape[2]: - # make sure they line up if frames don't match - e_acoustic = self.acoustic_encoder(F.pad(input_values[:, 0, :], (self.pad, self.pad)).unsqueeze(1)) + # orignal codebase infer to get the output length, but we can directly infer it + # from the model and know wether we should pad + if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]: + e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad))) + else: + e_acoustic = self.acoustic_encoder(input_values) embeddings = torch.cat([e_acoustic, e_semantic], dim=1) embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2) From 53fea56c3cf67e27367cd5ab4b2aa665a5185fe4 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Fri, 7 Nov 2025 17:02:12 +0100 Subject: [PATCH 3/4] not for this pr --- src/transformers/models/dac/modeling_dac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 1489c0c70b4b..81cfcbb931d4 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -263,7 +263,7 @@ def forward(self, hidden_state): return hidden_state -class DacResidualVectorQuantizer(nn.Module): +class DacResidualVectorQuantize(nn.Module): """ ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312) """ @@ -566,7 +566,7 @@ def __init__(self, config: DacConfig): self.encoder = DacEncoder(config) self.decoder = DacDecoder(config) - self.quantizer = DacResidualVectorQuantizer(config) + self.quantizer = DacResidualVectorQuantize(config) self.bits_per_codebook = int(math.log2(self.config.codebook_size)) if 2**self.bits_per_codebook != self.config.codebook_size: From 806fbb0f45cf8370ec9453f2b960005eb6f4993f Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:30:43 +0100 Subject: [PATCH 4/4] make style --- src/transformers/models/xcodec/modeling_xcodec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index cbd0db20d58a..04018af855c5 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -23,8 +23,8 @@ import torch.nn as nn import torch.nn.functional as F -from ...audio_utils import conv1d_output_length from ... import initialization as init +from ...audio_utils import conv1d_output_length from ...modeling_utils import PreTrainedAudioTokenizerBase from ...utils import ModelOutput, auto_docstring from ..auto import AutoModel