Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 79 additions & 56 deletions nemo/collections/tts/parts/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
Utility helpers for TTS data processing, attention alignment, spectrogram
plotting, audio logging, and segment manipulation used across NeMo TTS models.
"""

import string
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import librosa
import matplotlib.pylab as plt
Expand Down Expand Up @@ -81,7 +86,7 @@ class OperationMode(Enum):
infer = 2


def get_batch_size(train_dataloader):
def get_batch_size(train_dataloader: "DataLoader") -> int:
if train_dataloader.batch_size is not None:
return train_dataloader.batch_size
elif train_dataloader.batch_sampler is not None:
Expand All @@ -93,11 +98,11 @@ def get_batch_size(train_dataloader):
raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}')


def get_num_workers(trainer):
def get_num_workers(trainer: "pl.Trainer") -> int:
return trainer.num_devices * trainer.num_nodes


def binarize_attention(attn, in_len, out_len):
def binarize_attention(attn: torch.Tensor, in_len: torch.Tensor, out_len: torch.Tensor) -> torch.Tensor:
"""Convert soft attention matrix to hard attention matrix.

Args:
Expand All @@ -118,7 +123,7 @@ def binarize_attention(attn, in_len, out_len):
return attn_out


def binarize_attention_parallel(attn, in_lens, out_lens):
def binarize_attention_parallel(attn: torch.Tensor, in_lens: torch.Tensor, out_lens: torch.Tensor) -> torch.Tensor:
"""For training purposes only. Binarizes attention with MAS.
These will no longer receive a gradient.

Expand Down Expand Up @@ -191,7 +196,7 @@ def unsort_tensor(ordered: torch.Tensor, indices: torch.Tensor, dim: Optional[in


@jit(nopython=True)
def mas(attn_map, width=1):
def mas(attn_map: np.ndarray, width: int = 1) -> np.ndarray:
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
Expand Down Expand Up @@ -222,7 +227,7 @@ def mas(attn_map, width=1):


@jit(nopython=True)
def mas_width1(log_attn_map):
def mas_width1(log_attn_map: np.ndarray) -> np.ndarray:
"""mas with hardcoded width=1"""
# assumes mel x text
neg_inf = log_attn_map.dtype.type(-np.inf)
Expand Down Expand Up @@ -251,7 +256,7 @@ def mas_width1(log_attn_map):


@jit(nopython=True, parallel=True)
def b_mas(b_log_attn_map, in_lens, out_lens, width=1):
def b_mas(b_log_attn_map: np.ndarray, in_lens: np.ndarray, out_lens: np.ndarray, width: int = 1) -> np.ndarray:
assert width == 1
attn_out = np.zeros_like(b_log_attn_map)

Expand All @@ -261,7 +266,7 @@ def b_mas(b_log_attn_map, in_lens, out_lens, width=1):
return attn_out


def griffin_lim(magnitudes, n_iters=50, n_fft=1024):
def griffin_lim(magnitudes: np.ndarray, n_iters: int = 50, n_fft: int = 1024) -> np.ndarray:
"""
Griffin-Lim algorithm to convert magnitude spectrograms to audio signals
"""
Expand All @@ -281,17 +286,17 @@ def griffin_lim(magnitudes, n_iters=50, n_fft=1024):

@rank_zero_only
def log_audio_to_tb(
swriter,
spect,
name,
step,
griffin_lim_mag_scale=1024,
griffin_lim_power=1.2,
sr=22050,
n_fft=1024,
n_mels=80,
fmax=8000,
):
swriter: Any,
spect: torch.Tensor,
name: str,
step: int,
griffin_lim_mag_scale: int = 1024,
griffin_lim_power: float = 1.2,
sr: int = 22050,
n_fft: int = 1024,
n_mels: int = 80,
fmax: int = 8000,
) -> None:
filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
log_mel = spect.data.cpu().numpy().T
mel = np.exp(log_mel)
Expand All @@ -300,7 +305,15 @@ def log_audio_to_tb(
swriter.add_audio(name, audio / max(np.abs(audio)), step, sample_rate=sr)


def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None, attended=None):
def plot_alignment_to_numpy(
alignment: np.ndarray,
title: str = '',
info: Optional[str] = None,
phoneme_seq: Optional[List[str]] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
attended: Optional[List[int]] = None,
) -> np.ndarray:
if phoneme_seq:
fig, ax = plt.subplots(figsize=(15, 10))
else:
Expand Down Expand Up @@ -331,16 +344,16 @@ def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vm


def plot_alignment_to_numpy_for_speechllm(
alignment,
title='',
info=None,
phoneme_seq=None,
vmin=None,
vmax=None,
phoneme_ver=0,
phone_offset=2,
h_offset=True,
):
alignment: np.ndarray,
title: str = '',
info: Optional[str] = None,
phoneme_seq: Optional[List[str]] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
phoneme_ver: int = 0,
phone_offset: int = 2,
h_offset: bool = True,
) -> np.ndarray:
alignment = np.clip(alignment, a_min=0, a_max=None)
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax)
Expand Down Expand Up @@ -387,7 +400,7 @@ def plot_alignment_to_numpy_for_speechllm(
return data


def plot_pitch_to_numpy(pitch, ylim_range=None):
def plot_pitch_to_numpy(pitch: np.ndarray, ylim_range: Optional[Tuple[float, float]] = None) -> np.ndarray:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(pitch)
if ylim_range is not None:
Expand All @@ -402,7 +415,9 @@ def plot_pitch_to_numpy(pitch, ylim_range=None):
return data


def plot_multipitch_to_numpy(pitch_gt, pitch_pred, ylim_range=None):
def plot_multipitch_to_numpy(
pitch_gt: np.ndarray, pitch_pred: np.ndarray, ylim_range: Optional[Tuple[float, float]] = None
) -> np.ndarray:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(pitch_gt, label="Ground truth")
plt.plot(pitch_pred, label="Predicted")
Expand All @@ -419,7 +434,7 @@ def plot_multipitch_to_numpy(pitch_gt, pitch_pred, ylim_range=None):
return data


def plot_spectrogram_to_numpy(spectrogram):
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
spectrogram = spectrogram.astype(np.float32)
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none')
Expand All @@ -434,7 +449,7 @@ def plot_spectrogram_to_numpy(spectrogram):
return data


def create_plot(data, x_axis, y_axis, output_filepath=None):
def create_plot(data: np.ndarray, x_axis: str, y_axis: str, output_filepath: Optional[str] = None) -> np.ndarray:
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(data, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
Expand All @@ -451,7 +466,7 @@ def create_plot(data, x_axis, y_axis, output_filepath=None):
return data


def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
def plot_gate_outputs_to_numpy(gate_targets: np.ndarray, gate_outputs: np.ndarray) -> np.ndarray:
fig, ax = plt.subplots(figsize=(12, 3))
ax.scatter(
range(len(gate_targets)),
Expand Down Expand Up @@ -482,7 +497,7 @@ def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
return data


def save_figure_to_numpy(fig):
def save_figure_to_numpy(fig: Any) -> np.ndarray:
img_array = np.array(fig.canvas.renderer.buffer_rgba())
return img_array

Expand Down Expand Up @@ -557,13 +572,13 @@ def plot_expert_usage_heatmap_to_numpy(


def regulate_len(
durations,
enc_out,
pace=1.0,
mel_max_len=None,
group_size=1,
dur_lens: torch.tensor = None,
):
durations: torch.Tensor,
enc_out: torch.Tensor,
pace: float = 1.0,
mel_max_len: Optional[int] = None,
group_size: int = 1,
dur_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A function that takes predicted durations per encoded token, and repeats enc_out according to the duration.
NOTE: durations.shape[1] == enc_out.shape[1]

Expand Down Expand Up @@ -605,7 +620,7 @@ def regulate_len(
return enc_rep, dec_lens


def slice_segments(x, ids_str, segment_size=4):
def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4) -> torch.Tensor:
"""
Time-wise slicing (patching) of bathches for audio/spectrogram
[B x C x T] -> [B x C x segment_size]
Expand All @@ -622,7 +637,9 @@ def slice_segments(x, ids_str, segment_size=4):
return ret


