icefall/egs/ljspeech/TTS/vits2/duration_discriminator.py
2024-03-01 09:57:14 +01:00

139 lines
3.8 KiB
Python

from typing import Optional
import torch
from flow import Transpose
class DurationDiscriminator(torch.nn.Module): # vits2
def __init__(
self,
channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
eps: float = 1e-5,
global_channels: int = -1,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.global_channels = global_channels
self.dropout = torch.nn.Dropout(dropout_rate)
self.conv_1 = torch.nn.Conv1d(
channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
self.pre_out_conv_1 = torch.nn.Conv1d(
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
self.pre_out_conv_2 = torch.nn.Conv1d(
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = torch.nn.Sequential(
Transpose(1, 2),
torch.nn.LayerNorm(
hidden_channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
)
if global_channels > 0:
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
)
def forward_probability(
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.dropout(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.dropout(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
dur_r: torch.Tensor,
dur_hat: torch.Tensor,
g: Optional[torch.Tensor] = None,
):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond_layer(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.dropout(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.dropout(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur)
output_probs.append(output_prob)
return output_probs