icefall/egs/ljspeech/TTS/vits/wavenet.py
zr_jin 735fb9a73d
A TTS recipe VITS on VCTK dataset (#1380)
* init

* isort formatted

* minor updates

* Create shared

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare.sh

* updated

* Update train.py

* Update train.py

* Update tts_datamodule.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* fixed formatting issue

* Update infer.py

* removed redundant files

* Create monotonic_align

* removed redundant files

* created symlinks

* Update prepare.sh

* minor adjustments

* Create requirements_tts.txt

* Update requirements_tts.txt

added version constraints

* Update infer.py

* Update infer.py

* Update infer.py

* updated docs

* Update export-onnx.py

* Update export-onnx.py

* Update test_onnx.py

* updated requirements.txt

* Update test_onnx.py

* Update test_onnx.py

* docs updated

* docs fixed

* minor updates
2023-12-06 09:59:19 +08:00

349 lines
12 KiB
Python

# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""WaveNet modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import logging
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
class WaveNet(torch.nn.Module):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_size: int = 3,
layers: int = 30,
stacks: int = 3,
base_dilation: int = 2,
residual_channels: int = 64,
aux_channels: int = -1,
gate_channels: int = 128,
skip_channels: int = 64,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
use_first_conv: bool = False,
use_last_conv: bool = False,
scale_residual: bool = False,
scale_skip_connect: bool = False,
):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(layers):
dilation = base_dilation ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual,
)
self.conv_layers += [conv]
# define output layers
if self.use_last_conv:
self.last_conv = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, skip_channels, bias=True),
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, out_channels, bias=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(
layers: int,
stacks: int,
kernel_size: int,
base_dilation: int,
) -> int:
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self) -> int:
"""Return receptive field size."""
return self._get_receptive_field_size(
self.layers, self.stacks, self.kernel_size, self.base_dilation
)
class Conv1d(torch.nn.Conv1d):
"""Conv1d module with customized initialization."""
def __init__(self, *args, **kwargs):
"""Initialize Conv1d module."""
super().__init__(*args, **kwargs)
def reset_parameters(self):
"""Reset parameters."""
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
if self.bias is not None:
torch.nn.init.constant_(self.bias, 0.0)
class Conv1d1x1(Conv1d):
"""1x1 Conv1d with customized initialization."""
def __init__(self, in_channels: int, out_channels: int, bias: bool):
"""Initialize 1x1 Conv1d module."""
super().__init__(
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
)
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int = 3,
residual_channels: int = 64,
gate_channels: int = 128,
skip_channels: int = 64,
aux_channels: int = 80,
global_channels: int = -1,
dropout_rate: float = 0.0,
dilation: int = 1,
bias: bool = True,
scale_residual: bool = False,
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = Conv1d(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = Conv1d1x1(
gate_out_channels, residual_channels + skip_channels, bias=bias
)
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ga, xb + gb
x = torch.tanh(xa) * torch.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s