Expose coupling transformer parameters

This commit is contained in:
Erwan 2024-03-01 09:47:38 +01:00
parent e96d91555b
commit 87e1d286cf
5 changed files with 174 additions and 157 deletions

View File

@ -0,0 +1,138 @@
from typing import Optional
import torch
from flow import Transpose
class DurationDiscriminator(torch.nn.Module): # vits2
def __init__(
self,
channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
eps: float = 1e-5,
global_channels: int = -1,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.global_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
self.pre_out_conv_1 = torch.nn.Conv1d(
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.pre_out_conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
if global_channels > 0:
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
)
def forward_probability(
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
dur_r: torch.Tensor,
dur_hat: torch.Tensor,
g: Optional[torch.Tensor] = None,
):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond_layer(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur)
output_probs.append(output_prob)
return output_probs

View File

@ -19,7 +19,7 @@ import torch.nn.functional as F
from duration_predictor import DurationPredictor, StochasticDurationPredictor from duration_predictor import DurationPredictor, StochasticDurationPredictor
from hifigan import HiFiGANGenerator from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock from residual_coupling import ResidualCouplingBlock
from text_encoder import TextEncoder from text_encoder import TextEncoder
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from utils import get_random_segments from utils import get_random_segments
@ -62,9 +62,11 @@ class VITSGenerator(torch.nn.Module):
use_weight_norm_in_posterior_encoder: bool = True, use_weight_norm_in_posterior_encoder: bool = True,
flow_flows: int = 4, flow_flows: int = 4,
flow_kernel_size: int = 5, flow_kernel_size: int = 5,
flow_heads_transformer: int = 2,
flow_layers_transformer: int = 1,
flow_kernel_size_transformer: int = 3,
flow_base_dilation: int = 1, flow_base_dilation: int = 1,
flow_layers: int = 4, flow_layers: int = 4,
flow_nheads: int = 2,
flow_dropout_rate: float = 0.0, flow_dropout_rate: float = 0.0,
use_weight_norm_in_flow: bool = True, use_weight_norm_in_flow: bool = True,
use_only_mean_in_flow: bool = True, use_only_mean_in_flow: bool = True,
@ -122,6 +124,9 @@ class VITSGenerator(torch.nn.Module):
normalization in posterior encoder. normalization in posterior encoder.
flow_flows (int): Number of flows in flow. flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow. flow_kernel_size (int): Kernel size in flow.
flow_heads_transformer (int): Number of heads for transformer in flow
flow_layers_transformer (int): Number of layers for transformer in flow
flow_kernel_size_transformer (int): Kernel size for transformer in flow
flow_base_dilation (int): Base dilation in flow. flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow. flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow flow_dropout_rate (float): Dropout rate in flow
@ -181,12 +186,14 @@ class VITSGenerator(torch.nn.Module):
dropout_rate=posterior_encoder_dropout_rate, dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder, use_weight_norm=use_weight_norm_in_posterior_encoder,
) )
self.flow = ResidualAffineCouplingBlock( self.flow = ResidualCouplingBlock(
in_channels=hidden_channels, in_channels=hidden_channels,
hidden_channels=hidden_channels, hidden_channels=hidden_channels,
num_heads=flow_nheads,
flows=flow_flows, flows=flow_flows,
kernel_size=flow_kernel_size, kernel_size=flow_kernel_size,
heads_transformer=flow_heads_transformer,
layers_transformer=flow_layers_transformer,
kernel_size_transformer=flow_kernel_size_transformer,
base_dilation=flow_base_dilation, base_dilation=flow_base_dilation,
layers=flow_layers, layers=flow_layers,
global_channels=global_channels, global_channels=global_channels,

View File

@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flow import Transpose
class HiFiGANGenerator(torch.nn.Module): class HiFiGANGenerator(torch.nn.Module):
@ -932,136 +931,3 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
msd_outs = self.msd(x) msd_outs = self.msd(x)
mpd_outs = self.mpd(x) mpd_outs = self.mpd(x)
return msd_outs + mpd_outs return msd_outs + mpd_outs
class DurationDiscriminator(torch.nn.Module): # vits2
def __init__(
self,
channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
eps: float = 1e-5,
global_channels: int = -1,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.global_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
self.pre_out_conv_1 = torch.nn.Conv1d(
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.pre_out_conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
if global_channels > 0:
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
)
def forward_probability(
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
dur_r: torch.Tensor,
dur_hat: torch.Tensor,
g: Optional[torch.Tensor] = None,
):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond_layer(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur)
output_probs.append(output_prob)
return output_probs

View File

@ -19,10 +19,10 @@ from wavenet import WaveNet
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
class ResidualAffineCouplingBlock(torch.nn.Module): class ResidualCouplingBlock(torch.nn.Module):
"""Residual affine coupling block module. """Residual coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in This is a module of residual affine/transformer coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End `Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_. Text-to-Speech`_.
@ -35,7 +35,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
self, self,
in_channels: int = 192, in_channels: int = 192,
hidden_channels: int = 192, hidden_channels: int = 192,
num_heads: int = 2, heads_transformer: int = 2,
layers_transformer: int = 4,
kernel_size_transformer: int = 3,
flows: int = 4, flows: int = 4,
kernel_size: int = 5, kernel_size: int = 5,
base_dilation: int = 1, base_dilation: int = 1,
@ -47,7 +49,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
use_only_mean: bool = True, use_only_mean: bool = True,
use_transformer_in_flows: bool = True, use_transformer_in_flows: bool = True,
): ):
"""Initilize ResidualAffineCouplingBlock module. """Initilize ResidualCouplingBlock module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int): Number of input channels.
@ -73,7 +75,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
ResidualCouplingTransformersLayer( ResidualCouplingTransformersLayer(
in_channels=in_channels, in_channels=in_channels,
hidden_channels=hidden_channels, hidden_channels=hidden_channels,
n_heads=num_heads, heads_transformer=heads_transformer,
layers_transformer=layers_transformer,
kernel_size_transformer=kernel_size_transformer,
kernel_size=kernel_size, kernel_size=kernel_size,
base_dilation=base_dilation, base_dilation=base_dilation,
layers=layers, layers=layers,
@ -258,9 +262,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
self, self,
in_channels: int = 192, in_channels: int = 192,
hidden_channels: int = 192, hidden_channels: int = 192,
n_heads: int = 2, heads_transformer: int = 2,
n_layers: int = 2, layers_transformer: int = 2,
n_kernel_size: int = 5, kernel_size_transformer: int = 5,
kernel_size: int = 5, kernel_size: int = 5,
base_dilation: int = 1, base_dilation: int = 1,
layers: int = 5, layers: int = 5,
@ -292,9 +296,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
self.pre_transformer = Transformer( self.pre_transformer = Transformer(
self.half_channels, self.half_channels,
num_heads=n_heads, num_heads=heads_transformer,
num_layers=n_layers, num_layers=layers_transformer,
cnn_module_kernel=n_kernel_size, cnn_module_kernel=kernel_size_transformer,
) )
# define modules # define modules

View File

@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from duration_discriminator import DurationDiscriminator
from generator import VITSGenerator from generator import VITSGenerator
from hifigan import ( from hifigan import (
DurationDiscriminator,
HiFiGANMultiPeriodDiscriminator, HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator, HiFiGANMultiScaleDiscriminator,
HiFiGANMultiScaleMultiPeriodDiscriminator, HiFiGANMultiScaleMultiPeriodDiscriminator,
@ -76,10 +76,12 @@ class VITS(nn.Module):
"posterior_encoder_dropout_rate": 0.0, "posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True, "use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4, "flow_flows": 4,
"flow_kernel_size": 3, "flow_kernel_size": 5,
"flow_heads_transformer": 2, # vits2
"flow_layers_transformer": 1, # vits2
"flow_kernel_size_transformer": 3, # vits2
"flow_base_dilation": 1, "flow_base_dilation": 1,
"flow_layers": 2, "flow_layers": 4,
"flow_nheads": 2,
"flow_dropout_rate": 0.0, "flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True, "use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True, "use_only_mean_in_flow": True,
@ -90,9 +92,9 @@ class VITS(nn.Module):
"stochastic_duration_predictor_dds_conv_layers": 3, "stochastic_duration_predictor_dds_conv_layers": 3,
"duration_predictor_output_channels": 256, "duration_predictor_output_channels": 256,
"use_stochastic_duration_predictor": True, "use_stochastic_duration_predictor": True,
"use_noised_mas": True, "use_noised_mas": True, # vits2
"noise_initial_mas": 0.01, "noise_initial_mas": 0.01, # vits2
"noise_scale_mas": 2e-06, "noise_scale_mas": 2e-06, # vits2
}, },
# discriminator related # discriminator related
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",