diff --git a/egs/ljspeech/TTS/vits2/generator.py b/egs/ljspeech/TTS/vits2/generator.py index 5b64410a9..15f5f5187 100644 --- a/egs/ljspeech/TTS/vits2/generator.py +++ b/egs/ljspeech/TTS/vits2/generator.py @@ -63,6 +63,7 @@ class VITSGenerator(torch.nn.Module): flow_kernel_size: int = 5, flow_base_dilation: int = 1, flow_layers: int = 4, + flow_nheads: int = 2, flow_dropout_rate: float = 0.0, use_weight_norm_in_flow: bool = True, use_only_mean_in_flow: bool = True, @@ -73,6 +74,7 @@ class VITSGenerator(torch.nn.Module): use_noised_mas: bool = True, noise_initial_mas: float = 0.01, noise_scale_mas: float = 2e-6, + use_transformer_in_flows: bool = True, ): """Initialize VITS generator module. @@ -170,6 +172,7 @@ class VITSGenerator(torch.nn.Module): self.flow = ResidualAffineCouplingBlock( in_channels=hidden_channels, hidden_channels=hidden_channels, + num_heads=flow_nheads, flows=flow_flows, kernel_size=flow_kernel_size, base_dilation=flow_base_dilation, @@ -178,6 +181,7 @@ class VITSGenerator(torch.nn.Module): dropout_rate=flow_dropout_rate, use_weight_norm=use_weight_norm_in_flow, use_only_mean=use_only_mean_in_flow, + use_transformer_in_flows=use_transformer_in_flows, ) # TODO(kan-bayashi): Add deterministic version as an option self.duration_predictor = StochasticDurationPredictor( diff --git a/egs/ljspeech/TTS/vits2/residual_coupling.py b/egs/ljspeech/TTS/vits2/residual_coupling.py index f9a2a3786..f3de17ddd 100644 --- a/egs/ljspeech/TTS/vits2/residual_coupling.py +++ b/egs/ljspeech/TTS/vits2/residual_coupling.py @@ -13,8 +13,11 @@ from typing import Optional, Tuple, Union import torch from flow import FlipFlow +from text_encoder import Transformer from wavenet import WaveNet +from icefall.utils import make_pad_mask + class ResidualAffineCouplingBlock(torch.nn.Module): """Residual affine coupling block module. @@ -32,6 +35,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module): self, in_channels: int = 192, hidden_channels: int = 192, + num_heads: int = 2, flows: int = 4, kernel_size: int = 5, base_dilation: int = 1, @@ -41,6 +45,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module): use_weight_norm: bool = True, bias: bool = True, use_only_mean: bool = True, + use_transformer_in_flows: bool = True, ): """Initilize ResidualAffineCouplingBlock module. @@ -63,21 +68,39 @@ class ResidualAffineCouplingBlock(torch.nn.Module): self.flows = torch.nn.ModuleList() for i in range(flows): - self.flows += [ - ResidualAffineCouplingLayer( - in_channels=in_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - base_dilation=base_dilation, - layers=layers, - stacks=1, - global_channels=global_channels, - dropout_rate=dropout_rate, - use_weight_norm=use_weight_norm, - bias=bias, - use_only_mean=use_only_mean, - ) - ] + if use_transformer_in_flows: + self.flows += [ + ResidualCouplingTransformersLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + n_heads=num_heads, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] + else: + self.flows += [ + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] self.flows += [FlipFlow()] def forward( @@ -226,3 +249,138 @@ class ResidualAffineCouplingLayer(torch.nn.Module): xb = (xb - m) * torch.exp(-logs) * x_mask x = torch.cat([xa, xb], 1) return x + + +class ResidualCouplingTransformersLayer(torch.nn.Module): + """Residual transformer coupling layer.""" + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + n_heads: int = 2, + n_layers: int = 2, + n_kernel_size: int = 5, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 5, + stacks: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initialzie ResidualCouplingTransformersLayer module. + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + self.pre_transformer = Transformer( + self.half_channels, + num_heads=n_heads, + num_layers=n_layers, + cnn_module_kernel=n_kernel_size, + ) + + # define modules + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + + if use_only_mean: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels, + 1, + ) + else: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * 2, + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + """ + xa, xb = x.split(x.size(1) // 2, dim=1) + + x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64)) + xa_trans = self.pre_transformer(xa.transpose(1, 2), x_trans_mask).transpose( + 1, 2 + ) + xa_ = xa + xa_trans + + h = self.input_conv(xa_) * x_mask + h = self.encoder(h, x_mask, g=g) + + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = stats.split(stats.size(1) // 2, dim=1) + else: + m = stats + logs = torch.zeros_like(m) + + if not inverse: + xb = m + xb * torch.exp(logs) * x_mask + x = torch.cat([xa, xb], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * torch.exp(-logs) * x_mask + x = torch.cat([xa, xb], 1) + return x diff --git a/egs/ljspeech/TTS/vits2/train.py b/egs/ljspeech/TTS/vits2/train.py index 5414494e2..8cdbc4623 100755 --- a/egs/ljspeech/TTS/vits2/train.py +++ b/egs/ljspeech/TTS/vits2/train.py @@ -395,9 +395,7 @@ def train_one_epoch( # MAS with Gaussian Noise model.module.generator.noise_current_mas = max( model.module.generator.noise_initial_mas - - model.module.generator.noise_scale_mas - * params.batch_idx_train - * 0.25, + - model.module.generator.noise_scale_mas * params.batch_idx_train, 0.0, ) diff --git a/egs/ljspeech/TTS/vits2/vits.py b/egs/ljspeech/TTS/vits2/vits.py index f8ad7707f..02087d07c 100644 --- a/egs/ljspeech/TTS/vits2/vits.py +++ b/egs/ljspeech/TTS/vits2/vits.py @@ -75,12 +75,14 @@ class VITS(nn.Module): "posterior_encoder_dropout_rate": 0.0, "use_weight_norm_in_posterior_encoder": True, "flow_flows": 4, - "flow_kernel_size": 5, + "flow_kernel_size": 3, "flow_base_dilation": 1, - "flow_layers": 4, + "flow_layers": 2, + "flow_nheads": 2, "flow_dropout_rate": 0.0, "use_weight_norm_in_flow": True, "use_only_mean_in_flow": True, + "use_transformer_in_flows": True, "stochastic_duration_predictor_kernel_size": 3, "stochastic_duration_predictor_dropout_rate": 0.5, "stochastic_duration_predictor_flows": 4,