mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Add symbolic link
This commit is contained in:
parent
a4e4f8080a
commit
0377cccc6f
@ -405,6 +405,7 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
# update discriminator
|
# update discriminator
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
scaler.scale(loss_d).backward()
|
scaler.scale(loss_d).backward()
|
||||||
|
@ -1,311 +0,0 @@
|
|||||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
||||||
|
|
||||||
"""Basic Flow modules used in VITS.
|
|
||||||
|
|
||||||
This code is based on https://github.com/jaywalnut310/vits.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transform import piecewise_rational_quadratic_transform
|
|
||||||
|
|
||||||
|
|
||||||
class FlipFlow(torch.nn.Module):
|
|
||||||
"""Flip flow module."""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, channels, T).
|
|
||||||
inverse (bool): Whether to inverse the flow.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Flipped tensor (B, channels, T).
|
|
||||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
||||||
|
|
||||||
"""
|
|
||||||
x = torch.flip(x, [1])
|
|
||||||
if not inverse:
|
|
||||||
logdet = x.new_zeros(x.size(0))
|
|
||||||
return x, logdet
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LogFlow(torch.nn.Module):
|
|
||||||
"""Log flow module."""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_mask: torch.Tensor,
|
|
||||||
inverse: bool = False,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, channels, T).
|
|
||||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
|
||||||
inverse (bool): Whether to inverse the flow.
|
|
||||||
eps (float): Epsilon for log.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, channels, T).
|
|
||||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not inverse:
|
|
||||||
y = torch.log(torch.clamp_min(x, eps)) * x_mask
|
|
||||||
logdet = torch.sum(-y, [1, 2])
|
|
||||||
return y, logdet
|
|
||||||
else:
|
|
||||||
x = torch.exp(x) * x_mask
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseAffineFlow(torch.nn.Module):
|
|
||||||
"""Elementwise affine flow module."""
|
|
||||||
|
|
||||||
def __init__(self, channels: int):
|
|
||||||
"""Initialize ElementwiseAffineFlow module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels (int): Number of channels.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
|
|
||||||
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, channels, T).
|
|
||||||
x_lengths (Tensor): Length tensor (B,).
|
|
||||||
inverse (bool): Whether to inverse the flow.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, channels, T).
|
|
||||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not inverse:
|
|
||||||
y = self.m + torch.exp(self.logs) * x
|
|
||||||
y = y * x_mask
|
|
||||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
|
||||||
return y, logdet
|
|
||||||
else:
|
|
||||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Transpose(torch.nn.Module):
|
|
||||||
"""Transpose module for torch.nn.Sequential()."""
|
|
||||||
|
|
||||||
def __init__(self, dim1: int, dim2: int):
|
|
||||||
"""Initialize Transpose module."""
|
|
||||||
super().__init__()
|
|
||||||
self.dim1 = dim1
|
|
||||||
self.dim2 = dim2
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Transpose."""
|
|
||||||
return x.transpose(self.dim1, self.dim2)
|
|
||||||
|
|
||||||
|
|
||||||
class DilatedDepthSeparableConv(torch.nn.Module):
|
|
||||||
"""Dilated depth-separable conv module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels: int,
|
|
||||||
kernel_size: int,
|
|
||||||
layers: int,
|
|
||||||
dropout_rate: float = 0.0,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
):
|
|
||||||
"""Initialize DilatedDepthSeparableConv module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels (int): Number of channels.
|
|
||||||
kernel_size (int): Kernel size.
|
|
||||||
layers (int): Number of layers.
|
|
||||||
dropout_rate (float): Dropout rate.
|
|
||||||
eps (float): Epsilon for layer norm.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.convs = torch.nn.ModuleList()
|
|
||||||
for i in range(layers):
|
|
||||||
dilation = kernel_size**i
|
|
||||||
padding = (kernel_size * dilation - dilation) // 2
|
|
||||||
self.convs += [
|
|
||||||
torch.nn.Sequential(
|
|
||||||
torch.nn.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
groups=channels,
|
|
||||||
dilation=dilation,
|
|
||||||
padding=padding,
|
|
||||||
),
|
|
||||||
Transpose(1, 2),
|
|
||||||
torch.nn.LayerNorm(
|
|
||||||
channels,
|
|
||||||
eps=eps,
|
|
||||||
elementwise_affine=True,
|
|
||||||
),
|
|
||||||
Transpose(1, 2),
|
|
||||||
torch.nn.GELU(),
|
|
||||||
torch.nn.Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
Transpose(1, 2),
|
|
||||||
torch.nn.LayerNorm(
|
|
||||||
channels,
|
|
||||||
eps=eps,
|
|
||||||
elementwise_affine=True,
|
|
||||||
),
|
|
||||||
Transpose(1, 2),
|
|
||||||
torch.nn.GELU(),
|
|
||||||
torch.nn.Dropout(dropout_rate),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, in_channels, T).
|
|
||||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
|
||||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, channels, T).
|
|
||||||
|
|
||||||
"""
|
|
||||||
if g is not None:
|
|
||||||
x = x + g
|
|
||||||
for f in self.convs:
|
|
||||||
y = f(x * x_mask)
|
|
||||||
x = x + y
|
|
||||||
return x * x_mask
|
|
||||||
|
|
||||||
|
|
||||||
class ConvFlow(torch.nn.Module):
|
|
||||||
"""Convolutional flow module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
hidden_channels: int,
|
|
||||||
kernel_size: int,
|
|
||||||
layers: int,
|
|
||||||
bins: int = 10,
|
|
||||||
tail_bound: float = 5.0,
|
|
||||||
):
|
|
||||||
"""Initialize ConvFlow module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Number of input channels.
|
|
||||||
hidden_channels (int): Number of hidden channels.
|
|
||||||
kernel_size (int): Kernel size.
|
|
||||||
layers (int): Number of layers.
|
|
||||||
bins (int): Number of bins.
|
|
||||||
tail_bound (float): Tail bound value.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.half_channels = in_channels // 2
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.bins = bins
|
|
||||||
self.tail_bound = tail_bound
|
|
||||||
|
|
||||||
self.input_conv = torch.nn.Conv1d(
|
|
||||||
self.half_channels,
|
|
||||||
hidden_channels,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
self.dds_conv = DilatedDepthSeparableConv(
|
|
||||||
hidden_channels,
|
|
||||||
kernel_size,
|
|
||||||
layers,
|
|
||||||
dropout_rate=0.0,
|
|
||||||
)
|
|
||||||
self.proj = torch.nn.Conv1d(
|
|
||||||
hidden_channels,
|
|
||||||
self.half_channels * (bins * 3 - 1),
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
self.proj.weight.data.zero_()
|
|
||||||
self.proj.bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_mask: torch.Tensor,
|
|
||||||
g: Optional[torch.Tensor] = None,
|
|
||||||
inverse: bool = False,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, channels, T).
|
|
||||||
x_mask (Tensor): Mask tensor (B,).
|
|
||||||
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
|
|
||||||
inverse (bool): Whether to inverse the flow.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, channels, T).
|
|
||||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
|
||||||
|
|
||||||
"""
|
|
||||||
xa, xb = x.split(x.size(1) // 2, 1)
|
|
||||||
h = self.input_conv(xa)
|
|
||||||
h = self.dds_conv(h, x_mask, g=g)
|
|
||||||
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
|
|
||||||
|
|
||||||
b, c, t = xa.shape
|
|
||||||
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
|
|
||||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
|
|
||||||
|
|
||||||
# TODO(kan-bayashi): Understand this calculation
|
|
||||||
denom = math.sqrt(self.hidden_channels)
|
|
||||||
unnorm_widths = h[..., : self.bins] / denom
|
|
||||||
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
|
|
||||||
unnorm_derivatives = h[..., 2 * self.bins :]
|
|
||||||
xb, logdet_abs = piecewise_rational_quadratic_transform(
|
|
||||||
xb,
|
|
||||||
unnorm_widths,
|
|
||||||
unnorm_heights,
|
|
||||||
unnorm_derivatives,
|
|
||||||
inverse=inverse,
|
|
||||||
tails="linear",
|
|
||||||
tail_bound=self.tail_bound,
|
|
||||||
)
|
|
||||||
x = torch.cat([xa, xb], 1) * x_mask
|
|
||||||
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
|
|
||||||
if not inverse:
|
|
||||||
return x, logdet
|
|
||||||
else:
|
|
||||||
return x
|
|
1
egs/ljspeech/TTS/vits2/flow.py
Symbolic link
1
egs/ljspeech/TTS/vits2/flow.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/vits/flow.py
|
@ -1,335 +0,0 @@
|
|||||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
||||||
|
|
||||||
"""HiFiGAN-related loss modules.
|
|
||||||
|
|
||||||
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributions as D
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
|
||||||
"""Generator adversarial loss module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
average_by_discriminators: bool = True,
|
|
||||||
loss_type: str = "mse",
|
|
||||||
):
|
|
||||||
"""Initialize GeneratorAversarialLoss module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
average_by_discriminators (bool): Whether to average the loss by
|
|
||||||
the number of discriminators.
|
|
||||||
loss_type (str): Loss type, "mse" or "hinge".
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.average_by_discriminators = average_by_discriminators
|
|
||||||
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
||||||
if loss_type == "mse":
|
|
||||||
self.criterion = self._mse_loss
|
|
||||||
else:
|
|
||||||
self.criterion = self._hinge_loss
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Calcualate generator adversarial loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
||||||
outputs, list of discriminator outputs, or list of list of discriminator
|
|
||||||
outputs..
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Generator adversarial loss value.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(outputs, (tuple, list)):
|
|
||||||
adv_loss = 0.0
|
|
||||||
for i, outputs_ in enumerate(outputs):
|
|
||||||
if isinstance(outputs_, (tuple, list)):
|
|
||||||
# NOTE(kan-bayashi): case including feature maps
|
|
||||||
outputs_ = outputs_[-1]
|
|
||||||
adv_loss += self.criterion(outputs_)
|
|
||||||
if self.average_by_discriminators:
|
|
||||||
adv_loss /= i + 1
|
|
||||||
else:
|
|
||||||
adv_loss = self.criterion(outputs)
|
|
||||||
|
|
||||||
return adv_loss
|
|
||||||
|
|
||||||
def _mse_loss(self, x):
|
|
||||||
return F.mse_loss(x, x.new_ones(x.size()))
|
|
||||||
|
|
||||||
def _hinge_loss(self, x):
|
|
||||||
return -x.mean()
|
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|
||||||
"""Discriminator adversarial loss module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
average_by_discriminators: bool = True,
|
|
||||||
loss_type: str = "mse",
|
|
||||||
):
|
|
||||||
"""Initialize DiscriminatorAversarialLoss module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
average_by_discriminators (bool): Whether to average the loss by
|
|
||||||
the number of discriminators.
|
|
||||||
loss_type (str): Loss type, "mse" or "hinge".
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.average_by_discriminators = average_by_discriminators
|
|
||||||
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
||||||
if loss_type == "mse":
|
|
||||||
self.fake_criterion = self._mse_fake_loss
|
|
||||||
self.real_criterion = self._mse_real_loss
|
|
||||||
else:
|
|
||||||
self.fake_criterion = self._hinge_fake_loss
|
|
||||||
self.real_criterion = self._hinge_real_loss
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
||||||
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Calcualate discriminator adversarial loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
||||||
outputs, list of discriminator outputs, or list of list of discriminator
|
|
||||||
outputs calculated from generator.
|
|
||||||
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
||||||
outputs, list of discriminator outputs, or list of list of discriminator
|
|
||||||
outputs calculated from groundtruth.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Discriminator real loss value.
|
|
||||||
Tensor: Discriminator fake loss value.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(outputs, (tuple, list)):
|
|
||||||
real_loss = 0.0
|
|
||||||
fake_loss = 0.0
|
|
||||||
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
||||||
if isinstance(outputs_hat_, (tuple, list)):
|
|
||||||
# NOTE(kan-bayashi): case including feature maps
|
|
||||||
outputs_hat_ = outputs_hat_[-1]
|
|
||||||
outputs_ = outputs_[-1]
|
|
||||||
real_loss += self.real_criterion(outputs_)
|
|
||||||
fake_loss += self.fake_criterion(outputs_hat_)
|
|
||||||
if self.average_by_discriminators:
|
|
||||||
fake_loss /= i + 1
|
|
||||||
real_loss /= i + 1
|
|
||||||
else:
|
|
||||||
real_loss = self.real_criterion(outputs)
|
|
||||||
fake_loss = self.fake_criterion(outputs_hat)
|
|
||||||
|
|
||||||
return real_loss, fake_loss
|
|
||||||
|
|
||||||
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return F.mse_loss(x, x.new_ones(x.size()))
|
|
||||||
|
|
||||||
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
|
||||||
|
|
||||||
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
|
|
||||||
|
|
||||||
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMatchLoss(torch.nn.Module):
|
|
||||||
"""Feature matching loss module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
average_by_layers: bool = True,
|
|
||||||
average_by_discriminators: bool = True,
|
|
||||||
include_final_outputs: bool = False,
|
|
||||||
):
|
|
||||||
"""Initialize FeatureMatchLoss module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
average_by_layers (bool): Whether to average the loss by the number
|
|
||||||
of layers.
|
|
||||||
average_by_discriminators (bool): Whether to average the loss by
|
|
||||||
the number of discriminators.
|
|
||||||
include_final_outputs (bool): Whether to include the final output of
|
|
||||||
each discriminator for loss calculation.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.average_by_layers = average_by_layers
|
|
||||||
self.average_by_discriminators = average_by_discriminators
|
|
||||||
self.include_final_outputs = include_final_outputs
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
||||||
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Calculate feature matching loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
||||||
discriminator outputs or list of discriminator outputs calcuated
|
|
||||||
from generator's outputs.
|
|
||||||
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
||||||
discriminator outputs or list of discriminator outputs calcuated
|
|
||||||
from groundtruth.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Feature matching loss value.
|
|
||||||
|
|
||||||
"""
|
|
||||||
feat_match_loss = 0.0
|
|
||||||
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
|
||||||
feat_match_loss_ = 0.0
|
|
||||||
if not self.include_final_outputs:
|
|
||||||
feats_hat_ = feats_hat_[:-1]
|
|
||||||
feats_ = feats_[:-1]
|
|
||||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
|
||||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
|
||||||
if self.average_by_layers:
|
|
||||||
feat_match_loss_ /= j + 1
|
|
||||||
feat_match_loss += feat_match_loss_
|
|
||||||
if self.average_by_discriminators:
|
|
||||||
feat_match_loss /= i + 1
|
|
||||||
|
|
||||||
return feat_match_loss
|
|
||||||
|
|
||||||
|
|
||||||
class MelSpectrogramLoss(torch.nn.Module):
|
|
||||||
"""Mel-spectrogram loss."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
sampling_rate: int = 22050,
|
|
||||||
frame_length: int = 1024, # in samples
|
|
||||||
frame_shift: int = 256, # in samples
|
|
||||||
n_mels: int = 80,
|
|
||||||
use_fft_mag: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.wav_to_mel = Wav2LogFilterBank(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=frame_length / sampling_rate, # in second
|
|
||||||
frame_shift=frame_shift / sampling_rate, # in second
|
|
||||||
use_fft_mag=use_fft_mag,
|
|
||||||
num_filters=n_mels,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
y_hat: torch.Tensor,
|
|
||||||
y: torch.Tensor,
|
|
||||||
return_mel: bool = False,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
"""Calculate Mel-spectrogram loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_hat (Tensor): Generated waveform tensor (B, 1, T).
|
|
||||||
y (Tensor): Groundtruth waveform tensor (B, 1, T).
|
|
||||||
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
|
||||||
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
|
||||||
waveform.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Mel-spectrogram loss value.
|
|
||||||
|
|
||||||
"""
|
|
||||||
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
|
|
||||||
mel = self.wav_to_mel(y.squeeze(1))
|
|
||||||
mel_loss = F.l1_loss(mel_hat, mel)
|
|
||||||
|
|
||||||
if return_mel:
|
|
||||||
return mel_loss, (mel_hat, mel)
|
|
||||||
|
|
||||||
return mel_loss
|
|
||||||
|
|
||||||
|
|
||||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
|
||||||
|
|
||||||
"""VITS-related loss modules.
|
|
||||||
|
|
||||||
This code is based on https://github.com/jaywalnut310/vits.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class KLDivergenceLoss(torch.nn.Module):
|
|
||||||
"""KL divergence loss."""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
z_p: torch.Tensor,
|
|
||||||
logs_q: torch.Tensor,
|
|
||||||
m_p: torch.Tensor,
|
|
||||||
logs_p: torch.Tensor,
|
|
||||||
z_mask: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Calculate KL divergence loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
z_p (Tensor): Flow hidden representation (B, H, T_feats).
|
|
||||||
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
|
||||||
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
|
||||||
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
|
||||||
z_mask (Tensor): Mask tensor (B, 1, T_feats).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: KL divergence loss.
|
|
||||||
|
|
||||||
"""
|
|
||||||
z_p = z_p.float()
|
|
||||||
logs_q = logs_q.float()
|
|
||||||
m_p = m_p.float()
|
|
||||||
logs_p = logs_p.float()
|
|
||||||
z_mask = z_mask.float()
|
|
||||||
kl = logs_p - logs_q - 0.5
|
|
||||||
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
|
||||||
kl = torch.sum(kl * z_mask)
|
|
||||||
loss = kl / torch.sum(z_mask)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
|
||||||
"""KL divergence loss without flow."""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
m_q: torch.Tensor,
|
|
||||||
logs_q: torch.Tensor,
|
|
||||||
m_p: torch.Tensor,
|
|
||||||
logs_p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Calculate KL divergence loss without flow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
|
|
||||||
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
|
||||||
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
|
||||||
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
|
||||||
"""
|
|
||||||
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
|
|
||||||
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
|
||||||
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
|
||||||
return loss
|
|
1
egs/ljspeech/TTS/vits2/loss.py
Symbolic link
1
egs/ljspeech/TTS/vits2/loss.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/vits/loss.py
|
@ -360,10 +360,10 @@ class ResidualCouplingTransformersLayer(torch.nn.Module):
|
|||||||
xa, xb = x.split(x.size(1) // 2, dim=1)
|
xa, xb = x.split(x.size(1) // 2, dim=1)
|
||||||
|
|
||||||
x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64))
|
x_trans_mask = make_pad_mask(torch.sum(x_mask, dim=[1, 2]).type(torch.int64))
|
||||||
xa_trans = self.pre_transformer(xa.transpose(1, 2), x_trans_mask).transpose(
|
xa_ = self.pre_transformer(
|
||||||
1, 2
|
(xa * x_mask).transpose(1, 2), x_trans_mask
|
||||||
)
|
).transpose(1, 2)
|
||||||
xa_ = xa + xa_trans
|
xa_ = xa + xa_
|
||||||
|
|
||||||
h = self.input_conv(xa_) * x_mask
|
h = self.input_conv(xa_) * x_mask
|
||||||
h = self.encoder(h, x_mask, g=g)
|
h = self.encoder(h, x_mask, g=g)
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import g2p_en
|
|
||||||
import tacotron_cleaner.cleaners
|
|
||||||
from utils import intersperse
|
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(object):
|
|
||||||
def __init__(self, tokens: str):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
tokens: the file that maps tokens to ids
|
|
||||||
"""
|
|
||||||
# Parse token file
|
|
||||||
self.token2id: Dict[str, int] = {}
|
|
||||||
with open(tokens, "r", encoding="utf-8") as f:
|
|
||||||
for line in f.readlines():
|
|
||||||
info = line.rstrip().split()
|
|
||||||
if len(info) == 1:
|
|
||||||
# case of space
|
|
||||||
token = " "
|
|
||||||
id = int(info[0])
|
|
||||||
else:
|
|
||||||
token, id = info[0], int(info[1])
|
|
||||||
self.token2id[token] = id
|
|
||||||
|
|
||||||
self.blank_id = self.token2id["<blk>"]
|
|
||||||
self.oov_id = self.token2id["<unk>"]
|
|
||||||
self.vocab_size = len(self.token2id)
|
|
||||||
|
|
||||||
self.g2p = g2p_en.G2p()
|
|
||||||
|
|
||||||
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
texts:
|
|
||||||
A list of transcripts.
|
|
||||||
intersperse_blank:
|
|
||||||
Whether to intersperse blanks in the token sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a list of token id list [utterance][token_id]
|
|
||||||
"""
|
|
||||||
token_ids_list = []
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
# Text normalization
|
|
||||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
|
||||||
# Convert to phonemes
|
|
||||||
tokens = self.g2p(text)
|
|
||||||
token_ids = []
|
|
||||||
for t in tokens:
|
|
||||||
if t in self.token2id:
|
|
||||||
token_ids.append(self.token2id[t])
|
|
||||||
else:
|
|
||||||
token_ids.append(self.oov_id)
|
|
||||||
|
|
||||||
if intersperse_blank:
|
|
||||||
token_ids = intersperse(token_ids, self.blank_id)
|
|
||||||
|
|
||||||
token_ids_list.append(token_ids)
|
|
||||||
|
|
||||||
return token_ids_list
|
|
||||||
|
|
||||||
def tokens_to_token_ids(
|
|
||||||
self, tokens_list: List[str], intersperse_blank: bool = True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
tokens_list:
|
|
||||||
A list of token list, each corresponding to one utterance.
|
|
||||||
intersperse_blank:
|
|
||||||
Whether to intersperse blanks in the token sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a list of token id list [utterance][token_id]
|
|
||||||
"""
|
|
||||||
token_ids_list = []
|
|
||||||
|
|
||||||
for tokens in tokens_list:
|
|
||||||
token_ids = []
|
|
||||||
for t in tokens:
|
|
||||||
if t in self.token2id:
|
|
||||||
token_ids.append(self.token2id[t])
|
|
||||||
else:
|
|
||||||
token_ids.append(self.oov_id)
|
|
||||||
|
|
||||||
if intersperse_blank:
|
|
||||||
token_ids = intersperse(token_ids, self.blank_id)
|
|
||||||
token_ids_list.append(token_ids)
|
|
||||||
|
|
||||||
return token_ids_list
|
|
1
egs/ljspeech/TTS/vits2/tokenizer.py
Symbolic link
1
egs/ljspeech/TTS/vits2/tokenizer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/vits/tokenizer.py
|
@ -433,7 +433,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, dur_loss, stats_d = model(
|
loss_d, stats_d = model(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
feats=features,
|
feats=features,
|
||||||
|
@ -1,218 +0,0 @@
|
|||||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
|
||||||
|
|
||||||
"""Flow-related transformation.
|
|
||||||
|
|
||||||
This code is derived from https://github.com/bayesiains/nflows.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
|
||||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
|
||||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(kan-bayashi): Documentation and type hint
|
|
||||||
def piecewise_rational_quadratic_transform(
|
|
||||||
inputs,
|
|
||||||
unnormalized_widths,
|
|
||||||
unnormalized_heights,
|
|
||||||
unnormalized_derivatives,
|
|
||||||
inverse=False,
|
|
||||||
tails=None,
|
|
||||||
tail_bound=1.0,
|
|
||||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
|
||||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
|
||||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
|
||||||
):
|
|
||||||
if tails is None:
|
|
||||||
spline_fn = rational_quadratic_spline
|
|
||||||
spline_kwargs = {}
|
|
||||||
else:
|
|
||||||
spline_fn = unconstrained_rational_quadratic_spline
|
|
||||||
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
|
||||||
|
|
||||||
outputs, logabsdet = spline_fn(
|
|
||||||
inputs=inputs,
|
|
||||||
unnormalized_widths=unnormalized_widths,
|
|
||||||
unnormalized_heights=unnormalized_heights,
|
|
||||||
unnormalized_derivatives=unnormalized_derivatives,
|
|
||||||
inverse=inverse,
|
|
||||||
min_bin_width=min_bin_width,
|
|
||||||
min_bin_height=min_bin_height,
|
|
||||||
min_derivative=min_derivative,
|
|
||||||
**spline_kwargs
|
|
||||||
)
|
|
||||||
return outputs, logabsdet
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(kan-bayashi): Documentation and type hint
|
|
||||||
def unconstrained_rational_quadratic_spline(
|
|
||||||
inputs,
|
|
||||||
unnormalized_widths,
|
|
||||||
unnormalized_heights,
|
|
||||||
unnormalized_derivatives,
|
|
||||||
inverse=False,
|
|
||||||
tails="linear",
|
|
||||||
tail_bound=1.0,
|
|
||||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
|
||||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
|
||||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
|
||||||
):
|
|
||||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
|
||||||
outside_interval_mask = ~inside_interval_mask
|
|
||||||
|
|
||||||
outputs = torch.zeros_like(inputs)
|
|
||||||
logabsdet = torch.zeros_like(inputs)
|
|
||||||
|
|
||||||
if tails == "linear":
|
|
||||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
|
||||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
|
||||||
unnormalized_derivatives[..., 0] = constant
|
|
||||||
unnormalized_derivatives[..., -1] = constant
|
|
||||||
|
|
||||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
|
||||||
logabsdet[outside_interval_mask] = 0
|
|
||||||
else:
|
|
||||||
raise RuntimeError("{} tails are not implemented.".format(tails))
|
|
||||||
|
|
||||||
(
|
|
||||||
outputs[inside_interval_mask],
|
|
||||||
logabsdet[inside_interval_mask],
|
|
||||||
) = rational_quadratic_spline(
|
|
||||||
inputs=inputs[inside_interval_mask],
|
|
||||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
|
||||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
|
||||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
|
||||||
inverse=inverse,
|
|
||||||
left=-tail_bound,
|
|
||||||
right=tail_bound,
|
|
||||||
bottom=-tail_bound,
|
|
||||||
top=tail_bound,
|
|
||||||
min_bin_width=min_bin_width,
|
|
||||||
min_bin_height=min_bin_height,
|
|
||||||
min_derivative=min_derivative,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs, logabsdet
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(kan-bayashi): Documentation and type hint
|
|
||||||
def rational_quadratic_spline(
|
|
||||||
inputs,
|
|
||||||
unnormalized_widths,
|
|
||||||
unnormalized_heights,
|
|
||||||
unnormalized_derivatives,
|
|
||||||
inverse=False,
|
|
||||||
left=0.0,
|
|
||||||
right=1.0,
|
|
||||||
bottom=0.0,
|
|
||||||
top=1.0,
|
|
||||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
|
||||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
|
||||||
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
|
||||||
):
|
|
||||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
|
||||||
raise ValueError("Input to a transform is not within its domain")
|
|
||||||
|
|
||||||
num_bins = unnormalized_widths.shape[-1]
|
|
||||||
|
|
||||||
if min_bin_width * num_bins > 1.0:
|
|
||||||
raise ValueError("Minimal bin width too large for the number of bins")
|
|
||||||
if min_bin_height * num_bins > 1.0:
|
|
||||||
raise ValueError("Minimal bin height too large for the number of bins")
|
|
||||||
|
|
||||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
|
||||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
|
||||||
cumwidths = torch.cumsum(widths, dim=-1)
|
|
||||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
|
||||||
cumwidths = (right - left) * cumwidths + left
|
|
||||||
cumwidths[..., 0] = left
|
|
||||||
cumwidths[..., -1] = right
|
|
||||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
|
||||||
|
|
||||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
|
||||||
|
|
||||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
|
||||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
|
||||||
cumheights = torch.cumsum(heights, dim=-1)
|
|
||||||
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
|
||||||
cumheights = (top - bottom) * cumheights + bottom
|
|
||||||
cumheights[..., 0] = bottom
|
|
||||||
cumheights[..., -1] = top
|
|
||||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
|
||||||
|
|
||||||
if inverse:
|
|
||||||
bin_idx = _searchsorted(cumheights, inputs)[..., None]
|
|
||||||
else:
|
|
||||||
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
|
|
||||||
|
|
||||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
|
||||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
|
||||||
|
|
||||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
|
||||||
delta = heights / widths
|
|
||||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
|
||||||
|
|
||||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
|
||||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
|
||||||
|
|
||||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
|
||||||
|
|
||||||
if inverse:
|
|
||||||
a = (inputs - input_cumheights) * (
|
|
||||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
|
||||||
) + input_heights * (input_delta - input_derivatives)
|
|
||||||
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
|
||||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
|
||||||
)
|
|
||||||
c = -input_delta * (inputs - input_cumheights)
|
|
||||||
|
|
||||||
discriminant = b.pow(2) - 4 * a * c
|
|
||||||
assert (discriminant >= 0).all()
|
|
||||||
|
|
||||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
|
||||||
outputs = root * input_bin_widths + input_cumwidths
|
|
||||||
|
|
||||||
theta_one_minus_theta = root * (1 - root)
|
|
||||||
denominator = input_delta + (
|
|
||||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
|
||||||
* theta_one_minus_theta
|
|
||||||
)
|
|
||||||
derivative_numerator = input_delta.pow(2) * (
|
|
||||||
input_derivatives_plus_one * root.pow(2)
|
|
||||||
+ 2 * input_delta * theta_one_minus_theta
|
|
||||||
+ input_derivatives * (1 - root).pow(2)
|
|
||||||
)
|
|
||||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
|
||||||
|
|
||||||
return outputs, -logabsdet
|
|
||||||
else:
|
|
||||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
|
||||||
theta_one_minus_theta = theta * (1 - theta)
|
|
||||||
|
|
||||||
numerator = input_heights * (
|
|
||||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
|
||||||
)
|
|
||||||
denominator = input_delta + (
|
|
||||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
|
||||||
* theta_one_minus_theta
|
|
||||||
)
|
|
||||||
outputs = input_cumheights + numerator / denominator
|
|
||||||
|
|
||||||
derivative_numerator = input_delta.pow(2) * (
|
|
||||||
input_derivatives_plus_one * theta.pow(2)
|
|
||||||
+ 2 * input_delta * theta_one_minus_theta
|
|
||||||
+ input_derivatives * (1 - theta).pow(2)
|
|
||||||
)
|
|
||||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
|
||||||
|
|
||||||
return outputs, logabsdet
|
|
||||||
|
|
||||||
|
|
||||||
def _searchsorted(bin_locations, inputs, eps=1e-6):
|
|
||||||
bin_locations[..., -1] += eps
|
|
||||||
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
|
1
egs/ljspeech/TTS/vits2/transform.py
Symbolic link
1
egs/ljspeech/TTS/vits2/transform.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/vits/transform.py
|
@ -1,327 +0,0 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
|
||||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
|
||||||
# Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from functools import lru_cache
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
|
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|
||||||
CutConcatenate,
|
|
||||||
CutMix,
|
|
||||||
DynamicBucketingSampler,
|
|
||||||
PrecomputedFeatures,
|
|
||||||
SimpleCutSampler,
|
|
||||||
SpecAugment,
|
|
||||||
SpeechSynthesisDataset,
|
|
||||||
)
|
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|
||||||
AudioSamples,
|
|
||||||
OnTheFlyFeatures,
|
|
||||||
)
|
|
||||||
from lhotse.utils import fix_random_seed
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
|
||||||
def __init__(self, seed: int):
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def __call__(self, worker_id: int):
|
|
||||||
fix_random_seed(self.seed + worker_id)
|
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechTtsDataModule:
|
|
||||||
"""
|
|
||||||
DataModule for tts experiments.
|
|
||||||
It assumes there is always one train and valid dataloader,
|
|
||||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
|
||||||
and test-other).
|
|
||||||
|
|
||||||
It contains all the common data pipeline modules used in ASR
|
|
||||||
experiments, e.g.:
|
|
||||||
- dynamic batch size,
|
|
||||||
- bucketing samplers,
|
|
||||||
- cut concatenation,
|
|
||||||
- on-the-fly feature extraction
|
|
||||||
|
|
||||||
This class should be derived for specific corpora used in ASR tasks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, args: argparse.Namespace):
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
|
||||||
group = parser.add_argument_group(
|
|
||||||
title="TTS data related options",
|
|
||||||
description="These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
|
||||||
"augmentations, etc.",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--manifest-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/spectrogram"),
|
|
||||||
help="Path to directory with train/valid/test cuts.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--max-duration",
|
|
||||||
type=int,
|
|
||||||
default=200.0,
|
|
||||||
help="Maximum pooled recordings duration (seconds) in a "
|
|
||||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--bucketing-sampler",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="When enabled, the batches will come from buckets of "
|
|
||||||
"similar duration (saves padding frames).",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--num-buckets",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="The number of buckets for the DynamicBucketingSampler"
|
|
||||||
"(you might want to increase it for larger datasets).",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--on-the-fly-feats",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
|
||||||
"if available.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--shuffle",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="When enabled (=default), the examples will be "
|
|
||||||
"shuffled for each epoch.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--drop-last",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to drop last batch. Used by sampler.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--return-cuts",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="When enabled, each batch will have the "
|
|
||||||
"field: batch['cut'] with the cuts that "
|
|
||||||
"were used to construct it.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="The number of training dataloader workers that "
|
|
||||||
"collect the batches.",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--input-strategy",
|
|
||||||
type=str,
|
|
||||||
default="PrecomputedFeatures",
|
|
||||||
help="AudioSamples or PrecomputedFeatures",
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_dataloaders(
|
|
||||||
self,
|
|
||||||
cuts_train: CutSet,
|
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> DataLoader:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
cuts_train:
|
|
||||||
CutSet for training.
|
|
||||||
sampler_state_dict:
|
|
||||||
The state dict for the training sampler.
|
|
||||||
"""
|
|
||||||
logging.info("About to create train dataset")
|
|
||||||
train = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
sampling_rate = 22050
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
|
||||||
use_fft_mag=True,
|
|
||||||
)
|
|
||||||
train = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
|
||||||
train_sampler = DynamicBucketingSampler(
|
|
||||||
cuts_train,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=self.args.shuffle,
|
|
||||||
num_buckets=self.args.num_buckets,
|
|
||||||
buffer_size=self.args.num_buckets * 2000,
|
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
|
||||||
drop_last=self.args.drop_last,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info("Using SimpleCutSampler.")
|
|
||||||
train_sampler = SimpleCutSampler(
|
|
||||||
cuts_train,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=self.args.shuffle,
|
|
||||||
)
|
|
||||||
logging.info("About to create train dataloader")
|
|
||||||
|
|
||||||
if sampler_state_dict is not None:
|
|
||||||
logging.info("Loading sampler state dict")
|
|
||||||
train_sampler.load_state_dict(sampler_state_dict)
|
|
||||||
|
|
||||||
# 'seed' is derived from the current random state, which will have
|
|
||||||
# previously been set in the main process.
|
|
||||||
seed = torch.randint(0, 100000, ()).item()
|
|
||||||
worker_init_fn = _SeedWorkers(seed)
|
|
||||||
|
|
||||||
train_dl = DataLoader(
|
|
||||||
train,
|
|
||||||
sampler=train_sampler,
|
|
||||||
batch_size=None,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
persistent_workers=False,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_dl
|
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
|
||||||
logging.info("About to create dev dataset")
|
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
sampling_rate = 22050
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
|
||||||
use_fft_mag=True,
|
|
||||||
)
|
|
||||||
validate = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
validate = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
valid_sampler = DynamicBucketingSampler(
|
|
||||||
cuts_valid,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
logging.info("About to create valid dataloader")
|
|
||||||
valid_dl = DataLoader(
|
|
||||||
validate,
|
|
||||||
sampler=valid_sampler,
|
|
||||||
batch_size=None,
|
|
||||||
num_workers=2,
|
|
||||||
persistent_workers=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return valid_dl
|
|
||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
|
||||||
logging.info("About to create test dataset")
|
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
sampling_rate = 22050
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
|
||||||
use_fft_mag=True,
|
|
||||||
)
|
|
||||||
test = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
test = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
test_sampler = DynamicBucketingSampler(
|
|
||||||
cuts,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
logging.info("About to create test dataloader")
|
|
||||||
test_dl = DataLoader(
|
|
||||||
test,
|
|
||||||
batch_size=None,
|
|
||||||
sampler=test_sampler,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
)
|
|
||||||
return test_dl
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get train cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def valid_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get validation cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
1
egs/ljspeech/TTS/vits2/tts_datamodule.py
Symbolic link
1
egs/ljspeech/TTS/vits2/tts_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/vits/tts_datamodule.py
|
@ -545,10 +545,6 @@ class VITS(nn.Module):
|
|||||||
discriminator_fake_loss=fake_loss.item(),
|
discriminator_fake_loss=fake_loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# reset cache
|
|
||||||
if reuse_cache or not self.training:
|
|
||||||
self._cache = None
|
|
||||||
|
|
||||||
return loss, stats
|
return loss, stats
|
||||||
|
|
||||||
def _forward_discrminator_duration(
|
def _forward_discrminator_duration(
|
||||||
@ -582,7 +578,6 @@ class VITS(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# setup
|
# setup
|
||||||
feats = feats.transpose(1, 2)
|
feats = feats.transpose(1, 2)
|
||||||
speech = speech.unsqueeze(1)
|
|
||||||
|
|
||||||
# calculate generator outputs
|
# calculate generator outputs
|
||||||
reuse_cache = True
|
reuse_cache = True
|
||||||
|
@ -1,348 +0,0 @@
|
|||||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
|
||||||
|
|
||||||
# Copyright 2021 Tomoki Hayashi
|
|
||||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
||||||
|
|
||||||
"""WaveNet modules.
|
|
||||||
|
|
||||||
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class WaveNet(torch.nn.Module):
|
|
||||||
"""WaveNet with global conditioning."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int = 1,
|
|
||||||
out_channels: int = 1,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
layers: int = 30,
|
|
||||||
stacks: int = 3,
|
|
||||||
base_dilation: int = 2,
|
|
||||||
residual_channels: int = 64,
|
|
||||||
aux_channels: int = -1,
|
|
||||||
gate_channels: int = 128,
|
|
||||||
skip_channels: int = 64,
|
|
||||||
global_channels: int = -1,
|
|
||||||
dropout_rate: float = 0.0,
|
|
||||||
bias: bool = True,
|
|
||||||
use_weight_norm: bool = True,
|
|
||||||
use_first_conv: bool = False,
|
|
||||||
use_last_conv: bool = False,
|
|
||||||
scale_residual: bool = False,
|
|
||||||
scale_skip_connect: bool = False,
|
|
||||||
):
|
|
||||||
"""Initialize WaveNet module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Number of input channels.
|
|
||||||
out_channels (int): Number of output channels.
|
|
||||||
kernel_size (int): Kernel size of dilated convolution.
|
|
||||||
layers (int): Number of residual block layers.
|
|
||||||
stacks (int): Number of stacks i.e., dilation cycles.
|
|
||||||
base_dilation (int): Base dilation factor.
|
|
||||||
residual_channels (int): Number of channels in residual conv.
|
|
||||||
gate_channels (int): Number of channels in gated conv.
|
|
||||||
skip_channels (int): Number of channels in skip conv.
|
|
||||||
aux_channels (int): Number of channels for local conditioning feature.
|
|
||||||
global_channels (int): Number of channels for global conditioning feature.
|
|
||||||
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
|
|
||||||
bias (bool): Whether to use bias parameter in conv layer.
|
|
||||||
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
|
||||||
be applied to all of the conv layers.
|
|
||||||
use_first_conv (bool): Whether to use the first conv layers.
|
|
||||||
use_last_conv (bool): Whether to use the last conv layers.
|
|
||||||
scale_residual (bool): Whether to scale the residual outputs.
|
|
||||||
scale_skip_connect (bool): Whether to scale the skip connection outputs.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.layers = layers
|
|
||||||
self.stacks = stacks
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.base_dilation = base_dilation
|
|
||||||
self.use_first_conv = use_first_conv
|
|
||||||
self.use_last_conv = use_last_conv
|
|
||||||
self.scale_skip_connect = scale_skip_connect
|
|
||||||
|
|
||||||
# check the number of layers and stacks
|
|
||||||
assert layers % stacks == 0
|
|
||||||
layers_per_stack = layers // stacks
|
|
||||||
|
|
||||||
# define first convolution
|
|
||||||
if self.use_first_conv:
|
|
||||||
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
|
||||||
|
|
||||||
# define residual blocks
|
|
||||||
self.conv_layers = torch.nn.ModuleList()
|
|
||||||
for layer in range(layers):
|
|
||||||
dilation = base_dilation ** (layer % layers_per_stack)
|
|
||||||
conv = ResidualBlock(
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
residual_channels=residual_channels,
|
|
||||||
gate_channels=gate_channels,
|
|
||||||
skip_channels=skip_channels,
|
|
||||||
aux_channels=aux_channels,
|
|
||||||
global_channels=global_channels,
|
|
||||||
dilation=dilation,
|
|
||||||
dropout_rate=dropout_rate,
|
|
||||||
bias=bias,
|
|
||||||
scale_residual=scale_residual,
|
|
||||||
)
|
|
||||||
self.conv_layers += [conv]
|
|
||||||
|
|
||||||
# define output layers
|
|
||||||
if self.use_last_conv:
|
|
||||||
self.last_conv = torch.nn.Sequential(
|
|
||||||
torch.nn.ReLU(inplace=False),
|
|
||||||
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
|
||||||
torch.nn.ReLU(inplace=False),
|
|
||||||
Conv1d1x1(skip_channels, out_channels, bias=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# apply weight norm
|
|
||||||
if use_weight_norm:
|
|
||||||
self.apply_weight_norm()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_mask: Optional[torch.Tensor] = None,
|
|
||||||
c: Optional[torch.Tensor] = None,
|
|
||||||
g: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
|
|
||||||
(B, residual_channels, T).
|
|
||||||
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
|
|
||||||
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
|
|
||||||
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
|
|
||||||
(B, residual_channels, T).
|
|
||||||
|
|
||||||
"""
|
|
||||||
# encode to hidden representation
|
|
||||||
if self.use_first_conv:
|
|
||||||
x = self.first_conv(x)
|
|
||||||
|
|
||||||
# residual block
|
|
||||||
skips = 0.0
|
|
||||||
for f in self.conv_layers:
|
|
||||||
x, h = f(x, x_mask=x_mask, c=c, g=g)
|
|
||||||
skips = skips + h
|
|
||||||
x = skips
|
|
||||||
if self.scale_skip_connect:
|
|
||||||
x = x * math.sqrt(1.0 / len(self.conv_layers))
|
|
||||||
|
|
||||||
# apply final layers
|
|
||||||
if self.use_last_conv:
|
|
||||||
x = self.last_conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
"""Remove weight normalization module from all of the layers."""
|
|
||||||
|
|
||||||
def _remove_weight_norm(m: torch.nn.Module):
|
|
||||||
try:
|
|
||||||
logging.debug(f"Weight norm is removed from {m}.")
|
|
||||||
torch.nn.utils.remove_weight_norm(m)
|
|
||||||
except ValueError: # this module didn't have weight norm
|
|
||||||
return
|
|
||||||
|
|
||||||
self.apply(_remove_weight_norm)
|
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
|
||||||
"""Apply weight normalization module from all of the layers."""
|
|
||||||
|
|
||||||
def _apply_weight_norm(m: torch.nn.Module):
|
|
||||||
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
|
||||||
torch.nn.utils.weight_norm(m)
|
|
||||||
logging.debug(f"Weight norm is applied to {m}.")
|
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_receptive_field_size(
|
|
||||||
layers: int,
|
|
||||||
stacks: int,
|
|
||||||
kernel_size: int,
|
|
||||||
base_dilation: int,
|
|
||||||
) -> int:
|
|
||||||
assert layers % stacks == 0
|
|
||||||
layers_per_cycle = layers // stacks
|
|
||||||
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
|
|
||||||
return (kernel_size - 1) * sum(dilations) + 1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def receptive_field_size(self) -> int:
|
|
||||||
"""Return receptive field size."""
|
|
||||||
return self._get_receptive_field_size(
|
|
||||||
self.layers, self.stacks, self.kernel_size, self.base_dilation
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(torch.nn.Conv1d):
|
|
||||||
"""Conv1d module with customized initialization."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
"""Initialize Conv1d module."""
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
"""Reset parameters."""
|
|
||||||
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
|
||||||
if self.bias is not None:
|
|
||||||
torch.nn.init.constant_(self.bias, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1d1x1(Conv1d):
|
|
||||||
"""1x1 Conv1d with customized initialization."""
|
|
||||||
|
|
||||||
def __init__(self, in_channels: int, out_channels: int, bias: bool):
|
|
||||||
"""Initialize 1x1 Conv1d module."""
|
|
||||||
super().__init__(
|
|
||||||
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(torch.nn.Module):
|
|
||||||
"""Residual block module in WaveNet."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
residual_channels: int = 64,
|
|
||||||
gate_channels: int = 128,
|
|
||||||
skip_channels: int = 64,
|
|
||||||
aux_channels: int = 80,
|
|
||||||
global_channels: int = -1,
|
|
||||||
dropout_rate: float = 0.0,
|
|
||||||
dilation: int = 1,
|
|
||||||
bias: bool = True,
|
|
||||||
scale_residual: bool = False,
|
|
||||||
):
|
|
||||||
"""Initialize ResidualBlock module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kernel_size (int): Kernel size of dilation convolution layer.
|
|
||||||
residual_channels (int): Number of channels for residual connection.
|
|
||||||
skip_channels (int): Number of channels for skip connection.
|
|
||||||
aux_channels (int): Number of local conditioning channels.
|
|
||||||
dropout (float): Dropout probability.
|
|
||||||
dilation (int): Dilation factor.
|
|
||||||
bias (bool): Whether to add bias parameter in convolution layers.
|
|
||||||
scale_residual (bool): Whether to scale the residual outputs.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.dropout_rate = dropout_rate
|
|
||||||
self.residual_channels = residual_channels
|
|
||||||
self.skip_channels = skip_channels
|
|
||||||
self.scale_residual = scale_residual
|
|
||||||
|
|
||||||
# check
|
|
||||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
|
||||||
assert gate_channels % 2 == 0
|
|
||||||
|
|
||||||
# dilation conv
|
|
||||||
padding = (kernel_size - 1) // 2 * dilation
|
|
||||||
self.conv = Conv1d(
|
|
||||||
residual_channels,
|
|
||||||
gate_channels,
|
|
||||||
kernel_size,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# local conditioning
|
|
||||||
if aux_channels > 0:
|
|
||||||
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
|
|
||||||
else:
|
|
||||||
self.conv1x1_aux = None
|
|
||||||
|
|
||||||
# global conditioning
|
|
||||||
if global_channels > 0:
|
|
||||||
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
|
|
||||||
else:
|
|
||||||
self.conv1x1_glo = None
|
|
||||||
|
|
||||||
# conv output is split into two groups
|
|
||||||
gate_out_channels = gate_channels // 2
|
|
||||||
|
|
||||||
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
|
|
||||||
# (integrate res 1x1 + skip 1x1 convs)
|
|
||||||
self.conv1x1_out = Conv1d1x1(
|
|
||||||
gate_out_channels, residual_channels + skip_channels, bias=bias
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_mask: Optional[torch.Tensor] = None,
|
|
||||||
c: Optional[torch.Tensor] = None,
|
|
||||||
g: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Calculate forward propagation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, residual_channels, T).
|
|
||||||
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
|
|
||||||
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
|
|
||||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
|
||||||
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
|
||||||
|
|
||||||
"""
|
|
||||||
residual = x
|
|
||||||
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
# split into two part for gated activation
|
|
||||||
splitdim = 1
|
|
||||||
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
|
||||||
|
|
||||||
# local conditioning
|
|
||||||
if c is not None:
|
|
||||||
c = self.conv1x1_aux(c)
|
|
||||||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
|
||||||
xa, xb = xa + ca, xb + cb
|
|
||||||
|
|
||||||
# global conditioning
|
|
||||||
if g is not None:
|
|
||||||
g = self.conv1x1_glo(g)
|
|
||||||
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
|
|
||||||
xa, xb = xa + ga, xb + gb
|
|
||||||
|
|
||||||
x = torch.tanh(xa) * torch.sigmoid(xb)
|
|
||||||
|
|
||||||
# residual + skip 1x1 conv
|
|
||||||
x = self.conv1x1_out(x)
|
|
||||||
if x_mask is not None:
|
|
||||||
x = x * x_mask
|
|
||||||
|
|
||||||
# split integrated conv results
|
|
||||||
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
|
|
||||||
|
|
||||||
# for residual connection
|
|
||||||
x = x + residual
|
|
||||||
if self.scale_residual:
|
|
||||||
x = x * math.sqrt(0.5)
|
|
||||||
|
|
||||||
return x, s
|
|
1
egs/ljspeech/TTS/vits2/wavenet.py
Symbolic link
1
egs/ljspeech/TTS/vits2/wavenet.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/vits/wavenet.py
|
Loading…
x
Reference in New Issue
Block a user