diff --git a/egs/ljspeech/TTS/vits2/duration_discriminator.py b/egs/ljspeech/TTS/vits2/duration_discriminator.py new file mode 100644 index 000000000..3df4f8690 --- /dev/null +++ b/egs/ljspeech/TTS/vits2/duration_discriminator.py @@ -0,0 +1,138 @@ +from typing import Optional + +import torch +from flow import Transpose + + +class DurationDiscriminator(torch.nn.Module): # vits2 + def __init__( + self, + channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + eps: float = 1e-5, + global_channels: int = -1, + ): + super().__init__() + + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.global_channels = global_channels + + self.dropout = torch.nn.Dropout(dropout_rate) + self.conv_1 = torch.nn.Conv1d( + channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + self.conv_2 = torch.nn.Conv1d( + hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + + self.norm_2 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1) + + self.pre_out_conv_1 = torch.nn.Conv1d( + 2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + + self.pre_out_norm_1 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + self.pre_out_conv_2 = torch.nn.Conv1d( + hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + + self.pre_out_norm_2 = torch.nn.Sequential( + Transpose(1, 2), + torch.nn.LayerNorm( + hidden_channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + ) + + if global_channels > 0: + self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1) + + self.output_layer = torch.nn.Sequential( + torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid() + ) + + def forward_probability( + self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor + ): + dur = self.dur_proj(dur) + x = torch.cat([x, dur], dim=1) + x = self.pre_out_conv_1(x * x_mask) + x = torch.relu(x) + x = self.pre_out_norm_1(x) + x = self.dropout(x) + x = self.pre_out_conv_2(x * x_mask) + x = torch.relu(x) + x = self.pre_out_norm_2(x) + x = self.dropout(x) + x = x * x_mask + x = x.transpose(1, 2) + output_prob = self.output_layer(x) + + return output_prob + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + dur_r: torch.Tensor, + dur_hat: torch.Tensor, + g: Optional[torch.Tensor] = None, + ): + + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond_layer(g) + + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.dropout(x) + + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.dropout(x) + + output_probs = [] + for dur in [dur_r, dur_hat]: + output_prob = self.forward_probability(x, x_mask, dur) + output_probs.append(output_prob) + + return output_probs diff --git a/egs/ljspeech/TTS/vits2/generator.py b/egs/ljspeech/TTS/vits2/generator.py index 345c06211..adb161ed2 100644 --- a/egs/ljspeech/TTS/vits2/generator.py +++ b/egs/ljspeech/TTS/vits2/generator.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from duration_predictor import DurationPredictor, StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder -from residual_coupling import ResidualAffineCouplingBlock +from residual_coupling import ResidualCouplingBlock from text_encoder import TextEncoder from torch.cuda.amp import autocast from utils import get_random_segments @@ -62,9 +62,11 @@ class VITSGenerator(torch.nn.Module): use_weight_norm_in_posterior_encoder: bool = True, flow_flows: int = 4, flow_kernel_size: int = 5, + flow_heads_transformer: int = 2, + flow_layers_transformer: int = 1, + flow_kernel_size_transformer: int = 3, flow_base_dilation: int = 1, flow_layers: int = 4, - flow_nheads: int = 2, flow_dropout_rate: float = 0.0, use_weight_norm_in_flow: bool = True, use_only_mean_in_flow: bool = True, @@ -122,6 +124,9 @@ class VITSGenerator(torch.nn.Module): normalization in posterior encoder. flow_flows (int): Number of flows in flow. flow_kernel_size (int): Kernel size in flow. + flow_heads_transformer (int): Number of heads for transformer in flow + flow_layers_transformer (int): Number of layers for transformer in flow + flow_kernel_size_transformer (int): Kernel size for transformer in flow flow_base_dilation (int): Base dilation in flow. flow_layers (int): Number of layers in flow. flow_dropout_rate (float): Dropout rate in flow @@ -181,12 +186,14 @@ class VITSGenerator(torch.nn.Module): dropout_rate=posterior_encoder_dropout_rate, use_weight_norm=use_weight_norm_in_posterior_encoder, ) - self.flow = ResidualAffineCouplingBlock( + self.flow = ResidualCouplingBlock( in_channels=hidden_channels, hidden_channels=hidden_channels, - num_heads=flow_nheads, flows=flow_flows, kernel_size=flow_kernel_size, + heads_transformer=flow_heads_transformer, + layers_transformer=flow_layers_transformer, + kernel_size_transformer=flow_kernel_size_transformer, base_dilation=flow_base_dilation, layers=flow_layers, global_channels=global_channels, diff --git a/egs/ljspeech/TTS/vits2/hifigan.py b/egs/ljspeech/TTS/vits2/hifigan.py index cb02a1494..589ac30f6 100644 --- a/egs/ljspeech/TTS/vits2/hifigan.py +++ b/egs/ljspeech/TTS/vits2/hifigan.py @@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional import numpy as np import torch import torch.nn.functional as F -from flow import Transpose class HiFiGANGenerator(torch.nn.Module): @@ -932,136 +931,3 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): msd_outs = self.msd(x) mpd_outs = self.mpd(x) return msd_outs + mpd_outs - - -class DurationDiscriminator(torch.nn.Module): # vits2 - def __init__( - self, - channels: int = 192, - hidden_channels: int = 192, - kernel_size: int = 3, - dropout_rate: float = 0.5, - eps: float = 1e-5, - global_channels: int = -1, - ): - super().__init__() - - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - self.global_channels = global_channels - - self.dropout = torch.nn.Dropout(dropout_rate) - self.conv_1 = torch.nn.Conv1d( - channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_1 = torch.nn.Sequential( - Transpose(1, 2), - torch.nn.LayerNorm( - hidden_channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - ) - - self.conv_2 = torch.nn.Conv1d( - hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - - self.norm_2 = torch.nn.Sequential( - Transpose(1, 2), - torch.nn.LayerNorm( - hidden_channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - ) - - self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1) - - self.pre_out_conv_1 = torch.nn.Conv1d( - 2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - - self.pre_out_norm_1 = torch.nn.Sequential( - Transpose(1, 2), - torch.nn.LayerNorm( - hidden_channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - ) - - self.pre_out_conv_2 = torch.nn.Conv1d( - hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - - self.pre_out_norm_2 = torch.nn.Sequential( - Transpose(1, 2), - torch.nn.LayerNorm( - hidden_channels, - eps=eps, - elementwise_affine=True, - ), - Transpose(1, 2), - ) - - if global_channels > 0: - self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1) - - self.output_layer = torch.nn.Sequential( - torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid() - ) - - def forward_probability( - self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor - ): - dur = self.dur_proj(dur) - x = torch.cat([x, dur], dim=1) - x = self.pre_out_conv_1(x * x_mask) - x = torch.relu(x) - x = self.pre_out_norm_1(x) - x = self.dropout(x) - x = self.pre_out_conv_2(x * x_mask) - x = torch.relu(x) - x = self.pre_out_norm_2(x) - x = self.dropout(x) - x = x * x_mask - x = x.transpose(1, 2) - output_prob = self.output_layer(x) - return output_prob - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - dur_r: torch.Tensor, - dur_hat: torch.Tensor, - g: Optional[torch.Tensor] = None, - ): - - x = torch.detach(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond_layer(g) - - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.dropout(x) - - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.dropout(x) - - output_probs = [] - for dur in [dur_r, dur_hat]: - output_prob = self.forward_probability(x, x_mask, dur) - output_probs.append(output_prob) - - return output_probs diff --git a/egs/ljspeech/TTS/vits2/residual_coupling.py b/egs/ljspeech/TTS/vits2/residual_coupling.py index d378d8509..1e95ab912 100644 --- a/egs/ljspeech/TTS/vits2/residual_coupling.py +++ b/egs/ljspeech/TTS/vits2/residual_coupling.py @@ -19,10 +19,10 @@ from wavenet import WaveNet from icefall.utils import make_pad_mask -class ResidualAffineCouplingBlock(torch.nn.Module): - """Residual affine coupling block module. +class ResidualCouplingBlock(torch.nn.Module): + """Residual coupling block module. - This is a module of residual affine coupling block, which used as "Flow" in + This is a module of residual affine/transformer coupling block, which used as "Flow" in `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. @@ -35,7 +35,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module): self, in_channels: int = 192, hidden_channels: int = 192, - num_heads: int = 2, + heads_transformer: int = 2, + layers_transformer: int = 4, + kernel_size_transformer: int = 3, flows: int = 4, kernel_size: int = 5, base_dilation: int = 1, @@ -47,7 +49,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module): use_only_mean: bool = True, use_transformer_in_flows: bool = True, ): - """Initilize ResidualAffineCouplingBlock module. + """Initilize ResidualCouplingBlock module. Args: in_channels (int): Number of input channels. @@ -73,7 +75,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module): ResidualCouplingTransformersLayer( in_channels=in_channels, hidden_channels=hidden_channels, - n_heads=num_heads, + heads_transformer=heads_transformer, + layers_transformer=layers_transformer, + kernel_size_transformer=kernel_size_transformer, kernel_size=kernel_size, base_dilation=base_dilation, layers=layers, @@ -258,9 +262,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module): self, in_channels: int = 192, hidden_channels: int = 192, - n_heads: int = 2, - n_layers: int = 2, - n_kernel_size: int = 5, + heads_transformer: int = 2, + layers_transformer: int = 2, + kernel_size_transformer: int = 5, kernel_size: int = 5, base_dilation: int = 1, layers: int = 5, @@ -292,9 +296,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module): self.pre_transformer = Transformer( self.half_channels, - num_heads=n_heads, - num_layers=n_layers, - cnn_module_kernel=n_kernel_size, + num_heads=heads_transformer, + num_layers=layers_transformer, + cnn_module_kernel=kernel_size_transformer, ) # define modules diff --git a/egs/ljspeech/TTS/vits2/vits.py b/egs/ljspeech/TTS/vits2/vits.py index 7f2e42a6a..5b5a248f5 100644 --- a/egs/ljspeech/TTS/vits2/vits.py +++ b/egs/ljspeech/TTS/vits2/vits.py @@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn +from duration_discriminator import DurationDiscriminator from generator import VITSGenerator from hifigan import ( - DurationDiscriminator, HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, HiFiGANMultiScaleMultiPeriodDiscriminator, @@ -76,10 +76,12 @@ class VITS(nn.Module): "posterior_encoder_dropout_rate": 0.0, "use_weight_norm_in_posterior_encoder": True, "flow_flows": 4, - "flow_kernel_size": 3, + "flow_kernel_size": 5, + "flow_heads_transformer": 2, # vits2 + "flow_layers_transformer": 1, # vits2 + "flow_kernel_size_transformer": 3, # vits2 "flow_base_dilation": 1, - "flow_layers": 2, - "flow_nheads": 2, + "flow_layers": 4, "flow_dropout_rate": 0.0, "use_weight_norm_in_flow": True, "use_only_mean_in_flow": True, @@ -90,9 +92,9 @@ class VITS(nn.Module): "stochastic_duration_predictor_dds_conv_layers": 3, "duration_predictor_output_channels": 256, "use_stochastic_duration_predictor": True, - "use_noised_mas": True, - "noise_initial_mas": 0.01, - "noise_scale_mas": 2e-06, + "use_noised_mas": True, # vits2 + "noise_initial_mas": 0.01, # vits2 + "noise_scale_mas": 2e-06, # vits2 }, # discriminator related discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",