diff --git a/egs/ljspeech/tts/local/split_subsets.py b/egs/ljspeech/tts/local/split_subsets.py deleted file mode 100755 index b2afca971..000000000 --- a/egs/ljspeech/tts/local/split_subsets.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script split the LJSpeech dataset cuts into three sets: - - training, 12500 - - validation, 100 - - test, 500 -The numbers are from https://arxiv.org/pdf/2106.06103.pdf - -Usage example: - python3 ./local/split_subsets.py ./data/spectrogram -""" - -import argparse -import logging -import random -from pathlib import Path - -from lhotse import load_manifest_lazy - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest_dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest_dir = Path(args.manifest_dir) - prefix = "ljspeech" - suffix = "jsonl.gz" - # all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") - all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}") - - cut_ids = list(all_cuts.ids) - random.shuffle(cut_ids) - - train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500]) - valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100]) - test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:]) - assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)" - assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)" - assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)" - - train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}") - valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}") - test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}") - - logging.info("Splitted into three sets: training (12500), validation (100), and test (500)") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/tts/prepare.sh index 4f4685951..613eb37d8 100755 --- a/egs/ljspeech/tts/prepare.sh +++ b/egs/ljspeech/tts/prepare.sh @@ -9,8 +9,7 @@ nj=1 stage=-1 stop_stage=100 -# dl_dir=$PWD/download -dl_dir=/star-data/zengwei/download/ljspeech/ +dl_dir=$PWD/download . shared/parse_options.sh || exit 1 @@ -66,22 +65,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi fi -# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then -# log "Stage 3: Phonemize the transcripts for LJSpeech" -# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then -# ./local/phonemize_text.py data/spectrogram -# touch data/spectrogram/.ljspeech_phonemized.done -# fi -# fi - -# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then -# log "Stage 4: Split the LJSpeech cuts into three sets" -# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then -# ./local/split_subsets.py data/spectrogram -# touch data/spectrogram/.ljspeech_split.done -# fi -# fi - if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then @@ -94,6 +77,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then lhotse subset --last 500 \ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ data/spectrogram/ljspeech_cuts_test.jsonl.gz + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) diff --git a/egs/ljspeech/tts/vits/commons.py b/egs/ljspeech/tts/vits/commons.py deleted file mode 100644 index 9ad0444b6..000000000 --- a/egs/ljspeech/tts/vits/commons.py +++ /dev/null @@ -1,161 +0,0 @@ -import math -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size*dilation - dilation)/2) - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d( - length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - (num_timescales - 1)) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - device = duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2,3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1. / norm_type) - return total_norm diff --git a/egs/ljspeech/tts/vits/duration_predictor.py b/egs/ljspeech/tts/vits/duration_predictor.py index 5e8d670bd..c29a28479 100644 --- a/egs/ljspeech/tts/vits/duration_predictor.py +++ b/egs/ljspeech/tts/vits/duration_predictor.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/features.py b/egs/ljspeech/tts/vits/features.py deleted file mode 100644 index b43c7cf46..000000000 --- a/egs/ljspeech/tts/vits/features.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, Optional, Tuple - -import librosa -import numpy as np -import torch -from torch import nn - -from icefall.utils import make_pad_mask - - -# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py -class Stft(nn.Module): - def __init__( - self, - n_fft: int = 512, - win_length: int = None, - hop_length: int = 128, - window: Optional[str] = "hann", - center: bool = True, - normalized: bool = False, - onesided: bool = True, - ): - super().__init__() - self.n_fft = n_fft - if win_length is None: - self.win_length = n_fft - else: - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.normalized = normalized - self.onesided = onesided - if window is not None and not hasattr(torch, f"{window}_window"): - raise ValueError(f"{window} window is not implemented") - self.window = window - - def extra_repr(self): - return ( - f"n_fft={self.n_fft}, " - f"win_length={self.win_length}, " - f"hop_length={self.hop_length}, " - f"center={self.center}, " - f"normalized={self.normalized}, " - f"onesided={self.onesided}" - ) - - def forward( - self, input: torch.Tensor, ilens: torch.Tensor = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """STFT forward function. - - Args: - input: (Batch, Nsamples) or (Batch, Nsample, Channels) - ilens: (Batch) - Returns: - output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) - - """ - bs = input.size(0) - if input.dim() == 3: - multi_channel = True - # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) - input = input.transpose(1, 2).reshape(-1, input.size(1)) - else: - multi_channel = False - - # NOTE(kamo): - # The default behaviour of torch.stft is compatible with librosa.stft - # about padding and scaling. - # Note that it's different from scipy.signal.stft - - # output: (Batch, Freq, Frames, 2=real_imag) - # or (Batch, Channel, Freq, Frames, 2=real_imag) - if self.window is not None: - window_func = getattr(torch, f"{self.window}_window") - window = window_func( - self.win_length, dtype=input.dtype, device=input.device - ) - else: - window = None - - # For the compatibility of ARM devices, which do not support - # torch.stft() due to the lack of MKL (on older pytorch versions), - # there is an alternative replacement implementation with librosa. - # Note: pytorch >= 1.10.0 now has native support for FFT and STFT - # on all cpu targets including ARM. - if input.is_cuda or torch.backends.mkl.is_available(): - stft_kwargs = dict( - n_fft=self.n_fft, - win_length=self.win_length, - hop_length=self.hop_length, - center=self.center, - window=window, - normalized=self.normalized, - onesided=self.onesided, - ) - stft_kwargs["return_complex"] = True - output = torch.stft(input, **stft_kwargs) - output = torch.view_as_real(output) - else: - if self.training: - raise NotImplementedError( - "stft is implemented with librosa on this device, which does not " - "support the training mode." - ) - - # use stft_kwargs to flexibly control different PyTorch versions' kwargs - # note: librosa does not support a win_length that is < n_ftt - # but the window can be manually padded (see below). - stft_kwargs = dict( - n_fft=self.n_fft, - win_length=self.n_fft, - hop_length=self.hop_length, - center=self.center, - window=window, - pad_mode="reflect", - ) - - if window is not None: - # pad the given window to n_fft - n_pad_left = (self.n_fft - window.shape[0]) // 2 - n_pad_right = self.n_fft - window.shape[0] - n_pad_left - stft_kwargs["window"] = torch.cat( - [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0 - ).numpy() - else: - win_length = ( - self.win_length if self.win_length is not None else self.n_fft - ) - stft_kwargs["window"] = torch.ones(win_length) - - output = [] - # iterate over istances in a batch - for i, instance in enumerate(input): - stft = librosa.stft(input[i].numpy(), **stft_kwargs) - output.append(torch.tensor(np.stack([stft.real, stft.imag], -1))) - output = torch.stack(output, 0) - if not self.onesided: - len_conj = self.n_fft - output.shape[1] - conj = output[:, 1 : 1 + len_conj].flip(1) - conj[:, :, :, -1].data *= -1 - output = torch.cat([output, conj], 1) - if self.normalized: - output = output * (stft_kwargs["window"].shape[0] ** (-0.5)) - - # output: (Batch, Freq, Frames, 2=real_imag) - # -> (Batch, Frames, Freq, 2=real_imag) - output = output.transpose(1, 2) - if multi_channel: - # output: (Batch * Channel, Frames, Freq, 2=real_imag) - # -> (Batch, Frame, Channel, Freq, 2=real_imag) - output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( - 1, 2 - ) - - if ilens is not None: - if self.center: - pad = self.n_fft // 2 - ilens = ilens + 2 * pad - - olens = ( - torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc") - + 1 - ) - output.masked_fill_(make_pad_mask(olens), 0.0) - else: - olens = None - - return output, olens - - -# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py -class LinearSpectrogram(nn.Module): - """Linear amplitude spectrogram. - - Stft -> amplitude-spec - """ - - def __init__( - self, - n_fft: int = 1024, - win_length: int = None, - hop_length: int = 256, - window: Optional[str] = "hann", - center: bool = True, - normalized: bool = False, - onesided: bool = True, - ): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.stft = Stft( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window=window, - center=center, - normalized=normalized, - onesided=onesided, - ) - self.n_fft = n_fft - - def output_size(self) -> int: - return self.n_fft // 2 + 1 - - def get_parameters(self) -> Dict[str, Any]: - """Return the parameters required by Vocoder.""" - return dict( - n_fft=self.n_fft, - n_shift=self.hop_length, - win_length=self.win_length, - window=self.window, - ) - - def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - # 1. Stft: time -> time-freq - input_stft, feats_lens = self.stft(input, input_lengths) - - assert input_stft.dim() >= 4, input_stft.shape - # "2" refers to the real/imag parts of Complex - assert input_stft.shape[-1] == 2, input_stft.shape - - # STFT -> Power spectrum -> Amp spectrum - # input_stft: (..., F, 2) -> (..., F) - input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 - input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) - return input_amp, feats_lens - - -# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py -class LogMel(nn.Module): - """Convert STFT to fbank feats - - The arguments is same as librosa.filters.mel - - Args: - fs: number > 0 [scalar] sampling rate of the incoming signal - n_fft: int > 0 [scalar] number of FFT components - n_mels: int > 0 [scalar] number of Mel bands to generate - fmin: float >= 0 [scalar] lowest frequency (in Hz) - fmax: float >= 0 [scalar] highest frequency (in Hz). - If `None`, use `fmax = fs / 2.0` - htk: use HTK formula instead of Slaney - """ - - def __init__( - self, - fs: int = 16000, - n_fft: int = 512, - n_mels: int = 80, - fmin: float = None, - fmax: float = None, - htk: bool = False, - log_base: float = None, - ): - super().__init__() - - fmin = 0 if fmin is None else fmin - fmax = fs / 2 if fmax is None else fmax - _mel_options = dict( - sr=fs, - n_fft=n_fft, - n_mels=n_mels, - fmin=fmin, - fmax=fmax, - htk=htk, - ) - self.mel_options = _mel_options - self.log_base = log_base - - # Note(kamo): The mel matrix of librosa is different from kaldi. - melmat = librosa.filters.mel(**_mel_options) - # melmat: (D2, D1) -> (D1, D2) - self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) - - def extra_repr(self): - return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) - - def forward( - self, - feat: torch.Tensor, - ilens: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) - mel_feat = torch.matmul(feat, self.melmat) - mel_feat = torch.clamp(mel_feat, min=1e-10) - - if self.log_base is None: - logmel_feat = mel_feat.log() - elif self.log_base == 2.0: - logmel_feat = mel_feat.log2() - elif self.log_base == 10.0: - logmel_feat = mel_feat.log10() - else: - logmel_feat = mel_feat.log() / torch.log(self.log_base) - - # Zero padding - if ilens is not None: - logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0) - else: - ilens = feat.new_full( - [feat.size(0)], fill_value=feat.size(1), dtype=torch.long - ) - return logmel_feat, ilens - - -# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py -class LogMelFbank(nn.Module): - """Conventional frontend structure for TTS. - - Stft -> amplitude-spec -> Log-Mel-Fbank - """ - - def __init__( - self, - fs: int = 16000, - n_fft: int = 1024, - win_length: int = None, - hop_length: int = 256, - window: Optional[str] = "hann", - center: bool = True, - normalized: bool = False, - onesided: bool = True, - n_mels: int = 80, - fmin: Optional[int] = 80, - fmax: Optional[int] = 7600, - htk: bool = False, - log_base: Optional[float] = 10.0, - ): - super().__init__() - - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.fmin = fmin - self.fmax = fmax - - self.stft = Stft( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window=window, - center=center, - normalized=normalized, - onesided=onesided, - ) - - self.logmel = LogMel( - fs=fs, - n_fft=n_fft, - n_mels=n_mels, - fmin=fmin, - fmax=fmax, - htk=htk, - log_base=log_base, - ) - - def output_size(self) -> int: - return self.n_mels - - def get_parameters(self) -> Dict[str, Any]: - """Return the parameters required by Vocoder""" - return dict( - fs=self.fs, - n_fft=self.n_fft, - n_shift=self.hop_length, - window=self.window, - n_mels=self.n_mels, - win_length=self.win_length, - fmin=self.fmin, - fmax=self.fmax, - ) - - def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - # 1. Domain-conversion: e.g. Stft: time -> time-freq - input_stft, feats_lens = self.stft(input, input_lengths) - - assert input_stft.dim() >= 4, input_stft.shape - # "2" refers to the real/imag parts of Complex - assert input_stft.shape[-1] == 2, input_stft.shape - - # NOTE(kamo): We use different definition for log-spec between TTS and ASR - # TTS: log_10(abs(stft)) - # ASR: log_e(power(stft)) - - # input_stft: (..., F, 2) -> (..., F) - input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 - input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) - input_feats, _ = self.logmel(input_amp, feats_lens) - return input_feats, feats_lens diff --git a/egs/ljspeech/tts/vits/flow.py b/egs/ljspeech/tts/vits/flow.py index 04fb99b42..206bd5e3e 100644 --- a/egs/ljspeech/tts/vits/flow.py +++ b/egs/ljspeech/tts/vits/flow.py @@ -1,4 +1,5 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py + # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py index fc0d45cfd..664d8064f 100644 --- a/egs/ljspeech/tts/vits/generator.py +++ b/egs/ljspeech/tts/vits/generator.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/hifigan.py b/egs/ljspeech/tts/vits/hifigan.py index a87cb2fce..589ac30f6 100644 --- a/egs/ljspeech/tts/vits/hifigan.py +++ b/egs/ljspeech/tts/vits/hifigan.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/tts/vits/infer.py index 623cc3ec9..f971f85ff 100755 --- a/egs/ljspeech/tts/vits/infer.py +++ b/egs/ljspeech/tts/vits/infer.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -17,118 +16,34 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +This script performs model inference on test set. + Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ +./vits/infer.py \ + --epoch 1000 \ --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 + --max-duration 500 """ import argparse import logging -import math -import os -from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import List import k2 import torch import torch.nn as nn import torchaudio -from train import get_model, get_params, prepare_input +from train import get_model, get_params from tokenizer import Tokenizer -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger from tts_datamodule import LJSpeechTtsDataModule -LOG_EPS = math.log(1e-10) - def get_parser(): parser = argparse.ArgumentParser( @@ -138,35 +53,16 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=30, + default=1000, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. """, ) - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="vits/exp", help="The experiment dir", ) @@ -174,7 +70,7 @@ def get_parser(): "--tokens", type=str, default="data/tokens.txt", - help="""Path to tokens.txt.""", + help="""Path to vocabulary.""", ) return parser @@ -185,8 +81,9 @@ def infer_dataset( params: AttributeDict, model: nn.Module, tokenizer: Tokenizer, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) -> None: """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. Args: dl: @@ -195,20 +92,8 @@ def infer_dataset( It is returned by :func:`get_params`. model: The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + tokenizer: + Used to convert text to phonemes. """ # Background worker save audios to disk. def _save_worker( @@ -233,7 +118,7 @@ def infer_dataset( device = next(model.parameters()).device num_cuts = 0 - log_interval = 10 + log_interval = 5 try: num_batches = len(dl) @@ -242,7 +127,6 @@ def infer_dataset( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: - # We only want one background worker so that serialization is deterministic. for batch_idx, batch in enumerate(dl): batch_size = len(batch["text"]) @@ -253,7 +137,7 @@ def infer_dataset( tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) + # tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) audio = batch["audio"] @@ -265,9 +149,6 @@ def infer_dataset( # convert to samples audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() - # import pdb - # pdb.set_trace() - futures.append( executor.submit( _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred @@ -295,10 +176,7 @@ def main(): params = get_params() params.update(vars(args)) - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}" params.res_dir = params.exp_dir / "infer" / params.suffix params.save_wav_dir = params.res_dir / "wav" @@ -322,40 +200,16 @@ def main(): logging.info("About to create model") model = get_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") # we need cut ids to display recognition results. args.return_cuts = True @@ -371,17 +225,8 @@ def main(): tokenizer=tokenizer, ) - # save_results( - # params=params, - # test_set_name=test_set, - # results_dict=results_dict, - # ) - logging.info("Done!") -# torch.set_num_threads(1) -# torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/tts/vits/loss.py index 0d27af643..21aaad6e7 100644 --- a/egs/ljspeech/tts/vits/loss.py +++ b/egs/ljspeech/tts/vits/loss.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) @@ -9,7 +9,7 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch import torch.distributions as D @@ -266,7 +266,7 @@ class MelSpectrogramLoss(torch.nn.Module): return mel_loss -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py """VITS-related loss modules. diff --git a/egs/ljspeech/tts/vits/models.py b/egs/ljspeech/tts/vits/models.py deleted file mode 100644 index f5acdeb2b..000000000 --- a/egs/ljspeech/tts/vits/models.py +++ /dev/null @@ -1,534 +0,0 @@ -import copy -import math -import torch -from torch import nn -from torch.nn import functional as F - -import commons -import modules -import attentions -import monotonic_align - -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from commons import init_weights, get_padding - - -class StochasticDurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): - super().__init__() - filter_channels = in_channels # it needs to be removed from future version. - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.log_flow = modules.Log() - self.flows = nn.ModuleList() - self.flows.append(modules.ElementwiseAffine(2)) - for i in range(n_flows): - self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.flows.append(modules.Flip()) - - self.post_pre = nn.Conv1d(1, filter_channels, 1) - self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - self.post_flows = nn.ModuleList() - self.post_flows.append(modules.ElementwiseAffine(2)) - for i in range(4): - self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.post_flows.append(modules.Flip()) - - self.pre = nn.Conv1d(in_channels, filter_channels, 1) - self.proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): - x = torch.detach(x) - x = self.pre(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.convs(x, x_mask) - x = self.proj(x) * x_mask - - if not reverse: - flows = self.flows - assert w is not None - - logdet_tot_q = 0 - h_w = self.post_pre(w) - h_w = self.post_convs(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask - z_q = e_q - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) - logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in flows: - z, logdet = flow(z, x_mask, g=x, reverse=reverse) - logdet_tot = logdet_tot + logdet - nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot - return nll + logq # [b] - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale - for flow in flows: - z = flow(z, x_mask, g=x, reverse=reverse) - z0, z1 = torch.split(z, [1, 1], 1) - logw = z0 - return logw - - -class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): - super().__init__() - - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.gin_channels = gin_channels - - self.drop = nn.Dropout(p_dropout) - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) - self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) - self.norm_2 = modules.LayerNorm(filter_channels) - self.proj = nn.Conv1d(filter_channels, 1, 1) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, in_channels, 1) - - def forward(self, x, x_mask, g=None): - x = torch.detach(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class TextEncoder(nn.Module): - def __init__(self, - n_vocab, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout): - super().__init__() - self.n_vocab = n_vocab - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - - self.encoder = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths): - x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return x, m, logs, x_mask - - -class ResidualCouplingBlock(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x - - -class PosteriorEncoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class Generator(torch.nn.Module): - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): - super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append(weight_norm( - ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), - k, u, padding=(k-u)//2))) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel//(2**(i+1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x, g=None): - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x) - else: - xs += self.resblocks[i*self.num_kernels+j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ]) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminator, self).__init__() - periods = [2,3,5,7,11] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__(self, - n_vocab, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=0, - gin_channels=0, - use_sdp=True, - **kwargs): - - super().__init__() - self.n_vocab = n_vocab - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.n_speakers = n_speakers - self.gin_channels = gin_channels - - self.use_sdp = use_sdp - - self.enc_p = TextEncoder(n_vocab, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - - if use_sdp: - self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) - else: - self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) - - if n_speakers > 1: - self.emb_g = nn.Embedding(n_speakers, gin_channels) - - def forward(self, x, x_lengths, y, y_lengths, sid=None): - - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - if self.n_speakers > 0: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - else: - g = None - - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - - with torch.no_grad(): - # negative cross-entropy - s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] - neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] - neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] - neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 - - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() - - w = attn.sum(2) - if self.use_sdp: - l_length = self.dp(x, x_mask, w, g=g) - l_length = l_length / torch.sum(x_mask) - else: - logw_ = torch.log(w + 1e-6) * x_mask - logw = self.dp(x, x_mask, g=g) - l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging - - # expand prior - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) - - z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) - o = self.dec(z_slice, g=g) - return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - if self.n_speakers > 0: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - else: - g = None - - if self.use_sdp: - logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) - else: - logw = self.dp(x, x_mask, g=g) - w = torch.exp(logw) * x_mask * length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = commons.generate_path(w_ceil, attn_mask) - - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - z = self.flow(z_p, y_mask, g=g, reverse=True) - o = self.dec((z * y_mask)[:,:,:max_len], g=g) - return o, attn, y_mask, (z, z_p, m_p, logs_p) - - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): - assert self.n_speakers > 0, "n_speakers have to be larger than 0." - g_src = self.emb_g(sid_src).unsqueeze(-1) - g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.dec(z_hat * y_mask, g=g_tgt) - return o_hat, y_mask, (z, z_p, z_hat) - diff --git a/egs/ljspeech/tts/vits/posterior_encoder.py b/egs/ljspeech/tts/vits/posterior_encoder.py index c78fd647f..6b8a5be52 100644 --- a/egs/ljspeech/tts/vits/posterior_encoder.py +++ b/egs/ljspeech/tts/vits/posterior_encoder.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/residual_coupling.py b/egs/ljspeech/tts/vits/residual_coupling.py index 48e748316..2d6807cb7 100644 --- a/egs/ljspeech/tts/vits/residual_coupling.py +++ b/egs/ljspeech/tts/vits/residual_coupling.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/symbols.py b/egs/ljspeech/tts/vits/symbols.py deleted file mode 100644 index 70c2868f4..000000000 --- a/egs/ljspeech/tts/vits/symbols.py +++ /dev/null @@ -1,17 +0,0 @@ -# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py -""" from https://github.com/keithito/tacotron """ - -''' -Defines the set of symbols used in text input to the model. -''' -_pad = '_' -_punctuation = ';:,.!?¡¿—…"«»“” ' -_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' -_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" - - -# Export all symbols: -symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) - -# Special symbol ids -SPACE_ID = symbol_table.index(" ") diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/tts/vits/text_encoder.py index 9ba8e1768..419fd6162 100644 --- a/egs/ljspeech/tts/vits/text_encoder.py +++ b/egs/ljspeech/tts/vits/text_encoder.py @@ -20,6 +20,7 @@ This code is based on - https://github.com/jaywalnut310/vits - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py + - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py """ import copy @@ -67,6 +68,7 @@ class TextEncoder(torch.nn.Module): self.emb = torch.nn.Embedding(vocabs, d_model) torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + # We use conformer as text encoder self.encoder = Transformer( d_model=d_model, num_heads=num_heads, diff --git a/egs/ljspeech/tts/vits/tokenizer.py b/egs/ljspeech/tts/vits/tokenizer.py index 5a513a0d9..8a61511ef 100644 --- a/egs/ljspeech/tts/vits/tokenizer.py +++ b/egs/ljspeech/tts/vits/tokenizer.py @@ -18,7 +18,6 @@ from typing import Dict, List import g2p_en import tacotron_cleaner.cleaners - from utils import intersperse diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py index 01cd6137e..c8df3c5d0 100755 --- a/egs/ljspeech/tts/vits/train.py +++ b/egs/ljspeech/tts/vits/train.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -22,9 +18,10 @@ import argparse import logging +import numpy as np from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import k2 import torch @@ -36,26 +33,17 @@ from torch.optim import Optimizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import LJSpeechTtsDataModule from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import load_checkpoint from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, setup_logger, str2bool from tokenizer import Tokenizer -from utils import ( - MetricsTracker, - plot_feature, - save_checkpoint, - save_checkpoint_with_global_batch_idx, -) +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint from vits import VITS LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -90,7 +78,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=30, + default=1000, help="Number of epochs to train.", ) @@ -104,15 +92,6 @@ def get_parser(): """, ) - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - parser.add_argument( "--exp-dir", type=str, @@ -127,7 +106,7 @@ def get_parser(): "--tokens", type=str, default="data/tokens.txt", - help="""Path to tokens.txt.""", + help="""Path to vocabulary.""", ) parser.add_argument( @@ -158,24 +137,11 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" + default=20, + help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt' """, ) @@ -218,8 +184,6 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - valid_interval: Run validation if batch_idx % valid_interval is 0 - feature_dim: The model input dim. It has to match the one used @@ -242,18 +206,14 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": -1, # 0 - "log_interval": 10, - "draw_interval": 500, - # "reset_interval": 200, + "log_interval": 50, "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 22050, "frame_shift": 256, "frame_length": 1024, "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "mel_loss_params": { - "n_mels": 80, - }, + "n_mels": 80, "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss "lambda_mel": 45.0, # loss scaling coefficient for Mel loss "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss @@ -270,9 +230,7 @@ def load_checkpoint_if_available( ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from + If params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -287,9 +245,7 @@ def load_checkpoint_if_available( Returns: Return a dict containing previously saved training info. """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: + if params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -308,19 +264,15 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - return saved_params def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = params.mel_loss_params - mel_loss_params.update( - frame_length=params.frame_length, - frame_shift=params.frame_shift, - ) + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } model = VITS( vocab_size=params.vocab_size, feature_dim=params.feature_dim, @@ -381,18 +333,22 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. scaler: The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -404,7 +360,7 @@ def train_one_epoch( model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device - # used to summary the stats over iterations + # used to summary the stats over iterations in one epoch tot_loss = MetricsTracker() saved_bad_model = False @@ -433,7 +389,6 @@ def train_one_epoch( loss_info = MetricsTracker() loss_info['samples'] = batch_size - return_sample = params.batch_idx_train % params.log_interval == 0 try: with autocast(enabled=params.use_fp16): # forward discriminator @@ -463,13 +418,11 @@ def train_one_epoch( speech=audio, speech_lengths=audio_lens, forward_generator=True, - return_sample=return_sample, + return_sample=params.batch_idx_train % params.log_interval == 0, ) for k, v in stats_g.items(): - if "return_sample" not in k: + if "returned_sample" not in k: loss_info[k] = v * batch_size - if return_sample: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"] # update generator optimizer_g.zero_grad() scaler.scale(loss_g).backward() @@ -477,7 +430,6 @@ def train_one_epoch( scaler.update() # summary stats - # tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = tot_loss + loss_info except: # noqa save_bad_model() @@ -486,37 +438,12 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - # if batch_idx % 100 == 0 and params.use_fp16: if params.batch_idx_train % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - # if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: @@ -530,7 +457,6 @@ def train_one_epoch( f"grad_scale is too small, exiting: {cur_grad_scale}" ) - # if batch_idx % params.log_interval == 0: if params.batch_idx_train % params.log_interval == 0: cur_lr_g = max(scheduler_g.get_last_lr()) cur_lr_d = max(scheduler_d.get_last_lr()) @@ -561,7 +487,8 @@ def train_one_epoch( tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - if return_sample: + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] tb_writer.add_audio( "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate ) @@ -575,7 +502,6 @@ def train_one_epoch( "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' ) - # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info, (speech_hat, speech) = compute_validation_loss( @@ -615,14 +541,14 @@ def compute_validation_loss( valid_dl: torch.utils.data.DataLoader, world_size: int = 1, rank: int = 0, -) -> MetricsTracker: +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: """Run the validation process.""" model.eval() device = model.device if isinstance(model, DDP) else next(model.parameters()).device # used to summary the stats over iterations tot_loss = MetricsTracker() - return_sample = None + returned_sample = None with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): @@ -667,12 +593,14 @@ def compute_validation_loss( # infer for first batch: if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()]) + audio_pred, _, duration = inner_model.inference( + text=tokens[0, :tokens_lens[0].item()] + ) audio_pred = audio_pred.data.cpu().numpy() audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() - return_sample = (audio_pred, audio_gt) + returned_sample = (audio_pred, audio_gt) if world_size > 1: tot_loss.reduce(device) @@ -682,7 +610,7 @@ def compute_validation_loss( params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value - return tot_loss, return_sample + return tot_loss, returned_sample def scan_pessimistic_batches_for_oom( @@ -805,18 +733,10 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer_g = torch.optim.AdamW( - generator.parameters(), - lr=params.lr, - betas=(0.8, 0.99), - eps=1e-9, - # weight_decay=0, + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 ) optimizer_d = torch.optim.AdamW( - discriminator.parameters(), - lr=params.lr, - betas=(0.8, 0.99), - eps=1e-9, - # weight_decay=0, + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) @@ -852,16 +772,8 @@ def run(rank, world_size, args): train_cuts = ljspeech.train_cuts() - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold @@ -870,13 +782,10 @@ def run(rank, world_size, args): # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) return False - return True train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = ljspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = ljspeech.train_dataloaders(train_cuts) valid_cuts = ljspeech.valid_cuts() valid_dl = ljspeech.valid_dataloaders(valid_cuts) @@ -902,11 +811,11 @@ def run(rank, world_size, args): fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) + params.cur_epoch = epoch + if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - params.cur_epoch = epoch - train_one_epoch( params=params, model=model, @@ -927,27 +836,28 @@ def run(rank, world_size, args): diagnostic.print_diagnostics() break - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) + if epoch % params.save_every_n == 0: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) # step per epoch scheduler_g.step() diff --git a/egs/ljspeech/tts/vits/transform.py b/egs/ljspeech/tts/vits/transform.py index 6858de2ab..c20d13130 100644 --- a/egs/ljspeech/tts/vits/transform.py +++ b/egs/ljspeech/tts/vits/transform.py @@ -1,4 +1,5 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py + """Flow-related transformation. This code is derived from https://github.com/bayesiains/nflows. diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/tts/vits/utils.py index 582856eee..2a3dae900 100644 --- a/egs/ljspeech/tts/vits/utils.py +++ b/egs/ljspeech/tts/vits/utils.py @@ -1,32 +1,35 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Function to get random segments.""" - +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging -import re -import warnings -import numpy as np import torch import torch.nn as nn import torch.distributed as dist from lhotse.dataset.sampling.base import CutSampler from pathlib import Path -from phonemizer import phonemize -from symbols import symbol_table from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils.rnn import pad_sequence from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from unidecode import unidecode +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py def get_random_segments( x: torch.Tensor, x_lengths: torch.Tensor, @@ -55,6 +58,7 @@ def get_random_segments( return segments, start_idxs +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py def get_segments( x: torch.Tensor, start_idxs: torch.Tensor, @@ -78,195 +82,41 @@ def get_segments( return segments -# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py -def force_gatherable(data, device): - """Change object to gatherable in torch.nn.DataParallel recursively - - The difference from to_device() is changing to torch.Tensor if float or int - value is found. - - The restriction to the returned value in DataParallel: - The object must be - - torch.cuda.Tensor - - 1 or more dimension. 0-dimension-tensor sends warning. - or a list, tuple, dict. - - """ - if isinstance(data, dict): - return {k: force_gatherable(v, device) for k, v in data.items()} - # DataParallel can't handle NamedTuple well - elif isinstance(data, tuple) and type(data) is not tuple: - return type(data)(*[force_gatherable(o, device) for o in data]) - elif isinstance(data, (list, tuple, set)): - return type(data)(force_gatherable(v, device) for v in data) - elif isinstance(data, np.ndarray): - return force_gatherable(torch.from_numpy(data), device) - elif isinstance(data, torch.Tensor): - if data.dim() == 0: - # To 1-dim array - data = data[None] - return data.to(device) - elif isinstance(data, float): - return torch.tensor([data], dtype=torch.float, device=device) - elif isinstance(data, int): - return torch.tensor([data], dtype=torch.long, device=device) - elif data is None: - return None - else: - warnings.warn(f"{type(data)} may not be gatherable by DataParallel") - return data - - -# The following codes are based on https://github.com/jaywalnut310/vits - -# Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def text_clean(text): - '''Pipeline for English text, including abbreviation expansion. + punctuation + stress. - - Returns: - A string of phonemes. - ''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - phonemes = phonemize( - text, - language='en-us', - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=True, - ) - phonemes = collapse_whitespace(phonemes) - return phonemes - - -# Mappings from symbol to numeric ID and vice versa: -symbol_to_id = {s: i for i, s in enumerate(symbol_table)} -id_to_symbol = {i: s for i, s in enumerate(symbol_table)} - - -# def text_to_sequence(text: str) -> List[int]: -# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. -# ''' -# cleaned_text = text_clean(text) -# sequence = [symbol_to_id[symbol] for symbol in cleaned_text] -# return sequence -# -# -# def sequence_to_text(sequence: List[int]) -> str: -# '''Converts a sequence of IDs back to a string''' -# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence) -# return result - - +# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py def intersperse(sequence, item=0): result = [item] * (len(sequence) * 2 + 1) result[1::2] = sequence return result -def prepare_token_batch( - texts: List[str], - phonemes: Optional[List[str]] = None, - intersperse_blank: bool = True, - blank_id: int = 0, - pad_id: int = 0, -) -> torch.Tensor: - """Convert a list of text strings into a batch of symbol tokens with padding. - Args: - texts: list of text strings - intersperse_blank: whether to intersperse blank tokens in the converted token sequence. - blank_id: index of blank token - pad_id: padding index - """ - if phonemes is None: - # normalize text - normalized_texts = [] - for text in texts: - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - normalized_texts.append(text) +# from https://github.com/jaywalnut310/vits/blob/main/utils.py +MATPLOTLIB_FLAG = False - # convert to phonemes - phonemes = phonemize( - normalized_texts, - language='en-us', - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=True, - ) - phonemes = [collapse_whitespace(sequence) for sequence in phonemes] - # convert to symbol ids - lengths = [] - sequences = [] - skip = False - for idx, sequence in enumerate(phonemes): - try: - sequence = [symbol_to_id[symbol] for symbol in sequence] - except Exception: - # print(texts[idx]) - # print(normalized_texts[idx]) - print(phonemes[idx]) - skip = True - if intersperse_blank: - sequence = intersperse(sequence, blank_id) - try: - sequences.append(torch.tensor(sequence, dtype=torch.int64)) - except Exception: - print(sequence) - skip = True - lengths.append(len(sequence)) +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np - sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id) - lengths = torch.tensor(lengths, dtype=torch.int64) - return sequences, lengths, skip + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data class MetricsTracker(collections.defaultdict): @@ -413,106 +263,3 @@ def save_checkpoint( checkpoint[k] = v torch.save(checkpoint, filename) - - -def save_checkpoint_with_global_batch_idx( - out_dir: Path, - global_batch_idx: int, - model: Union[nn.Module, DDP], - params: Optional[Dict[str, Any]] = None, - optimizer_g: Optional[Optimizer] = None, - optimizer_d: Optional[Optimizer] = None, - scheduler_g: Optional[LRSchedulerType] = None, - scheduler_d: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - rank: int = 0, -): - """Save training info after processing given number of batches. - - Args: - out_dir: - The directory to save the checkpoint. - global_batch_idx: - The number of batches processed so far from the very start of the - training. The saved checkpoint will have the following filename: - f'out_dir / checkpoint-{global_batch_idx}.pt' - model: - The neural network model whose `state_dict` will be saved in the - checkpoint. - params: - A dict of training configurations to be saved. - optimizer_g: - The optimizer for generator used in the training. - Its `state_dict` will be saved. - optimizer_d: - The optimizer for discriminator used in the training. - Its `state_dict` will be saved. - scheduler_g: - The learning rate scheduler for generator used in the training. - Its `state_dict` will be saved. - scheduler_d: - The learning rate scheduler for discriminator used in the training. - Its `state_dict` will be saved. - scaler: - The scaler used for mix precision training. Its `state_dict` will - be saved. - sampler: - The sampler used in the training dataset. - rank: - The rank ID used in DDP training of the current node. Set it to 0 - if DDP is not used. - """ - out_dir = Path(out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - filename = out_dir / f"checkpoint-{global_batch_idx}.pt" - save_checkpoint( - filename=filename, - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - scaler=scaler, - sampler=sampler, - rank=rank, - ) - - -# def plot_feature(feature): -# """ -# Display the feature matrix as an image. Requires matplotlib to be installed. -# """ -# import matplotlib.pyplot as plt -# -# feature = np.flip(feature.transpose(1, 0), 0) -# return plt.matshow(feature) - -MATPLOTLIB_FLAG = False - - -def plot_feature(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py index 27d9b4c7a..aa26a012d 100644 --- a/egs/ljspeech/tts/vits/vits.py +++ b/egs/ljspeech/tts/vits/vits.py @@ -1,11 +1,11 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """VITS module for GAN-TTS task.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -247,7 +247,7 @@ class VITS(nn.Module): spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, forward_generator: bool = True, - ) -> Dict[str, Any]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform generator forward. Args: @@ -263,12 +263,8 @@ class VITS(nn.Module): forward_generator (bool): Whether to forward generator. Returns: - Dict[str, Any]: - - loss (Tensor): Loss scalar tensor. - - stats (Dict[str, float]): Statistics to be monitored. - - weight (Tensor): Weight tensor to summarize losses. - - optim_idx (int): Optimizer index (0 for G and 1 for D). - + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. """ if forward_generator: return self._forward_generator( @@ -308,7 +304,7 @@ class VITS(nn.Module): sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, - ) -> Dict[str, Any]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform generator forward. Args: @@ -323,12 +319,8 @@ class VITS(nn.Module): lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: - Dict[str, Any]: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - * weight (Tensor): Weight tensor to summarize losses. - * optim_idx (int): Optimizer index (0 for G and 1 for D). - + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. """ # setup feats = feats.transpose(1, 2) @@ -399,7 +391,7 @@ class VITS(nn.Module): ) if return_sample: - stats["return_sample"] = ( + stats["returned_sample"] = ( speech_hat_[0].data.cpu().numpy(), speech_[0].data.cpu().numpy(), mel_hat_[0].data.cpu().numpy(), @@ -423,7 +415,7 @@ class VITS(nn.Module): sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, - ) -> Dict[str, Any]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform discriminator forward. Args: @@ -438,12 +430,8 @@ class VITS(nn.Module): lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: - Dict[str, Any]: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - * weight (Tensor): Weight tensor to summarize losses. - * optim_idx (int): Optimizer index (0 for G and 1 for D). - + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. """ # setup feats = feats.transpose(1, 2) @@ -511,8 +499,8 @@ class VITS(nn.Module): alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, - ) -> Dict[str, torch.Tensor]: - """Run inference. + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for single sample. Args: text (Tensor): Input text index tensor (T_text,). @@ -528,11 +516,9 @@ class VITS(nn.Module): use_teacher_forcing (bool): Whether to use teacher forcing. Returns: - Dict[str, Tensor]: - * wav (Tensor): Generated waveform tensor (T_wav,). - * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). - * duration (Tensor): Predicted duration tensor (T_text,). - + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). """ # setup text = text[None] @@ -593,8 +579,8 @@ class VITS(nn.Module): alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, - ) -> Dict[str, torch.Tensor]: - """Run inference. + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for one batch. Args: text (Tensor): Input text index tensor (B, T_text). @@ -605,11 +591,9 @@ class VITS(nn.Module): max_len (Optional[int]): Maximum length. Returns: - Dict[str, Tensor]: - * wav (Tensor): Generated waveform tensor (B, T_wav). - * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). - * duration (Tensor): Predicted duration tensor (B, T_text). - + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). """ # inference wav, att_w, dur = self.generator.inference( diff --git a/egs/ljspeech/tts/vits/wavenet.py b/egs/ljspeech/tts/vits/wavenet.py index cbb44a8f4..fbe1be52b 100644 --- a/egs/ljspeech/tts/vits/wavenet.py +++ b/egs/ljspeech/tts/vits/wavenet.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)