From 0377cccc6f58d697c7c01c1a16cd594642fb5825 Mon Sep 17 00:00:00 2001 From: Erwan Date: Wed, 14 Feb 2024 15:58:59 +0100 Subject: [PATCH] Add symbolic link --- egs/ljspeech/TTS/vits/train.py | 1 + egs/ljspeech/TTS/vits2/flow.py | 312 +---------------- egs/ljspeech/TTS/vits2/loss.py | 336 +------------------ egs/ljspeech/TTS/vits2/residual_coupling.py | 8 +- egs/ljspeech/TTS/vits2/tokenizer.py | 109 +----- egs/ljspeech/TTS/vits2/train.py | 2 +- egs/ljspeech/TTS/vits2/transform.py | 219 +----------- egs/ljspeech/TTS/vits2/tts_datamodule.py | 328 +----------------- egs/ljspeech/TTS/vits2/vits.py | 5 - egs/ljspeech/TTS/vits2/wavenet.py | 349 +------------------- 10 files changed, 12 insertions(+), 1657 deletions(-) mode change 100644 => 120000 egs/ljspeech/TTS/vits2/flow.py mode change 100644 => 120000 egs/ljspeech/TTS/vits2/loss.py mode change 100644 => 120000 egs/ljspeech/TTS/vits2/tokenizer.py mode change 100644 => 120000 egs/ljspeech/TTS/vits2/transform.py mode change 100644 => 120000 egs/ljspeech/TTS/vits2/tts_datamodule.py mode change 100644 => 120000 egs/ljspeech/TTS/vits2/wavenet.py diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 71c4224fa..cb73ac528 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -405,6 +405,7 @@ def train_one_epoch( ) for k, v in stats_d.items(): loss_info[k] = v * batch_size + # update discriminator optimizer_d.zero_grad() scaler.scale(loss_d).backward() diff --git a/egs/ljspeech/TTS/vits2/flow.py b/egs/ljspeech/TTS/vits2/flow.py deleted file mode 100644 index 2b84f6434..000000000 --- a/egs/ljspeech/TTS/vits2/flow.py +++ /dev/null @@ -1,311 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Basic Flow modules used in VITS. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - -import math -from typing import Optional, Tuple, Union - -import torch -from transform import piecewise_rational_quadratic_transform - - -class FlipFlow(torch.nn.Module): - """Flip flow module.""" - - def forward( - self, x: torch.Tensor, *args, inverse: bool = False, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Flipped tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - x = torch.flip(x, [1]) - if not inverse: - logdet = x.new_zeros(x.size(0)) - return x, logdet - else: - return x - - -class LogFlow(torch.nn.Module): - """Log flow module.""" - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - inverse: bool = False, - eps: float = 1e-5, - **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - x_mask (Tensor): Mask tensor (B, 1, T). - inverse (bool): Whether to inverse the flow. - eps (float): Epsilon for log. - - Returns: - Tensor: Output tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - if not inverse: - y = torch.log(torch.clamp_min(x, eps)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - - -class ElementwiseAffineFlow(torch.nn.Module): - """Elementwise affine flow module.""" - - def __init__(self, channels: int): - """Initialize ElementwiseAffineFlow module. - - Args: - channels (int): Number of channels. - - """ - super().__init__() - self.channels = channels - self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) - self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) - - def forward( - self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - x_lengths (Tensor): Length tensor (B,). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Output tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - if not inverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - -class Transpose(torch.nn.Module): - """Transpose module for torch.nn.Sequential().""" - - def __init__(self, dim1: int, dim2: int): - """Initialize Transpose module.""" - super().__init__() - self.dim1 = dim1 - self.dim2 = dim2 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Transpose.""" - return x.transpose(self.dim1, self.dim2) - - -class DilatedDepthSeparableConv(torch.nn.Module): - """Dilated depth-separable conv module.""" - - def __init__( - self, - channels: int, - kernel_size: int, - layers: int, - dropout_rate: float = 0.0, - eps: float = 1e-5, - ): - """Initialize DilatedDepthSeparableConv module. - - Args: - channels (int): Number of channels. - kernel_size (int): Kernel size. - layers (int): Number of layers. - dropout_rate (float): Dropout rate. - eps (float): Epsilon for layer norm. - - """ - super().__init__() - - self.convs = torch.nn.ModuleList() - for i in range(layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs += [ - torch.nn.Sequential( - torch.nn.Conv1d( - channels, - channels, - kernel_size, - groups=channels, - dilation=dilation, - padding=padding, - ), - Transpose(1, 2), - torch.nn.LayerNorm( - channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - torch.nn.GELU(), - torch.nn.Conv1d( - channels, - channels, - 1, - ), - Transpose(1, 2), - torch.nn.LayerNorm( - channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - torch.nn.GELU(), - torch.nn.Dropout(dropout_rate), - ) - ] - - def forward( - self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, in_channels, T). - x_mask (Tensor): Mask tensor (B, 1, T). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Output tensor (B, channels, T). - - """ - if g is not None: - x = x + g - for f in self.convs: - y = f(x * x_mask) - x = x + y - return x * x_mask - - -class ConvFlow(torch.nn.Module): - """Convolutional flow module.""" - - def __init__( - self, - in_channels: int, - hidden_channels: int, - kernel_size: int, - layers: int, - bins: int = 10, - tail_bound: float = 5.0, - ): - """Initialize ConvFlow module. - - Args: - in_channels (int): Number of input channels. - hidden_channels (int): Number of hidden channels. - kernel_size (int): Kernel size. - layers (int): Number of layers. - bins (int): Number of bins. - tail_bound (float): Tail bound value. - - """ - super().__init__() - self.half_channels = in_channels // 2 - self.hidden_channels = hidden_channels - self.bins = bins - self.tail_bound = tail_bound - - self.input_conv = torch.nn.Conv1d( - self.half_channels, - hidden_channels, - 1, - ) - self.dds_conv = DilatedDepthSeparableConv( - hidden_channels, - kernel_size, - layers, - dropout_rate=0.0, - ) - self.proj = torch.nn.Conv1d( - hidden_channels, - self.half_channels * (bins * 3 - 1), - 1, - ) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - g: Optional[torch.Tensor] = None, - inverse: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, channels, T). - x_mask (Tensor): Mask tensor (B,). - g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). - inverse (bool): Whether to inverse the flow. - - Returns: - Tensor: Output tensor (B, channels, T). - Tensor: Log-determinant tensor for NLL (B,) if not inverse. - - """ - xa, xb = x.split(x.size(1) // 2, 1) - h = self.input_conv(xa) - h = self.dds_conv(h, x_mask, g=g) - h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) - - b, c, t = xa.shape - # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) - - # TODO(kan-bayashi): Understand this calculation - denom = math.sqrt(self.hidden_channels) - unnorm_widths = h[..., : self.bins] / denom - unnorm_heights = h[..., self.bins : 2 * self.bins] / denom - unnorm_derivatives = h[..., 2 * self.bins :] - xb, logdet_abs = piecewise_rational_quadratic_transform( - xb, - unnorm_widths, - unnorm_heights, - unnorm_derivatives, - inverse=inverse, - tails="linear", - tail_bound=self.tail_bound, - ) - x = torch.cat([xa, xb], 1) * x_mask - logdet = torch.sum(logdet_abs * x_mask, [1, 2]) - if not inverse: - return x, logdet - else: - return x diff --git a/egs/ljspeech/TTS/vits2/flow.py b/egs/ljspeech/TTS/vits2/flow.py new file mode 120000 index 000000000..e65d91ea7 --- /dev/null +++ b/egs/ljspeech/TTS/vits2/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits2/loss.py b/egs/ljspeech/TTS/vits2/loss.py deleted file mode 100644 index 653e06c0f..000000000 --- a/egs/ljspeech/TTS/vits2/loss.py +++ /dev/null @@ -1,335 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""HiFiGAN-related loss modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -from typing import List, Tuple, Union - -import torch -import torch.distributions as D -import torch.nn.functional as F -from lhotse.features.kaldi import Wav2LogFilterBank - - -class GeneratorAdversarialLoss(torch.nn.Module): - """Generator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators: bool = True, - loss_type: str = "mse", - ): - """Initialize GeneratorAversarialLoss module. - - Args: - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - loss_type (str): Loss type, "mse" or "hinge". - - """ - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." - if loss_type == "mse": - self.criterion = self._mse_loss - else: - self.criterion = self._hinge_loss - - def forward( - self, - outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - ) -> torch.Tensor: - """Calcualate generator adversarial loss. - - Args: - outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs.. - - Returns: - Tensor: Generator adversarial loss value. - - """ - if isinstance(outputs, (tuple, list)): - adv_loss = 0.0 - for i, outputs_ in enumerate(outputs): - if isinstance(outputs_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_ = outputs_[-1] - adv_loss += self.criterion(outputs_) - if self.average_by_discriminators: - adv_loss /= i + 1 - else: - adv_loss = self.criterion(outputs) - - return adv_loss - - def _mse_loss(self, x): - return F.mse_loss(x, x.new_ones(x.size())) - - def _hinge_loss(self, x): - return -x.mean() - - -class DiscriminatorAdversarialLoss(torch.nn.Module): - """Discriminator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators: bool = True, - loss_type: str = "mse", - ): - """Initialize DiscriminatorAversarialLoss module. - - Args: - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - loss_type (str): Loss type, "mse" or "hinge". - - """ - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." - if loss_type == "mse": - self.fake_criterion = self._mse_fake_loss - self.real_criterion = self._mse_real_loss - else: - self.fake_criterion = self._hinge_fake_loss - self.real_criterion = self._hinge_real_loss - - def forward( - self, - outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calcualate discriminator adversarial loss. - - Args: - outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs calculated from generator. - outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator - outputs, list of discriminator outputs, or list of list of discriminator - outputs calculated from groundtruth. - - Returns: - Tensor: Discriminator real loss value. - Tensor: Discriminator fake loss value. - - """ - if isinstance(outputs, (tuple, list)): - real_loss = 0.0 - fake_loss = 0.0 - for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): - if isinstance(outputs_hat_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_hat_ = outputs_hat_[-1] - outputs_ = outputs_[-1] - real_loss += self.real_criterion(outputs_) - fake_loss += self.fake_criterion(outputs_hat_) - if self.average_by_discriminators: - fake_loss /= i + 1 - real_loss /= i + 1 - else: - real_loss = self.real_criterion(outputs) - fake_loss = self.fake_criterion(outputs_hat) - - return real_loss, fake_loss - - def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, x.new_ones(x.size())) - - def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, x.new_zeros(x.size())) - - def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) - - def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) - - -class FeatureMatchLoss(torch.nn.Module): - """Feature matching loss module.""" - - def __init__( - self, - average_by_layers: bool = True, - average_by_discriminators: bool = True, - include_final_outputs: bool = False, - ): - """Initialize FeatureMatchLoss module. - - Args: - average_by_layers (bool): Whether to average the loss by the number - of layers. - average_by_discriminators (bool): Whether to average the loss by - the number of discriminators. - include_final_outputs (bool): Whether to include the final output of - each discriminator for loss calculation. - - """ - super().__init__() - self.average_by_layers = average_by_layers - self.average_by_discriminators = average_by_discriminators - self.include_final_outputs = include_final_outputs - - def forward( - self, - feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], - feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], - ) -> torch.Tensor: - """Calculate feature matching loss. - - Args: - feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of - discriminator outputs or list of discriminator outputs calcuated - from generator's outputs. - feats (Union[List[List[Tensor]], List[Tensor]]): List of list of - discriminator outputs or list of discriminator outputs calcuated - from groundtruth. - - Returns: - Tensor: Feature matching loss value. - - """ - feat_match_loss = 0.0 - for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): - feat_match_loss_ = 0.0 - if not self.include_final_outputs: - feats_hat_ = feats_hat_[:-1] - feats_ = feats_[:-1] - for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): - feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) - if self.average_by_layers: - feat_match_loss_ /= j + 1 - feat_match_loss += feat_match_loss_ - if self.average_by_discriminators: - feat_match_loss /= i + 1 - - return feat_match_loss - - -class MelSpectrogramLoss(torch.nn.Module): - """Mel-spectrogram loss.""" - - def __init__( - self, - sampling_rate: int = 22050, - frame_length: int = 1024, # in samples - frame_shift: int = 256, # in samples - n_mels: int = 80, - use_fft_mag: bool = True, - ): - super().__init__() - self.wav_to_mel = Wav2LogFilterBank( - sampling_rate=sampling_rate, - frame_length=frame_length / sampling_rate, # in second - frame_shift=frame_shift / sampling_rate, # in second - use_fft_mag=use_fft_mag, - num_filters=n_mels, - ) - - def forward( - self, - y_hat: torch.Tensor, - y: torch.Tensor, - return_mel: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: - """Calculate Mel-spectrogram loss. - - Args: - y_hat (Tensor): Generated waveform tensor (B, 1, T). - y (Tensor): Groundtruth waveform tensor (B, 1, T). - spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor - (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth - waveform. - - Returns: - Tensor: Mel-spectrogram loss value. - - """ - mel_hat = self.wav_to_mel(y_hat.squeeze(1)) - mel = self.wav_to_mel(y.squeeze(1)) - mel_loss = F.l1_loss(mel_hat, mel) - - if return_mel: - return mel_loss, (mel_hat, mel) - - return mel_loss - - -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py - -"""VITS-related loss modules. - -This code is based on https://github.com/jaywalnut310/vits. - -""" - - -class KLDivergenceLoss(torch.nn.Module): - """KL divergence loss.""" - - def forward( - self, - z_p: torch.Tensor, - logs_q: torch.Tensor, - m_p: torch.Tensor, - logs_p: torch.Tensor, - z_mask: torch.Tensor, - ) -> torch.Tensor: - """Calculate KL divergence loss. - - Args: - z_p (Tensor): Flow hidden representation (B, H, T_feats). - logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). - m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). - logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). - z_mask (Tensor): Mask tensor (B, 1, T_feats). - - Returns: - Tensor: KL divergence loss. - - """ - z_p = z_p.float() - logs_q = logs_q.float() - m_p = m_p.float() - logs_p = logs_p.float() - z_mask = z_mask.float() - kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) - kl = torch.sum(kl * z_mask) - loss = kl / torch.sum(z_mask) - - return loss - - -class KLDivergenceLossWithoutFlow(torch.nn.Module): - """KL divergence loss without flow.""" - - def forward( - self, - m_q: torch.Tensor, - logs_q: torch.Tensor, - m_p: torch.Tensor, - logs_p: torch.Tensor, - ) -> torch.Tensor: - """Calculate KL divergence loss without flow. - - Args: - m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). - logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). - m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). - logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). - """ - posterior_norm = D.Normal(m_q, torch.exp(logs_q)) - prior_norm = D.Normal(m_p, torch.exp(logs_p)) - loss = D.kl_divergence(posterior_norm, prior_norm).mean() - return loss diff --git a/egs/ljspeech/TTS/vits2/loss.py b/egs/ljspeech/TTS/vits2/loss.py new file mode 120000 index 000000000..672e5ff68 --- /dev/null +++ b/egs/ljspeech/TTS/vits2/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits2/residual_coupling.py b/egs/ljspeech/TTS/vits2/residual_coupling.py index f3de17ddd..d378d8509 100644 --- a/egs/ljspeech/TTS/vits2/residual_coupling.py +++ b/egs/ljspeech/TTS/vits2/residual_coupling.py @@ -360,10 +360,10 @@ class ResidualCouplingTransformersLayer(torch.nn.Module): xa, xb = x.split(x.size(1) // 2, dim=1) x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64)) - xa_trans = self.pre_transformer(xa.transpose(1, 2), x_trans_mask).transpose( - 1, 2 - ) - xa_ = xa + xa_trans + xa_ = self.pre_transformer( + (xa * x_mask).transpose(1, 2), x_trans_mask + ).transpose(1, 2) + xa_ = xa + xa_ h = self.input_conv(xa_) * x_mask h = self.encoder(h, x_mask, g=g) diff --git a/egs/ljspeech/TTS/vits2/tokenizer.py b/egs/ljspeech/TTS/vits2/tokenizer.py deleted file mode 100644 index 70f1240b4..000000000 --- a/egs/ljspeech/TTS/vits2/tokenizer.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List - -import g2p_en -import tacotron_cleaner.cleaners -from utils import intersperse - - -class Tokenizer(object): - def __init__(self, tokens: str): - """ - Args: - tokens: the file that maps tokens to ids - """ - # Parse token file - self.token2id: Dict[str, int] = {} - with open(tokens, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - if len(info) == 1: - # case of space - token = " " - id = int(info[0]) - else: - token, id = info[0], int(info[1]) - self.token2id[token] = id - - self.blank_id = self.token2id[""] - self.oov_id = self.token2id[""] - self.vocab_size = len(self.token2id) - - self.g2p = g2p_en.G2p() - - def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): - """ - Args: - texts: - A list of transcripts. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - - Returns: - Return a list of token id list [utterance][token_id] - """ - token_ids_list = [] - - for text in texts: - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens = self.g2p(text) - token_ids = [] - for t in tokens: - if t in self.token2id: - token_ids.append(self.token2id[t]) - else: - token_ids.append(self.oov_id) - - if intersperse_blank: - token_ids = intersperse(token_ids, self.blank_id) - - token_ids_list.append(token_ids) - - return token_ids_list - - def tokens_to_token_ids( - self, tokens_list: List[str], intersperse_blank: bool = True - ): - """ - Args: - tokens_list: - A list of token list, each corresponding to one utterance. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - - Returns: - Return a list of token id list [utterance][token_id] - """ - token_ids_list = [] - - for tokens in tokens_list: - token_ids = [] - for t in tokens: - if t in self.token2id: - token_ids.append(self.token2id[t]) - else: - token_ids.append(self.oov_id) - - if intersperse_blank: - token_ids = intersperse(token_ids, self.blank_id) - token_ids_list.append(token_ids) - - return token_ids_list diff --git a/egs/ljspeech/TTS/vits2/tokenizer.py b/egs/ljspeech/TTS/vits2/tokenizer.py new file mode 120000 index 000000000..057b0dc4b --- /dev/null +++ b/egs/ljspeech/TTS/vits2/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits2/train.py b/egs/ljspeech/TTS/vits2/train.py index 743720084..d11c1674e 100755 --- a/egs/ljspeech/TTS/vits2/train.py +++ b/egs/ljspeech/TTS/vits2/train.py @@ -433,7 +433,7 @@ def train_one_epoch( with autocast(enabled=params.use_fp16): # forward discriminator - loss_d, dur_loss, stats_d = model( + loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, feats=features, diff --git a/egs/ljspeech/TTS/vits2/transform.py b/egs/ljspeech/TTS/vits2/transform.py deleted file mode 100644 index c20d13130..000000000 --- a/egs/ljspeech/TTS/vits2/transform.py +++ /dev/null @@ -1,218 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py - -"""Flow-related transformation. - -This code is derived from https://github.com/bayesiains/nflows. - -""" - -import numpy as np -import torch -from torch.nn import functional as F - -DEFAULT_MIN_BIN_WIDTH = 1e-3 -DEFAULT_MIN_BIN_HEIGHT = 1e-3 -DEFAULT_MIN_DERIVATIVE = 1e-3 - - -# TODO(kan-bayashi): Documentation and type hint -def piecewise_rational_quadratic_transform( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if tails is None: - spline_fn = rational_quadratic_spline - spline_kwargs = {} - else: - spline_fn = unconstrained_rational_quadratic_spline - spline_kwargs = {"tails": tails, "tail_bound": tail_bound} - - outputs, logabsdet = spline_fn( - inputs=inputs, - unnormalized_widths=unnormalized_widths, - unnormalized_heights=unnormalized_heights, - unnormalized_derivatives=unnormalized_derivatives, - inverse=inverse, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - **spline_kwargs - ) - return outputs, logabsdet - - -# TODO(kan-bayashi): Documentation and type hint -def unconstrained_rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails="linear", - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) - outside_interval_mask = ~inside_interval_mask - - outputs = torch.zeros_like(inputs) - logabsdet = torch.zeros_like(inputs) - - if tails == "linear": - unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) - constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 - else: - raise RuntimeError("{} tails are not implemented.".format(tails)) - - ( - outputs[inside_interval_mask], - logabsdet[inside_interval_mask], - ) = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], - inverse=inverse, - left=-tail_bound, - right=tail_bound, - bottom=-tail_bound, - top=tail_bound, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - ) - - return outputs, logabsdet - - -# TODO(kan-bayashi): Documentation and type hint -def rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0.0, - right=1.0, - bottom=0.0, - top=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): - if torch.min(inputs) < left or torch.max(inputs) > right: - raise ValueError("Input to a transform is not within its domain") - - num_bins = unnormalized_widths.shape[-1] - - if min_bin_width * num_bins > 1.0: - raise ValueError("Minimal bin width too large for the number of bins") - if min_bin_height * num_bins > 1.0: - raise ValueError("Minimal bin height too large for the number of bins") - - widths = F.softmax(unnormalized_widths, dim=-1) - widths = min_bin_width + (1 - min_bin_width * num_bins) * widths - cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) - cumwidths = (right - left) * cumwidths + left - cumwidths[..., 0] = left - cumwidths[..., -1] = right - widths = cumwidths[..., 1:] - cumwidths[..., :-1] - - derivatives = min_derivative + F.softplus(unnormalized_derivatives) - - heights = F.softmax(unnormalized_heights, dim=-1) - heights = min_bin_height + (1 - min_bin_height * num_bins) * heights - cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) - cumheights = (top - bottom) * cumheights + bottom - cumheights[..., 0] = bottom - cumheights[..., -1] = top - heights = cumheights[..., 1:] - cumheights[..., :-1] - - if inverse: - bin_idx = _searchsorted(cumheights, inputs)[..., None] - else: - bin_idx = _searchsorted(cumwidths, inputs)[..., None] - - input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] - input_bin_widths = widths.gather(-1, bin_idx)[..., 0] - - input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] - delta = heights / widths - input_delta = delta.gather(-1, bin_idx)[..., 0] - - input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] - input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] - - input_heights = heights.gather(-1, bin_idx)[..., 0] - - if inverse: - a = (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) + input_heights * (input_delta - input_derivatives) - b = input_heights * input_derivatives - (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) - c = -input_delta * (inputs - input_cumheights) - - discriminant = b.pow(2) - 4 * a * c - assert (discriminant >= 0).all() - - root = (2 * c) / (-b - torch.sqrt(discriminant)) - outputs = root * input_bin_widths + input_cumwidths - - theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, -logabsdet - else: - theta = (inputs - input_cumwidths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - - numerator = input_heights * ( - input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta - ) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) - outputs = input_cumheights + numerator / denominator - - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, logabsdet - - -def _searchsorted(bin_locations, inputs, eps=1e-6): - bin_locations[..., -1] += eps - return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/TTS/vits2/transform.py b/egs/ljspeech/TTS/vits2/transform.py new file mode 120000 index 000000000..962647408 --- /dev/null +++ b/egs/ljspeech/TTS/vits2/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits2/tts_datamodule.py b/egs/ljspeech/TTS/vits2/tts_datamodule.py deleted file mode 100644 index 8ff868bc8..000000000 --- a/egs/ljspeech/TTS/vits2/tts_datamodule.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LJSpeechTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = 22050 - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" - ) diff --git a/egs/ljspeech/TTS/vits2/tts_datamodule.py b/egs/ljspeech/TTS/vits2/tts_datamodule.py new file mode 120000 index 000000000..7293ee330 --- /dev/null +++ b/egs/ljspeech/TTS/vits2/tts_datamodule.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tts_datamodule.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits2/vits.py b/egs/ljspeech/TTS/vits2/vits.py index 14b45ffba..7db737d4b 100644 --- a/egs/ljspeech/TTS/vits2/vits.py +++ b/egs/ljspeech/TTS/vits2/vits.py @@ -545,10 +545,6 @@ class VITS(nn.Module): discriminator_fake_loss=fake_loss.item(), ) - # reset cache - if reuse_cache or not self.training: - self._cache = None - return loss, stats def _forward_discrminator_duration( @@ -582,7 +578,6 @@ class VITS(nn.Module): """ # setup feats = feats.transpose(1, 2) - speech = speech.unsqueeze(1) # calculate generator outputs reuse_cache = True diff --git a/egs/ljspeech/TTS/vits2/wavenet.py b/egs/ljspeech/TTS/vits2/wavenet.py deleted file mode 100644 index 98fd775f5..000000000 --- a/egs/ljspeech/TTS/vits2/wavenet.py +++ /dev/null @@ -1,348 +0,0 @@ -# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""WaveNet modules. - -This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. - -""" - -import logging -import math -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F - - -class WaveNet(torch.nn.Module): - """WaveNet with global conditioning.""" - - def __init__( - self, - in_channels: int = 1, - out_channels: int = 1, - kernel_size: int = 3, - layers: int = 30, - stacks: int = 3, - base_dilation: int = 2, - residual_channels: int = 64, - aux_channels: int = -1, - gate_channels: int = 128, - skip_channels: int = 64, - global_channels: int = -1, - dropout_rate: float = 0.0, - bias: bool = True, - use_weight_norm: bool = True, - use_first_conv: bool = False, - use_last_conv: bool = False, - scale_residual: bool = False, - scale_skip_connect: bool = False, - ): - """Initialize WaveNet module. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - kernel_size (int): Kernel size of dilated convolution. - layers (int): Number of residual block layers. - stacks (int): Number of stacks i.e., dilation cycles. - base_dilation (int): Base dilation factor. - residual_channels (int): Number of channels in residual conv. - gate_channels (int): Number of channels in gated conv. - skip_channels (int): Number of channels in skip conv. - aux_channels (int): Number of channels for local conditioning feature. - global_channels (int): Number of channels for global conditioning feature. - dropout_rate (float): Dropout rate. 0.0 means no dropout applied. - bias (bool): Whether to use bias parameter in conv layer. - use_weight_norm (bool): Whether to use weight norm. If set to true, it will - be applied to all of the conv layers. - use_first_conv (bool): Whether to use the first conv layers. - use_last_conv (bool): Whether to use the last conv layers. - scale_residual (bool): Whether to scale the residual outputs. - scale_skip_connect (bool): Whether to scale the skip connection outputs. - - """ - super().__init__() - self.layers = layers - self.stacks = stacks - self.kernel_size = kernel_size - self.base_dilation = base_dilation - self.use_first_conv = use_first_conv - self.use_last_conv = use_last_conv - self.scale_skip_connect = scale_skip_connect - - # check the number of layers and stacks - assert layers % stacks == 0 - layers_per_stack = layers // stacks - - # define first convolution - if self.use_first_conv: - self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) - - # define residual blocks - self.conv_layers = torch.nn.ModuleList() - for layer in range(layers): - dilation = base_dilation ** (layer % layers_per_stack) - conv = ResidualBlock( - kernel_size=kernel_size, - residual_channels=residual_channels, - gate_channels=gate_channels, - skip_channels=skip_channels, - aux_channels=aux_channels, - global_channels=global_channels, - dilation=dilation, - dropout_rate=dropout_rate, - bias=bias, - scale_residual=scale_residual, - ) - self.conv_layers += [conv] - - # define output layers - if self.use_last_conv: - self.last_conv = torch.nn.Sequential( - torch.nn.ReLU(inplace=False), - Conv1d1x1(skip_channels, skip_channels, bias=True), - torch.nn.ReLU(inplace=False), - Conv1d1x1(skip_channels, out_channels, bias=True), - ) - - # apply weight norm - if use_weight_norm: - self.apply_weight_norm() - - def forward( - self, - x: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - c: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (Tensor): Input noise signal (B, 1, T) if use_first_conv else - (B, residual_channels, T). - x_mask (Optional[Tensor]): Mask tensor (B, 1, T). - c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). - g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). - - Returns: - Tensor: Output tensor (B, out_channels, T) if use_last_conv else - (B, residual_channels, T). - - """ - # encode to hidden representation - if self.use_first_conv: - x = self.first_conv(x) - - # residual block - skips = 0.0 - for f in self.conv_layers: - x, h = f(x, x_mask=x_mask, c=c, g=g) - skips = skips + h - x = skips - if self.scale_skip_connect: - x = x * math.sqrt(1.0 / len(self.conv_layers)) - - # apply final layers - if self.use_last_conv: - x = self.last_conv(x) - - return x - - def remove_weight_norm(self): - """Remove weight normalization module from all of the layers.""" - - def _remove_weight_norm(m: torch.nn.Module): - try: - logging.debug(f"Weight norm is removed from {m}.") - torch.nn.utils.remove_weight_norm(m) - except ValueError: # this module didn't have weight norm - return - - self.apply(_remove_weight_norm) - - def apply_weight_norm(self): - """Apply weight normalization module from all of the layers.""" - - def _apply_weight_norm(m: torch.nn.Module): - if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): - torch.nn.utils.weight_norm(m) - logging.debug(f"Weight norm is applied to {m}.") - - self.apply(_apply_weight_norm) - - @staticmethod - def _get_receptive_field_size( - layers: int, - stacks: int, - kernel_size: int, - base_dilation: int, - ) -> int: - assert layers % stacks == 0 - layers_per_cycle = layers // stacks - dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] - return (kernel_size - 1) * sum(dilations) + 1 - - @property - def receptive_field_size(self) -> int: - """Return receptive field size.""" - return self._get_receptive_field_size( - self.layers, self.stacks, self.kernel_size, self.base_dilation - ) - - -class Conv1d(torch.nn.Conv1d): - """Conv1d module with customized initialization.""" - - def __init__(self, *args, **kwargs): - """Initialize Conv1d module.""" - super().__init__(*args, **kwargs) - - def reset_parameters(self): - """Reset parameters.""" - torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") - if self.bias is not None: - torch.nn.init.constant_(self.bias, 0.0) - - -class Conv1d1x1(Conv1d): - """1x1 Conv1d with customized initialization.""" - - def __init__(self, in_channels: int, out_channels: int, bias: bool): - """Initialize 1x1 Conv1d module.""" - super().__init__( - in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias - ) - - -class ResidualBlock(torch.nn.Module): - """Residual block module in WaveNet.""" - - def __init__( - self, - kernel_size: int = 3, - residual_channels: int = 64, - gate_channels: int = 128, - skip_channels: int = 64, - aux_channels: int = 80, - global_channels: int = -1, - dropout_rate: float = 0.0, - dilation: int = 1, - bias: bool = True, - scale_residual: bool = False, - ): - """Initialize ResidualBlock module. - - Args: - kernel_size (int): Kernel size of dilation convolution layer. - residual_channels (int): Number of channels for residual connection. - skip_channels (int): Number of channels for skip connection. - aux_channels (int): Number of local conditioning channels. - dropout (float): Dropout probability. - dilation (int): Dilation factor. - bias (bool): Whether to add bias parameter in convolution layers. - scale_residual (bool): Whether to scale the residual outputs. - - """ - super().__init__() - self.dropout_rate = dropout_rate - self.residual_channels = residual_channels - self.skip_channels = skip_channels - self.scale_residual = scale_residual - - # check - assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." - assert gate_channels % 2 == 0 - - # dilation conv - padding = (kernel_size - 1) // 2 * dilation - self.conv = Conv1d( - residual_channels, - gate_channels, - kernel_size, - padding=padding, - dilation=dilation, - bias=bias, - ) - - # local conditioning - if aux_channels > 0: - self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) - else: - self.conv1x1_aux = None - - # global conditioning - if global_channels > 0: - self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) - else: - self.conv1x1_glo = None - - # conv output is split into two groups - gate_out_channels = gate_channels // 2 - - # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency - # (integrate res 1x1 + skip 1x1 convs) - self.conv1x1_out = Conv1d1x1( - gate_out_channels, residual_channels + skip_channels, bias=bias - ) - - def forward( - self, - x: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - c: Optional[torch.Tensor] = None, - g: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (Tensor): Input tensor (B, residual_channels, T). - x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). - c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). - g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). - - Returns: - Tensor: Output tensor for residual connection (B, residual_channels, T). - Tensor: Output tensor for skip connection (B, skip_channels, T). - - """ - residual = x - x = F.dropout(x, p=self.dropout_rate, training=self.training) - x = self.conv(x) - - # split into two part for gated activation - splitdim = 1 - xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) - - # local conditioning - if c is not None: - c = self.conv1x1_aux(c) - ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) - xa, xb = xa + ca, xb + cb - - # global conditioning - if g is not None: - g = self.conv1x1_glo(g) - ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) - xa, xb = xa + ga, xb + gb - - x = torch.tanh(xa) * torch.sigmoid(xb) - - # residual + skip 1x1 conv - x = self.conv1x1_out(x) - if x_mask is not None: - x = x * x_mask - - # split integrated conv results - x, s = x.split([self.residual_channels, self.skip_channels], dim=1) - - # for residual connection - x = x + residual - if self.scale_residual: - x = x * math.sqrt(0.5) - - return x, s diff --git a/egs/ljspeech/TTS/vits2/wavenet.py b/egs/ljspeech/TTS/vits2/wavenet.py new file mode 120000 index 000000000..28f0a78ee --- /dev/null +++ b/egs/ljspeech/TTS/vits2/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file