Add transformer block

This commit is contained in:
Erwan 2024-02-08 17:36:49 +01:00
parent cafc33bac9
commit b9fdebaff2
4 changed files with 182 additions and 20 deletions

View File

@ -63,6 +63,7 @@ class VITSGenerator(torch.nn.Module):
flow_kernel_size: int = 5, flow_kernel_size: int = 5,
flow_base_dilation: int = 1, flow_base_dilation: int = 1,
flow_layers: int = 4, flow_layers: int = 4,
flow_nheads: int = 2,
flow_dropout_rate: float = 0.0, flow_dropout_rate: float = 0.0,
use_weight_norm_in_flow: bool = True, use_weight_norm_in_flow: bool = True,
use_only_mean_in_flow: bool = True, use_only_mean_in_flow: bool = True,
@ -73,6 +74,7 @@ class VITSGenerator(torch.nn.Module):
use_noised_mas: bool = True, use_noised_mas: bool = True,
noise_initial_mas: float = 0.01, noise_initial_mas: float = 0.01,
noise_scale_mas: float = 2e-6, noise_scale_mas: float = 2e-6,
use_transformer_in_flows: bool = True,
): ):
"""Initialize VITS generator module. """Initialize VITS generator module.
@ -170,6 +172,7 @@ class VITSGenerator(torch.nn.Module):
self.flow = ResidualAffineCouplingBlock( self.flow = ResidualAffineCouplingBlock(
in_channels=hidden_channels, in_channels=hidden_channels,
hidden_channels=hidden_channels, hidden_channels=hidden_channels,
num_heads=flow_nheads,
flows=flow_flows, flows=flow_flows,
kernel_size=flow_kernel_size, kernel_size=flow_kernel_size,
base_dilation=flow_base_dilation, base_dilation=flow_base_dilation,
@ -178,6 +181,7 @@ class VITSGenerator(torch.nn.Module):
dropout_rate=flow_dropout_rate, dropout_rate=flow_dropout_rate,
use_weight_norm=use_weight_norm_in_flow, use_weight_norm=use_weight_norm_in_flow,
use_only_mean=use_only_mean_in_flow, use_only_mean=use_only_mean_in_flow,
use_transformer_in_flows=use_transformer_in_flows,
) )
# TODO(kan-bayashi): Add deterministic version as an option # TODO(kan-bayashi): Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor( self.duration_predictor = StochasticDurationPredictor(

View File

@ -13,8 +13,11 @@ from typing import Optional, Tuple, Union
import torch import torch
from flow import FlipFlow from flow import FlipFlow
from text_encoder import Transformer
from wavenet import WaveNet from wavenet import WaveNet
from icefall.utils import make_pad_mask
class ResidualAffineCouplingBlock(torch.nn.Module): class ResidualAffineCouplingBlock(torch.nn.Module):
"""Residual affine coupling block module. """Residual affine coupling block module.
@ -32,6 +35,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
self, self,
in_channels: int = 192, in_channels: int = 192,
hidden_channels: int = 192, hidden_channels: int = 192,
num_heads: int = 2,
flows: int = 4, flows: int = 4,
kernel_size: int = 5, kernel_size: int = 5,
base_dilation: int = 1, base_dilation: int = 1,
@ -41,6 +45,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
use_weight_norm: bool = True, use_weight_norm: bool = True,
bias: bool = True, bias: bool = True,
use_only_mean: bool = True, use_only_mean: bool = True,
use_transformer_in_flows: bool = True,
): ):
"""Initilize ResidualAffineCouplingBlock module. """Initilize ResidualAffineCouplingBlock module.
@ -63,21 +68,39 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
self.flows = torch.nn.ModuleList() self.flows = torch.nn.ModuleList()
for i in range(flows): for i in range(flows):
self.flows += [ if use_transformer_in_flows:
ResidualAffineCouplingLayer( self.flows += [
in_channels=in_channels, ResidualCouplingTransformersLayer(
hidden_channels=hidden_channels, in_channels=in_channels,
kernel_size=kernel_size, hidden_channels=hidden_channels,
base_dilation=base_dilation, n_heads=num_heads,
layers=layers, kernel_size=kernel_size,
stacks=1, base_dilation=base_dilation,
global_channels=global_channels, layers=layers,
dropout_rate=dropout_rate, stacks=1,
use_weight_norm=use_weight_norm, global_channels=global_channels,
bias=bias, dropout_rate=dropout_rate,
use_only_mean=use_only_mean, use_weight_norm=use_weight_norm,
) bias=bias,
] use_only_mean=use_only_mean,
)
]
else:
self.flows += [
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean,
)
]
self.flows += [FlipFlow()] self.flows += [FlipFlow()]
def forward( def forward(
@ -226,3 +249,138 @@ class ResidualAffineCouplingLayer(torch.nn.Module):
xb = (xb - m) * torch.exp(-logs) * x_mask xb = (xb - m) * torch.exp(-logs) * x_mask
x = torch.cat([xa, xb], 1) x = torch.cat([xa, xb], 1)
return x return x
class ResidualCouplingTransformersLayer(torch.nn.Module):
"""Residual transformer coupling layer."""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
n_heads: int = 2,
n_layers: int = 2,
n_kernel_size: int = 5,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 5,
stacks: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initialzie ResidualCouplingTransformersLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
self.pre_transformer = Transformer(
self.half_channels,
num_heads=n_heads,
num_layers=n_layers,
cnn_module_kernel=n_kernel_size,
)
# define modules
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
if use_only_mean:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels,
1,
)
else:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * 2,
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, dim=1)
x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64))
xa_trans = self.pre_transformer(xa.transpose(1, 2), x_trans_mask).transpose(
1, 2
)
xa_ = xa + xa_trans
h = self.input_conv(xa_) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = stats.split(stats.size(1) // 2, dim=1)
else:
m = stats
logs = torch.zeros_like(m)
if not inverse:
xb = m + xb * torch.exp(logs) * x_mask
x = torch.cat([xa, xb], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * torch.exp(-logs) * x_mask
x = torch.cat([xa, xb], 1)
return x

View File

@ -395,9 +395,7 @@ def train_one_epoch(
# MAS with Gaussian Noise # MAS with Gaussian Noise
model.module.generator.noise_current_mas = max( model.module.generator.noise_current_mas = max(
model.module.generator.noise_initial_mas model.module.generator.noise_initial_mas
- model.module.generator.noise_scale_mas - model.module.generator.noise_scale_mas * params.batch_idx_train,
* params.batch_idx_train
* 0.25,
0.0, 0.0,
) )

View File

@ -75,12 +75,14 @@ class VITS(nn.Module):
"posterior_encoder_dropout_rate": 0.0, "posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True, "use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4, "flow_flows": 4,
"flow_kernel_size": 5, "flow_kernel_size": 3,
"flow_base_dilation": 1, "flow_base_dilation": 1,
"flow_layers": 4, "flow_layers": 2,
"flow_nheads": 2,
"flow_dropout_rate": 0.0, "flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True, "use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True, "use_only_mean_in_flow": True,
"use_transformer_in_flows": True,
"stochastic_duration_predictor_kernel_size": 3, "stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5, "stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4, "stochastic_duration_predictor_flows": 4,