mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-28 11:14:19 +00:00
139 lines
3.8 KiB
Python
139 lines
3.8 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from flow import Transpose
|
|
|
|
|
|
class DurationDiscriminator(torch.nn.Module): # vits2
|
|
def __init__(
|
|
self,
|
|
channels: int = 192,
|
|
hidden_channels: int = 192,
|
|
kernel_size: int = 3,
|
|
dropout_rate: float = 0.5,
|
|
eps: float = 1e-5,
|
|
global_channels: int = -1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.channels = channels
|
|
self.hidden_channels = hidden_channels
|
|
self.kernel_size = kernel_size
|
|
self.dropout_rate = dropout_rate
|
|
self.global_channels = global_channels
|
|
|
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
self.conv_1 = torch.nn.Conv1d(
|
|
channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
|
)
|
|
self.norm_1 = torch.nn.Sequential(
|
|
Transpose(1, 2),
|
|
torch.nn.LayerNorm(
|
|
hidden_channels,
|
|
eps=eps,
|
|
elementwise_affine=True,
|
|
),
|
|
Transpose(1, 2),
|
|
)
|
|
|
|
self.conv_2 = torch.nn.Conv1d(
|
|
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
|
)
|
|
|
|
self.norm_2 = torch.nn.Sequential(
|
|
Transpose(1, 2),
|
|
torch.nn.LayerNorm(
|
|
hidden_channels,
|
|
eps=eps,
|
|
elementwise_affine=True,
|
|
),
|
|
Transpose(1, 2),
|
|
)
|
|
|
|
self.dur_proj = torch.nn.Conv1d(1, hidden_channels, 1)
|
|
|
|
self.pre_out_conv_1 = torch.nn.Conv1d(
|
|
2 * hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
|
)
|
|
|
|
self.pre_out_norm_1 = torch.nn.Sequential(
|
|
Transpose(1, 2),
|
|
torch.nn.LayerNorm(
|
|
hidden_channels,
|
|
eps=eps,
|
|
elementwise_affine=True,
|
|
),
|
|
Transpose(1, 2),
|
|
)
|
|
|
|
self.pre_out_conv_2 = torch.nn.Conv1d(
|
|
hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
|
)
|
|
|
|
self.pre_out_norm_2 = torch.nn.Sequential(
|
|
Transpose(1, 2),
|
|
torch.nn.LayerNorm(
|
|
hidden_channels,
|
|
eps=eps,
|
|
elementwise_affine=True,
|
|
),
|
|
Transpose(1, 2),
|
|
)
|
|
|
|
if global_channels > 0:
|
|
self.cond_layer = torch.nn.Conv1d(global_channels, channels, 1)
|
|
|
|
self.output_layer = torch.nn.Sequential(
|
|
torch.nn.Linear(hidden_channels, 1), torch.nn.Sigmoid()
|
|
)
|
|
|
|
def forward_probability(
|
|
self, x: torch.Tensor, x_mask: torch.Tensor, dur: torch.Tensor
|
|
):
|
|
dur = self.dur_proj(dur)
|
|
x = torch.cat([x, dur], dim=1)
|
|
x = self.pre_out_conv_1(x * x_mask)
|
|
x = torch.relu(x)
|
|
x = self.pre_out_norm_1(x)
|
|
x = self.dropout(x)
|
|
x = self.pre_out_conv_2(x * x_mask)
|
|
x = torch.relu(x)
|
|
x = self.pre_out_norm_2(x)
|
|
x = self.dropout(x)
|
|
x = x * x_mask
|
|
x = x.transpose(1, 2)
|
|
output_prob = self.output_layer(x)
|
|
|
|
return output_prob
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
x_mask: torch.Tensor,
|
|
dur_r: torch.Tensor,
|
|
dur_hat: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
):
|
|
|
|
x = torch.detach(x)
|
|
if g is not None:
|
|
g = torch.detach(g)
|
|
x = x + self.cond_layer(g)
|
|
|
|
x = self.conv_1(x * x_mask)
|
|
x = torch.relu(x)
|
|
x = self.norm_1(x)
|
|
x = self.dropout(x)
|
|
|
|
x = self.conv_2(x * x_mask)
|
|
x = torch.relu(x)
|
|
x = self.norm_2(x)
|
|
x = self.dropout(x)
|
|
|
|
output_probs = []
|
|
for dur in [dur_r, dur_hat]:
|
|
output_prob = self.forward_probability(x, x_mask, dur)
|
|
output_probs.append(output_prob)
|
|
|
|
return output_probs
|