def rand_slice_segments(x, x_lengths=None, segment_size=4):
def rand_slice_segments(
x: torch.Tensor, x_lengths: Optional[torch.Tensor] = None, segment_size: int = 4
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Chooses random indices and slices segments from batch
[B x C x T] -> [B x C x segment_size]
Expand All @@ -639,7 +656,11 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
return ret, ids_str


def clip_grad_value_(parameters, clip_value, norm_type=2):
def clip_grad_value_(
parameters: Union[torch.Tensor, Iterable[torch.nn.Parameter]],
clip_value: Optional[float] = None,
norm_type: float = 2,
) -> float:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
Expand All @@ -657,12 +678,12 @@ def clip_grad_value_(parameters, clip_value, norm_type=2):
return total_norm


def convert_pad_shape(pad_shape):
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
pad_shape = [item for sublist in pad_shape[::-1] for item in sublist]
return pad_shape


def generate_path(duration, mask):
def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
Expand All @@ -678,7 +699,7 @@ def generate_path(duration, mask):
return path


def process_batch(batch_data, sup_data_types_set):
def process_batch(batch_data: Tuple, sup_data_types_set: set) -> Dict[str, Any]:
batch_dict = {}
batch_index = 0
for name, datatype in DATA_STR2DATA_CLASS.items():
Expand Down Expand Up @@ -748,11 +769,11 @@ def batch_from_ragged(


def sample_tts_input(
export_config,
device,
max_batch=1,
max_dim=127,
):
export_config: Dict[str, Any],
device: torch.device,
max_batch: int = 1,
max_dim: int = 127,
) -> Dict[str, torch.Tensor]:
"""
Generates input examples for tracing etc.
Returns:
Expand Down Expand Up @@ -901,7 +922,9 @@ def transcribe_with_whisper_from_filepaths(
return transcripts


def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, device):
def get_speaker_embeddings_from_filepaths(
filepaths: List[str], speaker_verification_model: Any, device: torch.device
) -> torch.Tensor:
"""
Get speaker embeddings from audio filepaths using a speaker verification model.
"""
Expand Down
Loading