mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* init * isort formatted * minor updates * Create shared * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare.sh * updated * Update train.py * Update train.py * Update tts_datamodule.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * fixed formatting issue * Update infer.py * removed redundant files * Create monotonic_align * removed redundant files * created symlinks * Update prepare.sh * minor adjustments * Create requirements_tts.txt * Update requirements_tts.txt added version constraints * Update infer.py * Update infer.py * Update infer.py * updated docs * Update export-onnx.py * Update export-onnx.py * Update test_onnx.py * updated requirements.txt * Update test_onnx.py * Update test_onnx.py * docs updated * docs fixed * minor updates
312 lines
9.2 KiB
Python
312 lines
9.2 KiB
Python
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
|
|
|
# Copyright 2021 Tomoki Hayashi
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""Basic Flow modules used in VITS.
|
|
|
|
This code is based on https://github.com/jaywalnut310/vits.
|
|
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from transform import piecewise_rational_quadratic_transform
|
|
|
|
|
|
class FlipFlow(torch.nn.Module):
|
|
"""Flip flow module."""
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, channels, T).
|
|
inverse (bool): Whether to inverse the flow.
|
|
|
|
Returns:
|
|
Tensor: Flipped tensor (B, channels, T).
|
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
|
|
"""
|
|
x = torch.flip(x, [1])
|
|
if not inverse:
|
|
logdet = x.new_zeros(x.size(0))
|
|
return x, logdet
|
|
else:
|
|
return x
|
|
|
|
|
|
class LogFlow(torch.nn.Module):
|
|
"""Log flow module."""
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
x_mask: torch.Tensor,
|
|
inverse: bool = False,
|
|
eps: float = 1e-5,
|
|
**kwargs
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, channels, T).
|
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
|
inverse (bool): Whether to inverse the flow.
|
|
eps (float): Epsilon for log.
|
|
|
|
Returns:
|
|
Tensor: Output tensor (B, channels, T).
|
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
|
|
"""
|
|
if not inverse:
|
|
y = torch.log(torch.clamp_min(x, eps)) * x_mask
|
|
logdet = torch.sum(-y, [1, 2])
|
|
return y, logdet
|
|
else:
|
|
x = torch.exp(x) * x_mask
|
|
return x
|
|
|
|
|
|
class ElementwiseAffineFlow(torch.nn.Module):
|
|
"""Elementwise affine flow module."""
|
|
|
|
def __init__(self, channels: int):
|
|
"""Initialize ElementwiseAffineFlow module.
|
|
|
|
Args:
|
|
channels (int): Number of channels.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
|
|
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, channels, T).
|
|
x_lengths (Tensor): Length tensor (B,).
|
|
inverse (bool): Whether to inverse the flow.
|
|
|
|
Returns:
|
|
Tensor: Output tensor (B, channels, T).
|
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
|
|
"""
|
|
if not inverse:
|
|
y = self.m + torch.exp(self.logs) * x
|
|
y = y * x_mask
|
|
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
|
return y, logdet
|
|
else:
|
|
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
|
return x
|
|
|
|
|
|
class Transpose(torch.nn.Module):
|
|
"""Transpose module for torch.nn.Sequential()."""
|
|
|
|
def __init__(self, dim1: int, dim2: int):
|
|
"""Initialize Transpose module."""
|
|
super().__init__()
|
|
self.dim1 = dim1
|
|
self.dim2 = dim2
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Transpose."""
|
|
return x.transpose(self.dim1, self.dim2)
|
|
|
|
|
|
class DilatedDepthSeparableConv(torch.nn.Module):
|
|
"""Dilated depth-separable conv module."""
|
|
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
kernel_size: int,
|
|
layers: int,
|
|
dropout_rate: float = 0.0,
|
|
eps: float = 1e-5,
|
|
):
|
|
"""Initialize DilatedDepthSeparableConv module.
|
|
|
|
Args:
|
|
channels (int): Number of channels.
|
|
kernel_size (int): Kernel size.
|
|
layers (int): Number of layers.
|
|
dropout_rate (float): Dropout rate.
|
|
eps (float): Epsilon for layer norm.
|
|
|
|
"""
|
|
super().__init__()
|
|
|
|
self.convs = torch.nn.ModuleList()
|
|
for i in range(layers):
|
|
dilation = kernel_size**i
|
|
padding = (kernel_size * dilation - dilation) // 2
|
|
self.convs += [
|
|
torch.nn.Sequential(
|
|
torch.nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size,
|
|
groups=channels,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
),
|
|
Transpose(1, 2),
|
|
torch.nn.LayerNorm(
|
|
channels,
|
|
eps=eps,
|
|
elementwise_affine=True,
|
|
),
|
|
Transpose(1, 2),
|
|
torch.nn.GELU(),
|
|
torch.nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
1,
|
|
),
|
|
Transpose(1, 2),
|
|
torch.nn.LayerNorm(
|
|
channels,
|
|
eps=eps,
|
|
elementwise_affine=True,
|
|
),
|
|
Transpose(1, 2),
|
|
torch.nn.GELU(),
|
|
torch.nn.Dropout(dropout_rate),
|
|
)
|
|
]
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, in_channels, T).
|
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
|
|
|
Returns:
|
|
Tensor: Output tensor (B, channels, T).
|
|
|
|
"""
|
|
if g is not None:
|
|
x = x + g
|
|
for f in self.convs:
|
|
y = f(x * x_mask)
|
|
x = x + y
|
|
return x * x_mask
|
|
|
|
|
|
class ConvFlow(torch.nn.Module):
|
|
"""Convolutional flow module."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
hidden_channels: int,
|
|
kernel_size: int,
|
|
layers: int,
|
|
bins: int = 10,
|
|
tail_bound: float = 5.0,
|
|
):
|
|
"""Initialize ConvFlow module.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
hidden_channels (int): Number of hidden channels.
|
|
kernel_size (int): Kernel size.
|
|
layers (int): Number of layers.
|
|
bins (int): Number of bins.
|
|
tail_bound (float): Tail bound value.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.half_channels = in_channels // 2
|
|
self.hidden_channels = hidden_channels
|
|
self.bins = bins
|
|
self.tail_bound = tail_bound
|
|
|
|
self.input_conv = torch.nn.Conv1d(
|
|
self.half_channels,
|
|
hidden_channels,
|
|
1,
|
|
)
|
|
self.dds_conv = DilatedDepthSeparableConv(
|
|
hidden_channels,
|
|
kernel_size,
|
|
layers,
|
|
dropout_rate=0.0,
|
|
)
|
|
self.proj = torch.nn.Conv1d(
|
|
hidden_channels,
|
|
self.half_channels * (bins * 3 - 1),
|
|
1,
|
|
)
|
|
self.proj.weight.data.zero_()
|
|
self.proj.bias.data.zero_()
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
x_mask: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
inverse: bool = False,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, channels, T).
|
|
x_mask (Tensor): Mask tensor (B,).
|
|
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
|
|
inverse (bool): Whether to inverse the flow.
|
|
|
|
Returns:
|
|
Tensor: Output tensor (B, channels, T).
|
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
|
|
"""
|
|
xa, xb = x.split(x.size(1) // 2, 1)
|
|
h = self.input_conv(xa)
|
|
h = self.dds_conv(h, x_mask, g=g)
|
|
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
|
|
|
|
b, c, t = xa.shape
|
|
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
|
|
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
|
|
|
|
# TODO(kan-bayashi): Understand this calculation
|
|
denom = math.sqrt(self.hidden_channels)
|
|
unnorm_widths = h[..., : self.bins] / denom
|
|
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
|
|
unnorm_derivatives = h[..., 2 * self.bins :]
|
|
xb, logdet_abs = piecewise_rational_quadratic_transform(
|
|
xb,
|
|
unnorm_widths,
|
|
unnorm_heights,
|
|
unnorm_derivatives,
|
|
inverse=inverse,
|
|
tails="linear",
|
|
tail_bound=self.tail_bound,
|
|
)
|
|
x = torch.cat([xa, xb], 1) * x_mask
|
|
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
|
|
if not inverse:
|
|
return x, logdet
|
|
else:
|
|
return x
|