mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Expose coupling transformer parameters
This commit is contained in:
parent
e96d91555b
commit
87e1d286cf
138
egs/ljspeech/TTS/vits2/duration_discriminator.py
Normal file
138
egs/ljspeech/TTS/vits2/duration_discriminator.py
Normal file
@ -0,0 +1,138 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flow import Transpose
|
||||
|
||||
|
||||
class DurationDiscriminator(torch.nn.Module): # vits2
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
kernel_size: int = 3,
|
||||
dropout_rate: float = 0.5,
|
||||
eps: float = 1e-5,
|
||||
global_channels: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.global_channels = global_channels
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.conv_1 = torch.nn.Conv1d(
|
||||
channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_1 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
self.conv_2 = torch.nn.Conv1d(
|
||||
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.norm_2 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
|
||||
|
||||
self.pre_out_conv_1 = torch.nn.Conv1d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.pre_out_norm_1 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
self.pre_out_conv_2 = torch.nn.Conv1d(
|
||||
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.pre_out_norm_2 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
if global_channels > 0:
|
||||
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
|
||||
|
||||
self.output_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward_probability(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
|
||||
):
|
||||
dur = self.dur_proj(dur)
|
||||
x = torch.cat([x, dur], dim=1)
|
||||
x = self.pre_out_conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.pre_out_norm_1(x)
|
||||
x = self.dropout(x)
|
||||
x = self.pre_out_conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.pre_out_norm_2(x)
|
||||
x = self.dropout(x)
|
||||
x = x * x_mask
|
||||
x = x.transpose(1, 2)
|
||||
output_prob = self.output_layer(x)
|
||||
|
||||
return output_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
dur_r: torch.Tensor,
|
||||
dur_hat: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
x = torch.detach(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond_layer(g)
|
||||
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
output_probs = []
|
||||
for dur in [dur_r, dur_hat]:
|
||||
output_prob = self.forward_probability(x, x_mask, dur)
|
||||
output_probs.append(output_prob)
|
||||
|
||||
return output_probs
|
@ -19,7 +19,7 @@ import torch.nn.functional as F
|
||||
from duration_predictor import DurationPredictor, StochasticDurationPredictor
|
||||
from hifigan import HiFiGANGenerator
|
||||
from posterior_encoder import PosteriorEncoder
|
||||
from residual_coupling import ResidualAffineCouplingBlock
|
||||
from residual_coupling import ResidualCouplingBlock
|
||||
from text_encoder import TextEncoder
|
||||
from torch.cuda.amp import autocast
|
||||
from utils import get_random_segments
|
||||
@ -62,9 +62,11 @@ class VITSGenerator(torch.nn.Module):
|
||||
use_weight_norm_in_posterior_encoder: bool = True,
|
||||
flow_flows: int = 4,
|
||||
flow_kernel_size: int = 5,
|
||||
flow_heads_transformer: int = 2,
|
||||
flow_layers_transformer: int = 1,
|
||||
flow_kernel_size_transformer: int = 3,
|
||||
flow_base_dilation: int = 1,
|
||||
flow_layers: int = 4,
|
||||
flow_nheads: int = 2,
|
||||
flow_dropout_rate: float = 0.0,
|
||||
use_weight_norm_in_flow: bool = True,
|
||||
use_only_mean_in_flow: bool = True,
|
||||
@ -122,6 +124,9 @@ class VITSGenerator(torch.nn.Module):
|
||||
normalization in posterior encoder.
|
||||
flow_flows (int): Number of flows in flow.
|
||||
flow_kernel_size (int): Kernel size in flow.
|
||||
flow_heads_transformer (int): Number of heads for transformer in flow
|
||||
flow_layers_transformer (int): Number of layers for transformer in flow
|
||||
flow_kernel_size_transformer (int): Kernel size for transformer in flow
|
||||
flow_base_dilation (int): Base dilation in flow.
|
||||
flow_layers (int): Number of layers in flow.
|
||||
flow_dropout_rate (float): Dropout rate in flow
|
||||
@ -181,12 +186,14 @@ class VITSGenerator(torch.nn.Module):
|
||||
dropout_rate=posterior_encoder_dropout_rate,
|
||||
use_weight_norm=use_weight_norm_in_posterior_encoder,
|
||||
)
|
||||
self.flow = ResidualAffineCouplingBlock(
|
||||
self.flow = ResidualCouplingBlock(
|
||||
in_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
num_heads=flow_nheads,
|
||||
flows=flow_flows,
|
||||
kernel_size=flow_kernel_size,
|
||||
heads_transformer=flow_heads_transformer,
|
||||
layers_transformer=flow_layers_transformer,
|
||||
kernel_size_transformer=flow_kernel_size_transformer,
|
||||
base_dilation=flow_base_dilation,
|
||||
layers=flow_layers,
|
||||
global_channels=global_channels,
|
||||
|
@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flow import Transpose
|
||||
|
||||
|
||||
class HiFiGANGenerator(torch.nn.Module):
|
||||
@ -932,136 +931,3 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
|
||||
msd_outs = self.msd(x)
|
||||
mpd_outs = self.mpd(x)
|
||||
return msd_outs + mpd_outs
|
||||
|
||||
|
||||
class DurationDiscriminator(torch.nn.Module): # vits2
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
kernel_size: int = 3,
|
||||
dropout_rate: float = 0.5,
|
||||
eps: float = 1e-5,
|
||||
global_channels: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.global_channels = global_channels
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.conv_1 = torch.nn.Conv1d(
|
||||
channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_1 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
self.conv_2 = torch.nn.Conv1d(
|
||||
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.norm_2 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
|
||||
|
||||
self.pre_out_conv_1 = torch.nn.Conv1d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.pre_out_norm_1 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
self.pre_out_conv_2 = torch.nn.Conv1d(
|
||||
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
|
||||
self.pre_out_norm_2 = torch.nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
hidden_channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
)
|
||||
|
||||
if global_channels > 0:
|
||||
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
|
||||
|
||||
self.output_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward_probability(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
|
||||
):
|
||||
dur = self.dur_proj(dur)
|
||||
x = torch.cat([x, dur], dim=1)
|
||||
x = self.pre_out_conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.pre_out_norm_1(x)
|
||||
x = self.dropout(x)
|
||||
x = self.pre_out_conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.pre_out_norm_2(x)
|
||||
x = self.dropout(x)
|
||||
x = x * x_mask
|
||||
x = x.transpose(1, 2)
|
||||
output_prob = self.output_layer(x)
|
||||
return output_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
dur_r: torch.Tensor,
|
||||
dur_hat: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
x = torch.detach(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond_layer(g)
|
||||
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
output_probs = []
|
||||
for dur in [dur_r, dur_hat]:
|
||||
output_prob = self.forward_probability(x, x_mask, dur)
|
||||
output_probs.append(output_prob)
|
||||
|
||||
return output_probs
|
||||
|
@ -19,10 +19,10 @@ from wavenet import WaveNet
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
"""Residual affine coupling block module.
|
||||
class ResidualCouplingBlock(torch.nn.Module):
|
||||
"""Residual coupling block module.
|
||||
|
||||
This is a module of residual affine coupling block, which used as "Flow" in
|
||||
This is a module of residual affine/transformer coupling block, which used as "Flow" in
|
||||
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`_.
|
||||
|
||||
@ -35,7 +35,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
self,
|
||||
in_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
num_heads: int = 2,
|
||||
heads_transformer: int = 2,
|
||||
layers_transformer: int = 4,
|
||||
kernel_size_transformer: int = 3,
|
||||
flows: int = 4,
|
||||
kernel_size: int = 5,
|
||||
base_dilation: int = 1,
|
||||
@ -47,7 +49,7 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
use_only_mean: bool = True,
|
||||
use_transformer_in_flows: bool = True,
|
||||
):
|
||||
"""Initilize ResidualAffineCouplingBlock module.
|
||||
"""Initilize ResidualCouplingBlock module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
@ -73,7 +75,9 @@ class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
ResidualCouplingTransformersLayer(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
n_heads=num_heads,
|
||||
heads_transformer=heads_transformer,
|
||||
layers_transformer=layers_transformer,
|
||||
kernel_size_transformer=kernel_size_transformer,
|
||||
kernel_size=kernel_size,
|
||||
base_dilation=base_dilation,
|
||||
layers=layers,
|
||||
@ -258,9 +262,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
|
||||
self,
|
||||
in_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
n_heads: int = 2,
|
||||
n_layers: int = 2,
|
||||
n_kernel_size: int = 5,
|
||||
heads_transformer: int = 2,
|
||||
layers_transformer: int = 2,
|
||||
kernel_size_transformer: int = 5,
|
||||
kernel_size: int = 5,
|
||||
base_dilation: int = 1,
|
||||
layers: int = 5,
|
||||
@ -292,9 +296,9 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
|
||||
|
||||
self.pre_transformer = Transformer(
|
||||
self.half_channels,
|
||||
num_heads=n_heads,
|
||||
num_layers=n_layers,
|
||||
cnn_module_kernel=n_kernel_size,
|
||||
num_heads=heads_transformer,
|
||||
num_layers=layers_transformer,
|
||||
cnn_module_kernel=kernel_size_transformer,
|
||||
)
|
||||
|
||||
# define modules
|
||||
|
@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from duration_discriminator import DurationDiscriminator
|
||||
from generator import VITSGenerator
|
||||
from hifigan import (
|
||||
DurationDiscriminator,
|
||||
HiFiGANMultiPeriodDiscriminator,
|
||||
HiFiGANMultiScaleDiscriminator,
|
||||
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
||||
@ -76,10 +76,12 @@ class VITS(nn.Module):
|
||||
"posterior_encoder_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_posterior_encoder": True,
|
||||
"flow_flows": 4,
|
||||
"flow_kernel_size": 3,
|
||||
"flow_kernel_size": 5,
|
||||
"flow_heads_transformer": 2, # vits2
|
||||
"flow_layers_transformer": 1, # vits2
|
||||
"flow_kernel_size_transformer": 3, # vits2
|
||||
"flow_base_dilation": 1,
|
||||
"flow_layers": 2,
|
||||
"flow_nheads": 2,
|
||||
"flow_layers": 4,
|
||||
"flow_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_flow": True,
|
||||
"use_only_mean_in_flow": True,
|
||||
@ -90,9 +92,9 @@ class VITS(nn.Module):
|
||||
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||
"duration_predictor_output_channels": 256,
|
||||
"use_stochastic_duration_predictor": True,
|
||||
"use_noised_mas": True,
|
||||
"noise_initial_mas": 0.01,
|
||||
"noise_scale_mas": 2e-06,
|
||||
"use_noised_mas": True, # vits2
|
||||
"noise_initial_mas": 0.01, # vits2
|
||||
"noise_scale_mas": 2e-06, # vits2
|
||||
},
|
||||
# discriminator related
|
||||
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
|
||||
|
Loading…
x
Reference in New Issue
Block a user