mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Add transformer block
This commit is contained in:
parent
cafc33bac9
commit
b9fdebaff2
@ -63,6 +63,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
flow_kernel_size: int = 5,
|
||||
flow_base_dilation: int = 1,
|
||||
flow_layers: int = 4,
|
||||
flow_nheads: int = 2,
|
||||
flow_dropout_rate: float = 0.0,
|
||||
use_weight_norm_in_flow: bool = True,
|
||||
use_only_mean_in_flow: bool = True,
|
||||
@ -73,6 +74,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
use_noised_mas: bool = True,
|
||||
noise_initial_mas: float = 0.01,
|
||||
noise_scale_mas: float = 2e-6,
|
||||
use_transformer_in_flows: bool = True,
|
||||
):
|
||||
"""Initialize VITS generator module.
|
||||
|
||||
@ -170,6 +172,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
self.flow = ResidualAffineCouplingBlock(
|
||||
in_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
num_heads=flow_nheads,
|
||||
flows=flow_flows,
|
||||
kernel_size=flow_kernel_size,
|
||||
base_dilation=flow_base_dilation,
|
||||
@ -178,6 +181,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
dropout_rate=flow_dropout_rate,
|
||||
use_weight_norm=use_weight_norm_in_flow,
|
||||
use_only_mean=use_only_mean_in_flow,
|
||||
use_transformer_in_flows=use_transformer_in_flows,
|
||||
)
|
||||
# TODO(kan-bayashi): Add deterministic version as an option
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
|
@ -13,8 +13,11 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from flow import FlipFlow
|
||||
from text_encoder import Transformer
|
||||
from wavenet import WaveNet
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
"""Residual affine coupling block module.
|
||||
@ -32,6 +35,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
self,
|
||||
in_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
num_heads: int = 2,
|
||||
flows: int = 4,
|
||||
kernel_size: int = 5,
|
||||
base_dilation: int = 1,
|
||||
@ -41,6 +45,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
use_weight_norm: bool = True,
|
||||
bias: bool = True,
|
||||
use_only_mean: bool = True,
|
||||
use_transformer_in_flows: bool = True,
|
||||
):
|
||||
"""Initilize ResidualAffineCouplingBlock module.
|
||||
|
||||
@ -63,21 +68,39 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
|
||||
self.flows = torch.nn.ModuleList()
|
||||
for i in range(flows):
|
||||
self.flows += [
|
||||
ResidualAffineCouplingLayer(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
base_dilation=base_dilation,
|
||||
layers=layers,
|
||||
stacks=1,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
use_weight_norm=use_weight_norm,
|
||||
bias=bias,
|
||||
use_only_mean=use_only_mean,
|
||||
)
|
||||
]
|
||||
if use_transformer_in_flows:
|
||||
self.flows += [
|
||||
ResidualCouplingTransformersLayer(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
n_heads=num_heads,
|
||||
kernel_size=kernel_size,
|
||||
base_dilation=base_dilation,
|
||||
layers=layers,
|
||||
stacks=1,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
use_weight_norm=use_weight_norm,
|
||||
bias=bias,
|
||||
use_only_mean=use_only_mean,
|
||||
)
|
||||
]
|
||||
else:
|
||||
self.flows += [
|
||||
ResidualAffineCouplingLayer(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
base_dilation=base_dilation,
|
||||
layers=layers,
|
||||
stacks=1,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
use_weight_norm=use_weight_norm,
|
||||
bias=bias,
|
||||
use_only_mean=use_only_mean,
|
||||
)
|
||||
]
|
||||
self.flows += [FlipFlow()]
|
||||
|
||||
def forward(
|
||||
@ -226,3 +249,138 @@ class ResidualAffineCouplingLayer(torch.nn.Module):
|
||||
xb = (xb - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([xa, xb], 1)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingTransformersLayer(torch.nn.Module):
|
||||
"""Residual transformer coupling layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
n_heads: int = 2,
|
||||
n_layers: int = 2,
|
||||
n_kernel_size: int = 5,
|
||||
kernel_size: int = 5,
|
||||
base_dilation: int = 1,
|
||||
layers: int = 5,
|
||||
stacks: int = 1,
|
||||
global_channels: int = -1,
|
||||
dropout_rate: float = 0.0,
|
||||
use_weight_norm: bool = True,
|
||||
bias: bool = True,
|
||||
use_only_mean: bool = True,
|
||||
):
|
||||
"""Initialzie ResidualCouplingTransformersLayer module.
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size for WaveNet.
|
||||
base_dilation (int): Base dilation factor for WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of stacks of WaveNet.
|
||||
global_channels (int): Number of global channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||
bias (bool): Whether to use bias paramters in WaveNet.
|
||||
use_only_mean (bool): Whether to estimate only mean.
|
||||
"""
|
||||
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.half_channels = in_channels // 2
|
||||
self.use_only_mean = use_only_mean
|
||||
|
||||
self.pre_transformer = Transformer(
|
||||
self.half_channels,
|
||||
num_heads=n_heads,
|
||||
num_layers=n_layers,
|
||||
cnn_module_kernel=n_kernel_size,
|
||||
)
|
||||
|
||||
# define modules
|
||||
self.input_conv = torch.nn.Conv1d(
|
||||
self.half_channels,
|
||||
hidden_channels,
|
||||
1,
|
||||
)
|
||||
|
||||
self.encoder = WaveNet(
|
||||
in_channels=-1,
|
||||
out_channels=-1,
|
||||
kernel_size=kernel_size,
|
||||
layers=layers,
|
||||
stacks=stacks,
|
||||
base_dilation=base_dilation,
|
||||
residual_channels=hidden_channels,
|
||||
aux_channels=-1,
|
||||
gate_channels=hidden_channels * 2,
|
||||
skip_channels=hidden_channels,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
use_weight_norm=use_weight_norm,
|
||||
use_first_conv=False,
|
||||
use_last_conv=False,
|
||||
scale_residual=False,
|
||||
scale_skip_connect=True,
|
||||
)
|
||||
|
||||
if use_only_mean:
|
||||
self.proj = torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
self.half_channels,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.proj = torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
self.half_channels * 2,
|
||||
1,
|
||||
)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
inverse: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
Returns:
|
||||
Tensor: Output tensor (B, in_channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
"""
|
||||
xa, xb = x.split(x.size(1) // 2, dim=1)
|
||||
|
||||
x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64))
|
||||
xa_trans = self.pre_transformer(xa.transpose(1, 2), x_trans_mask).transpose(
|
||||
1, 2
|
||||
)
|
||||
xa_ = xa + xa_trans
|
||||
|
||||
h = self.input_conv(xa_) * x_mask
|
||||
h = self.encoder(h, x_mask, g=g)
|
||||
|
||||
stats = self.proj(h) * x_mask
|
||||
if not self.use_only_mean:
|
||||
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not inverse:
|
||||
xb = m + xb * torch.exp(logs) * x_mask
|
||||
x = torch.cat([xa, xb], 1)
|
||||
logdet = torch.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
xb = (xb - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([xa, xb], 1)
|
||||
return x
|
||||
|
@ -395,9 +395,7 @@ def train_one_epoch(
|
||||
# MAS with Gaussian Noise
|
||||
model.module.generator.noise_current_mas = max(
|
||||
model.module.generator.noise_initial_mas
|
||||
- model.module.generator.noise_scale_mas
|
||||
* params.batch_idx_train
|
||||
* 0.25,
|
||||
- model.module.generator.noise_scale_mas * params.batch_idx_train,
|
||||
0.0,
|
||||
)
|
||||
|
||||
|
@ -75,12 +75,14 @@ class VITS(nn.Module):
|
||||
"posterior_encoder_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_posterior_encoder": True,
|
||||
"flow_flows": 4,
|
||||
"flow_kernel_size": 5,
|
||||
"flow_kernel_size": 3,
|
||||
"flow_base_dilation": 1,
|
||||
"flow_layers": 4,
|
||||
"flow_layers": 2,
|
||||
"flow_nheads": 2,
|
||||
"flow_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_flow": True,
|
||||
"use_only_mean_in_flow": True,
|
||||
"use_transformer_in_flows": True,
|
||||
"stochastic_duration_predictor_kernel_size": 3,
|
||||
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||
"stochastic_duration_predictor_flows": 4,
|
||||
|
Loading…
x
Reference in New Issue
Block a user