Expose coupling transformer parameters

This commit is contained in:
Erwan 2024-03-01 09:47:38 +01:00
parent e96d91555b
commit 87e1d286cf
5 changed files with 174 additions and 157 deletions

View File

@ -0,0 +1,138 @@
from typing import Optional
import torch
from flow import Transpose
class DurationDiscriminator(torch.nn.Module): # vits2
def __init__(
self,
channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
eps: float = 1e-5,
global_channels: int = -1,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.global_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
self.pre_out_conv_1 = torch.nn.Conv1d(
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.pre_out_conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
if global_channels > 0:
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
)
def forward_probability(
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
dur_r: torch.Tensor,
dur_hat: torch.Tensor,
g: Optional[torch.Tensor] = None,
):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond_layer(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur)
output_probs.append(output_prob)
return output_probs

View File

@ -19,7 +19,7 @@ import torch.nn.functional as F
from duration_predictor import DurationPredictor, StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock
from residual_coupling import ResidualCouplingBlock
from text_encoder import TextEncoder
from torch.cuda.amp import autocast
from utils import get_random_segments
@ -62,9 +62,11 @@ class VITSGenerator(torch.nn.Module):
use_weight_norm_in_posterior_encoder: bool = True,
flow_flows: int = 4,
flow_kernel_size: int = 5,
flow_heads_transformer: int = 2,
flow_layers_transformer: int = 1,
flow_kernel_size_transformer: int = 3,
flow_base_dilation: int = 1,
flow_layers: int = 4,
flow_nheads: int = 2,
flow_dropout_rate: float = 0.0,
use_weight_norm_in_flow: bool = True,
use_only_mean_in_flow: bool = True,
@ -122,6 +124,9 @@ class VITSGenerator(torch.nn.Module):
normalization in posterior encoder.
flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow.
flow_heads_transformer (int): Number of heads for transformer in flow
flow_layers_transformer (int): Number of layers for transformer in flow
flow_kernel_size_transformer (int): Kernel size for transformer in flow
flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow
@ -181,12 +186,14 @@ class VITSGenerator(torch.nn.Module):
dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder,
)
self.flow = ResidualAffineCouplingBlock(
self.flow = ResidualCouplingBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
num_heads=flow_nheads,
flows=flow_flows,
kernel_size=flow_kernel_size,
heads_transformer=flow_heads_transformer,
layers_transformer=flow_layers_transformer,
kernel_size_transformer=flow_kernel_size_transformer,
base_dilation=flow_base_dilation,
layers=flow_layers,
global_channels=global_channels,

View File

@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from flow import Transpose
class HiFiGANGenerator(torch.nn.Module):
@ -932,136 +931,3 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
msd_outs = self.msd(x)
mpd_outs = self.mpd(x)
return msd_outs + mpd_outs
class DurationDiscriminator(torch.nn.Module): # vits2
def __init__(
self,
channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
eps: float = 1e-5,
global_channels: int = -1,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.global_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
self.pre_out_conv_1 = torch.nn.Conv1d(
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.pre_out_conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
if global_channels > 0:
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
)
def forward_probability(
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
dur_r: torch.Tensor,
dur_hat: torch.Tensor,
g: Optional[torch.Tensor] = None,
):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond_layer(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur)
output_probs.append(output_prob)
return output_probs

View File

@ -19,10 +19,10 @@ from wavenet import WaveNet
from icefall.utils import make_pad_mask
class ResidualAffineCouplingBlock(torch.nn.Module):
"""Residual affine coupling block module.
class ResidualCouplingBlock(torch.nn.Module):
"""Residual coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
This is a module of residual affine/transformer coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
@ -35,7 +35,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
self,
in_channels: int = 192,
hidden_channels: int = 192,
num_heads: int = 2,
heads_transformer: int = 2,
layers_transformer: int = 4,
kernel_size_transformer: int = 3,
flows: int = 4,
kernel_size: int = 5,
base_dilation: int = 1,
@ -47,7 +49,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
use_only_mean: bool = True,
use_transformer_in_flows: bool = True,
):
"""Initilize ResidualAffineCouplingBlock module.
"""Initilize ResidualCouplingBlock module.
Args:
in_channels (int): Number of input channels.
@ -73,7 +75,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
ResidualCouplingTransformersLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
n_heads=num_heads,
heads_transformer=heads_transformer,
layers_transformer=layers_transformer,
kernel_size_transformer=kernel_size_transformer,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
@ -258,9 +262,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
self,
in_channels: int = 192,
hidden_channels: int = 192,
n_heads: int = 2,
n_layers: int = 2,
n_kernel_size: int = 5,
heads_transformer: int = 2,
layers_transformer: int = 2,
kernel_size_transformer: int = 5,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 5,
@ -292,9 +296,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
self.pre_transformer = Transformer(
self.half_channels,
num_heads=n_heads,
num_layers=n_layers,
cnn_module_kernel=n_kernel_size,
num_heads=heads_transformer,
num_layers=layers_transformer,
cnn_module_kernel=kernel_size_transformer,
)
# define modules

View File

@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from duration_discriminator import DurationDiscriminator
from generator import VITSGenerator
from hifigan import (
DurationDiscriminator,
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
HiFiGANMultiScaleMultiPeriodDiscriminator,
@ -76,10 +76,12 @@ class VITS(nn.Module):
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 3,
"flow_kernel_size": 5,
"flow_heads_transformer": 2, # vits2
"flow_layers_transformer": 1, # vits2
"flow_kernel_size_transformer": 3, # vits2
"flow_base_dilation": 1,
"flow_layers": 2,
"flow_nheads": 2,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
@ -90,9 +92,9 @@ class VITS(nn.Module):
"stochastic_duration_predictor_dds_conv_layers": 3,
"duration_predictor_output_channels": 256,
"use_stochastic_duration_predictor": True,
"use_noised_mas": True,
"noise_initial_mas": 0.01,
"noise_scale_mas": 2e-06,
"use_noised_mas": True, # vits2
"noise_initial_mas": 0.01, # vits2
"noise_scale_mas": 2e-06, # vits2
},
# discriminator related
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",