Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def load_audio_as(
raise ValueError(f"Error loading audio: {e}")


def conv1d_output_length(module: "torch.nn.Conv1d", input_length: int) -> int:
"""
Computes the output length of a 1D convolution layer according to torch's documentation:
https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
"""
return int(
(input_length + 2 * module.padding[0] - module.dilation[0] * (module.kernel_size[0] - 1) - 1)
/ module.stride[0]
+ 1
)


def is_valid_audio(audio):
return is_numpy_array(audio) or is_torch_tensor(audio)

Expand Down
46 changes: 42 additions & 4 deletions src/transformers/models/xcodec/modeling_xcodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

import math
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ... import initialization as init
from ...audio_utils import conv1d_output_length
from ...modeling_utils import PreTrainedAudioTokenizerBase
from ...utils import ModelOutput, auto_docstring
from ..auto import AutoModel
Expand Down Expand Up @@ -396,6 +398,40 @@ def remove_weight_norm(self):
if hasattr(m, "parametrizations") and "weight" in m.parametrizations:
torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True)

@lru_cache
def _get_conv1d_layers(self, module):
"""
Recursively iterate to fetch all Conv1d layers.
"""

def get_conv1d_layers_recursive(module: nn.Module):
params_list = []

if isinstance(module, nn.Conv1d):
params_list.append(module)

# Recursively check all child modules
for child in module.children():
params_list.extend(get_conv1d_layers_recursive(child))

return params_list

return tuple(get_conv1d_layers_recursive(module))

def _get_conv1d_output_lengths(self, input_length, module=None):
"""
For a given module, compute the output length that would be obtained after all Conv1d layers.
"""
if module is None:
module = self

conv1d_layers = self._get_conv1d_layers(module)

for layer in conv1d_layers:
input_length = conv1d_output_length(layer, input_length)

return input_length
Comment on lines +401 to +433
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we can do earlier, in the config? It would avoid having to do these recursions on modules etc!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we'll standardize this computation (see #41203), we could be able to do it from the config, provided we ensure that every convolution layer’s parameters are stored there, and in the correct order relative to their appearance in the forward pass. However, I don’t think expanding the config with such parameters is a good idea, especially since we decided to hardcode them to avoid exposing them to the user.

I was thinking more along the lines of handling this directly during model initialization, when the modules are already being iterated over. For now, is it okay to merge this?



@auto_docstring(custom_intro="""The Xcodec neural audio codec model.""")
class XcodecModel(XcodecPreTrainedModel):
Expand Down Expand Up @@ -476,11 +512,13 @@ def encode(

e_semantic_input = self._extract_semantic_features(input_values).detach()
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
e_acoustic = self.acoustic_encoder(input_values)

if e_acoustic.shape[2] != e_semantic.shape[2]:
# make sure they line up if frames don't match
e_acoustic = self.acoustic_encoder(F.pad(input_values[:, 0, :], (self.pad, self.pad)).unsqueeze(1))
# orignal codebase infer to get the output length, but we can directly infer it
# from the model and know wether we should pad
if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]:
e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad)))
else:
e_acoustic = self.acoustic_encoder(input_values)

embeddings = torch.cat([e_acoustic, e_semantic], dim=1)
embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2)
Expand Down