Add MAS to VIT-1

This commit is contained in:
Erwan 2024-02-08 16:29:10 +01:00
parent 027302c902
commit cafc33bac9
22 changed files with 6966 additions and 0 deletions

View File

@ -0,0 +1,3 @@
See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29.

View File

@ -0,0 +1,193 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Stochastic duration predictor modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
import torch
import torch.nn.functional as F
from flow import (
ConvFlow,
DilatedDepthSeparableConv,
ElementwiseAffineFlow,
FlipFlow,
LogFlow,
)
class StochasticDurationPredictor(torch.nn.Module):
"""Stochastic duration predictor module.
This is a module of stochastic duration predictor described in `Conditional
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
flows: int = 4,
dds_conv_layers: int = 3,
global_channels: int = -1,
):
"""Initialize StochasticDurationPredictor module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
dropout_rate (float): Dropout rate.
flows (int): Number of flows.
dds_conv_layers (int): Number of conv layers in DDS conv.
global_channels (int): Number of global conditioning channels.
"""
super().__init__()
self.pre = torch.nn.Conv1d(channels, channels, 1)
self.dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate,
)
self.proj = torch.nn.Conv1d(channels, channels, 1)
self.log_flow = LogFlow()
self.flows = torch.nn.ModuleList()
self.flows += [ElementwiseAffineFlow(2)]
for i in range(flows):
self.flows += [
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers,
)
]
self.flows += [FlipFlow()]
self.post_pre = torch.nn.Conv1d(1, channels, 1)
self.post_dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate,
)
self.post_proj = torch.nn.Conv1d(channels, channels, 1)
self.post_flows = torch.nn.ModuleList()
self.post_flows += [ElementwiseAffineFlow(2)]
for i in range(flows):
self.post_flows += [
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers,
)
]
self.post_flows += [FlipFlow()]
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
w: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
noise_scale: float = 1.0,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T_text).
x_mask (Tensor): Mask tensor (B, 1, T_text).
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
inverse (bool): Whether to inverse the flow.
noise_scale (float): Noise scale value.
Returns:
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
If inverse, log-duration tensor (B, 1, T_text).
"""
x = x.detach() # stop gradient
x = self.pre(x)
if g is not None:
x = x + self.global_conv(g.detach()) # stop gradient
x = self.dds(x, x_mask)
x = self.proj(x) * x_mask
if not inverse:
assert w is not None, "w must be provided."
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(
w.size(0),
2,
w.size(2),
).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q
logdet_tot_q = 0.0
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # (B,)
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(
x.size(0),
2,
x.size(2),
).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)
z0, z1 = z.split(1, 1)
logw = z0
return logw

View File

@ -0,0 +1,271 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a VITS model from PyTorch to ONNX.
Export the model to ONNX:
./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxModel(nn.Module):
"""A wrapper for VITS generator."""
def __init__(self, model: nn.Module):
"""
Args:
model:
A VITS generator.
frame_shift:
The frame shift in samples.
"""
super().__init__()
self.model = model
def forward(
self,
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
alpha: float = 1.0,
noise_scale_dur: float = 0.8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
Args:
tokens:
Input text token indexes (1, T_text)
tokens_lens:
Number of tokens of shape (1,)
noise_scale (float):
Noise scale parameter for flow.
noise_scale_dur (float):
Noise scale parameter for duration predictor.
alpha (float):
Alpha parameter to control the speed of generated speech.
Returns:
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
)
return audio
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
The exported model has one input:
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
and it has one output:
- audio, a tensor of shape (1, T'); dtype is torch.float32
Args:
model:
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
vocab_size:
Number of tokens used in training.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32)
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"alpha",
"noise_scale_dur",
],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"},
},
)
meta_data = {
"model_type": "VITS",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
suffix = f"epoch-{params.epoch}"
opset_version = 13
logging.info("Exporting encoder")
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,311 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Basic Flow modules used in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional, Tuple, Union
import torch
from transform import piecewise_rational_quadratic_transform
class FlipFlow(torch.nn.Module):
"""Flip flow module."""
def forward(
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Flipped tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
x = torch.flip(x, [1])
if not inverse:
logdet = x.new_zeros(x.size(0))
return x, logdet
else:
return x
class LogFlow(torch.nn.Module):
"""Log flow module."""
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
inverse: bool = False,
eps: float = 1e-5,
**kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
eps (float): Epsilon for log.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = torch.log(torch.clamp_min(x, eps)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffineFlow(torch.nn.Module):
"""Elementwise affine flow module."""
def __init__(self, channels: int):
"""Initialize ElementwiseAffineFlow module.
Args:
channels (int): Number of channels.
"""
super().__init__()
self.channels = channels
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_lengths (Tensor): Length tensor (B,).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class Transpose(torch.nn.Module):
"""Transpose module for torch.nn.Sequential()."""
def __init__(self, dim1: int, dim2: int):
"""Initialize Transpose module."""
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transpose."""
return x.transpose(self.dim1, self.dim2)
class DilatedDepthSeparableConv(torch.nn.Module):
"""Dilated depth-separable conv module."""
def __init__(
self,
channels: int,
kernel_size: int,
layers: int,
dropout_rate: float = 0.0,
eps: float = 1e-5,
):
"""Initialize DilatedDepthSeparableConv module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
dropout_rate (float): Dropout rate.
eps (float): Epsilon for layer norm.
"""
super().__init__()
self.convs = torch.nn.ModuleList()
for i in range(layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs += [
torch.nn.Sequential(
torch.nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
),
Transpose(1, 2),
torch.nn.LayerNorm(
channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
torch.nn.GELU(),
torch.nn.Conv1d(
channels,
channels,
1,
),
Transpose(1, 2),
torch.nn.LayerNorm(
channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
)
]
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, channels, T).
"""
if g is not None:
x = x + g
for f in self.convs:
y = f(x * x_mask)
x = x + y
return x * x_mask
class ConvFlow(torch.nn.Module):
"""Convolutional flow module."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
layers: int,
bins: int = 10,
tail_bound: float = 5.0,
):
"""Initialize ConvFlow module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
bins (int): Number of bins.
tail_bound (float): Tail bound value.
"""
super().__init__()
self.half_channels = in_channels // 2
self.hidden_channels = hidden_channels
self.bins = bins
self.tail_bound = tail_bound
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.dds_conv = DilatedDepthSeparableConv(
hidden_channels,
kernel_size,
layers,
dropout_rate=0.0,
)
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * (bins * 3 - 1),
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, 1)
h = self.input_conv(xa)
h = self.dds_conv(h, x_mask, g=g)
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
b, c, t = xa.shape
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
# TODO(kan-bayashi): Understand this calculation
denom = math.sqrt(self.hidden_channels)
unnorm_widths = h[..., : self.bins] / denom
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins :]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([xa, xb], 1) * x_mask
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
if not inverse:
return x, logdet
else:
return x

View File

@ -0,0 +1,549 @@
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Generator module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
from icefall.utils import make_pad_mask
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
aux_channels: int = 513,
hidden_channels: int = 192,
spks: Optional[int] = None,
langs: Optional[int] = None,
spk_embed_dim: Optional[int] = None,
global_channels: int = -1,
segment_size: int = 32,
text_encoder_attention_heads: int = 2,
text_encoder_ffn_expand: int = 4,
text_encoder_cnn_module_kernel: int = 5,
text_encoder_blocks: int = 6,
text_encoder_dropout_rate: float = 0.1,
decoder_kernel_size: int = 7,
decoder_channels: int = 512,
decoder_upsample_scales: List[int] = [8, 8, 2, 2],
decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
use_weight_norm_in_decoder: bool = True,
posterior_encoder_kernel_size: int = 5,
posterior_encoder_layers: int = 16,
posterior_encoder_stacks: int = 1,
posterior_encoder_base_dilation: int = 1,
posterior_encoder_dropout_rate: float = 0.0,
use_weight_norm_in_posterior_encoder: bool = True,
flow_flows: int = 4,
flow_kernel_size: int = 5,
flow_base_dilation: int = 1,
flow_layers: int = 4,
flow_dropout_rate: float = 0.0,
use_weight_norm_in_flow: bool = True,
use_only_mean_in_flow: bool = True,
stochastic_duration_predictor_kernel_size: int = 3,
stochastic_duration_predictor_dropout_rate: float = 0.5,
stochastic_duration_predictor_flows: int = 4,
stochastic_duration_predictor_dds_conv_layers: int = 3,
use_noised_mas: bool = True,
noise_initial_mas: float = 0.01,
noise_scale_mas: float = 2e-6,
):
"""Initialize VITS generator module.
Args:
vocabs (int): Input vocabulary size.
aux_channels (int): Number of acoustic feature channels.
hidden_channels (int): Number of hidden channels.
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
sids will be provided as the input and use sid embedding layer.
langs (Optional[int]): Number of languages. If set to > 1, assume that the
lids will be provided as the input and use sid embedding layer.
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
assume that spembs will be provided as the input.
global_channels (int): Number of global conditioning channels.
segment_size (int): Segment size for decoder.
text_encoder_attention_heads (int): Number of heads in conformer block
of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder.
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder.
decoder_kernel_size (int): Decoder kernel size.
decoder_channels (int): Number of decoder initial channels.
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
upsampling layers in decoder.
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
in decoder.
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
resblocks in decoder.
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
decoder.
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
posterior_encoder_layers (int): Number of layers of posterior encoder.
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
normalization in posterior encoder.
flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow.
flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
flow.
use_only_mean_in_flow (bool): Whether to use only mean in flow.
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
duration predictor.
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
stochastic duration predictor.
stochastic_duration_predictor_flows (int): Number of flows in stochastic
duration predictor.
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
layers in stochastic duration predictor.
"""
super().__init__()
self.segment_size = segment_size
self.text_encoder = TextEncoder(
vocabs=vocabs,
d_model=hidden_channels,
num_heads=text_encoder_attention_heads,
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
cnn_module_kernel=text_encoder_cnn_module_kernel,
num_layers=text_encoder_blocks,
dropout=text_encoder_dropout_rate,
)
self.decoder = HiFiGANGenerator(
in_channels=hidden_channels,
out_channels=1,
channels=decoder_channels,
global_channels=global_channels,
kernel_size=decoder_kernel_size,
upsample_scales=decoder_upsample_scales,
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
resblock_dilations=decoder_resblock_dilations,
use_weight_norm=use_weight_norm_in_decoder,
)
self.posterior_encoder = PosteriorEncoder(
in_channels=aux_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
kernel_size=posterior_encoder_kernel_size,
layers=posterior_encoder_layers,
stacks=posterior_encoder_stacks,
base_dilation=posterior_encoder_base_dilation,
global_channels=global_channels,
dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder,
)
self.flow = ResidualAffineCouplingBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
flows=flow_flows,
kernel_size=flow_kernel_size,
base_dilation=flow_base_dilation,
layers=flow_layers,
global_channels=global_channels,
dropout_rate=flow_dropout_rate,
use_weight_norm=use_weight_norm_in_flow,
use_only_mean=use_only_mean_in_flow,
)
# TODO(kan-bayashi): Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor(
channels=hidden_channels,
kernel_size=stochastic_duration_predictor_kernel_size,
dropout_rate=stochastic_duration_predictor_dropout_rate,
flows=stochastic_duration_predictor_flows,
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
global_channels=global_channels,
)
self.upsample_factor = int(np.prod(decoder_upsample_scales))
# MAS with Gaussian Noise params
self.use_noised_mas = use_noised_mas
self.noise_current_mas = noise_initial_mas
self.noise_scale_mas = noise_scale_mas
self.noise_initial_mas = noise_initial_mas
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
if spk_embed_dim is not None and spk_embed_dim > 0:
assert global_channels > 0
self.spk_embed_dim = spk_embed_dim
self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
self.langs = None
if langs is not None and langs > 1:
assert global_channels > 0
self.langs = langs
self.lang_emb = torch.nn.Embedding(langs, global_channels)
# delayed import
from monotonic_align import maximum_path
self.maximum_path = maximum_path
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
]:
"""Calculate forward propagation.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
Tensor: Segments start index tensor (B,).
Tensor: Text mask tensor (B, 1, T_text).
Tensor: Feature mask tensor (B, 1, T_feats).
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
- Tensor: Flow hidden representation (B, H, T_feats).
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
- Tensor: Posterior encoder projected mean (B, H, T_feats).
- Tensor: Posterior encoder projected scale (B, H, T_feats).
"""
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
# calculate global conditioning
g = None
if self.spks is not None:
# speaker one-hot vector embedding: (B, global_channels, 1)
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# language one-hot vector embedding: (B, global_channels, 1)
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
# forward flow
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
# monotonic alignment search
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
# (B, 1, T_text)
neg_x_ent_1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2),
s_p_sq_r,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = torch.matmul(
z_p.transpose(1, 2),
(m_p * s_p_sq_r),
)
# (B, 1, T_text)
neg_x_ent_4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True,
)
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
if self.use_noised_mas:
epsilon = (
torch.std(neg_x_ent)
* torch.randn_like(neg_x_ent)
* self.noise_current_mas
)
neg_x_ent = neg_x_ent + epsilon
# (B, 1, T_feats, T_text)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = (
self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1),
)
.unsqueeze(1)
.detach()
)
# forward duration predictor
w = attn.sum(2) # (B, 1, T_text)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / torch.sum(x_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
# get random segments
z_segments, z_start_idxs = get_random_segments(
z,
feats_lengths,
self.segment_size,
)
# forward decoder with random segments
wav = self.decoder(z_segments, g=g)
return (
wav,
dur_nll,
attn,
z_start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
)
def inference(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: Optional[torch.Tensor] = None,
feats_lengths: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
dur: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (B, T_text,).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
skip the prediction of durations (i.e., teacher forcing).
noise_scale (float): Noise scale parameter for flow.
noise_scale_dur (float): Noise scale parameter for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length of acoustic feature sequence.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Tensor: Generated waveform tensor (B, T_wav).
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
Tensor: Duration tensor (B, T_text).
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
x_mask = x_mask.to(x.dtype)
g = None
if self.spks is not None:
# (B, global_channels, 1)
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if use_teacher_forcing:
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
# forward flow
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
# monotonic alignment search
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
# (B, 1, T_text)
neg_x_ent_1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2),
s_p_sq_r,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = torch.matmul(
z_p.transpose(1, 2),
(m_p * s_p_sq_r),
)
# (B, 1, T_text)
neg_x_ent_4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True,
)
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1),
).unsqueeze(1)
dur = attn.sum(2) # (B, 1, T_text)
# forward decoder with random segments
wav = self.decoder(z * y_mask, g=g)
else:
# duration
if dur is None:
logw = self.duration_predictor(
x,
x_mask,
g=g,
inverse=True,
noise_scale=noise_scale_dur,
)
w = torch.exp(logw) * x_mask * alpha
dur = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
y_mask = y_mask.to(x.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = self._generate_path(dur, attn_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = torch.matmul(
attn.squeeze(1),
m_p.transpose(1, 2),
).transpose(1, 2)
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = torch.matmul(
attn.squeeze(1),
logs_p.transpose(1, 2),
).transpose(1, 2)
# decoder
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, inverse=True)
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate path a.k.a. monotonic attention.
Args:
dur (Tensor): Duration tensor (B, 1, T_text).
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
Returns:
Tensor: Path tensor (B, 1, T_feats, T_text).
"""
b, _, t_y, t_x = mask.shape
cum_dur = torch.cumsum(dur, -1)
cum_dur_flat = cum_dur.view(b * t_x)
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
path = path.view(b, t_x, t_y).to(dtype=torch.float)
# path will be like (t_x = 3, t_y = 5):
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
# path = path.to(dtype=mask.dtype)
return path.unsqueeze(1).transpose(2, 3) * mask

View File

@ -0,0 +1,933 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HiFi-GAN Modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import copy
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn.functional as F
class HiFiGANGenerator(torch.nn.Module):
"""HiFiGAN generator module."""
def __init__(
self,
in_channels: int = 80,
out_channels: int = 1,
channels: int = 512,
global_channels: int = -1,
kernel_size: int = 7,
upsample_scales: List[int] = [8, 8, 2, 2],
upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
use_additional_convs: bool = True,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
):
"""Initialize HiFiGANGenerator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (List[int]): List of upsampling scales.
upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
resblock_dilations (List[List[int]]): List of list of dilations for residual
blocks.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
"""
super().__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, "Kernel size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
# define modules
self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
self.num_upsamples = len(upsample_kernel_sizes)
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.ConvTranspose1d(
channels // (2**i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
),
)
]
for j in range(len(resblock_kernel_sizes)):
self.blocks += [
ResidualBlock(
kernel_size=resblock_kernel_sizes[j],
channels=channels // (2 ** (i + 1)),
dilations=resblock_dilations[j],
bias=bias,
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
)
]
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
),
torch.nn.Tanh(),
)
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
def forward(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T).
"""
c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
cs = 0.0 # initialize
for j in range(self.num_blocks):
cs += self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks
c = self.output_conv(c)
return c
def reset_parameters(self):
"""Reset parameters.
This initialization follows the official implementation manner.
https://github.com/jik876/hifi-gan/blob/master/models.py
"""
def _reset_parameters(m: torch.nn.Module):
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
m.weight.data.normal_(0.0, 0.01)
logging.debug(f"Reset parameters in {m}.")
self.apply(_reset_parameters)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(
m, torch.nn.ConvTranspose1d
):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def inference(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Perform inference.
Args:
c (torch.Tensor): Input tensor (T, in_channels).
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
Returns:
Tensor: Output tensor (T ** upsample_factor, out_channels).
"""
if g is not None:
g = g.unsqueeze(0)
c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
return c.squeeze(0).transpose(1, 0)
class ResidualBlock(torch.nn.Module):
"""Residual block module in HiFiGAN."""
def __init__(
self,
kernel_size: int = 3,
channels: int = 512,
dilations: List[int] = [1, 3, 5],
bias: bool = True,
use_additional_convs: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
channels (int): Number of channels for convolution layer.
dilations (List[int]): List of dilation factors.
use_additional_convs (bool): Whether to use additional convolution layers.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
"""
super().__init__()
self.use_additional_convs = use_additional_convs
self.convs1 = torch.nn.ModuleList()
if use_additional_convs:
self.convs2 = torch.nn.ModuleList()
assert kernel_size % 2 == 1, "Kernel size must be odd number."
for dilation in dilations:
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
bias=bias,
padding=(kernel_size - 1) // 2 * dilation,
),
)
]
if use_additional_convs:
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
bias=bias,
padding=(kernel_size - 1) // 2,
),
)
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
Returns:
Tensor: Output tensor (B, channels, T).
"""
for idx in range(len(self.convs1)):
xt = self.convs1[idx](x)
if self.use_additional_convs:
xt = self.convs2[idx](xt)
x = xt + x
return x
class HiFiGANPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN period discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
period: int = 3,
kernel_sizes: List[int] = [5, 3],
channels: int = 32,
downsample_scales: List[int] = [3, 3, 3, 3, 1],
max_downsample_channels: int = 1024,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initialize HiFiGANPeriodDiscriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
period (int): Period.
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
layer.
channels (int): Number of initial channels.
downsample_scales (List[int]): List of downsampling scales.
max_downsample_channels (int): Number of maximum downsampling channels.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm.
If set to true, it will be applied to all of the conv layers.
"""
super().__init__()
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
self.period = period
self.convs = torch.nn.ModuleList()
in_chs = in_channels
out_chs = channels
for downsample_scale in downsample_scales:
self.convs += [
torch.nn.Sequential(
torch.nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Use downsample_scale + 1?
out_chs = min(out_chs * 4, max_downsample_channels)
self.output_conv = torch.nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
if use_spectral_norm:
self.apply_spectral_norm()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
Returns:
list: List of each layer's tensors.
"""
# transform 1d to 2d -> (B, C, T/P, P)
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t += n_pad
x = x.view(b, c, t // self.period, self.period)
# forward conv
outs = []
for layer in self.convs:
x = layer(x)
outs += [x]
x = self.output_conv(x)
x = torch.flatten(x, 1, -1)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN multi-period discriminator module."""
def __init__(
self,
periods: List[int] = [2, 3, 5, 7, 11],
discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
):
"""Initialize HiFiGANMultiPeriodDiscriminator module.
Args:
periods (List[int]): List of periods.
discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
discriminator module. The period parameter will be overwritten.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params["period"] = period
self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each
layer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
return outs
class HiFiGANScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN scale discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_sizes: List[int] = [15, 41, 5, 3],
channels: int = 128,
max_downsample_channels: int = 1024,
max_groups: int = 16,
bias: int = True,
downsample_scales: List[int] = [2, 2, 4, 4, 1],
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initilize HiFiGAN scale discriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_sizes (List[int]): List of four kernel sizes. The first will be used
for the first conv layer, and the second is for downsampling part, and
the remaining two are for the last two output layers.
channels (int): Initial number of channels for conv layer.
max_downsample_channels (int): Maximum number of channels for downsampling
layers.
bias (bool): Whether to add bias parameter in convolution layers.
downsample_scales (List[int]): List of downsampling scales.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
will be applied to all of the conv layers.
"""
super().__init__()
self.layers = torch.nn.ModuleList()
# check kernel size is valid
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
# add first layer
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_channels,
channels,
# NOTE(kan-bayashi): Use always the same kernel size
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
# add downsample layers
in_chs = channels
out_chs = channels
# NOTE(kan-bayashi): Remove hard coding?
groups = 4
for downsample_scale in downsample_scales:
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Remove hard coding?
out_chs = min(in_chs * 2, max_downsample_channels)
# NOTE(kan-bayashi): Remove hard coding?
groups = min(groups * 4, max_groups)
# add final layers
out_chs = min(in_chs * 2, max_downsample_channels)
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
self.layers += [
torch.nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
),
]
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
self.use_weight_norm = use_weight_norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
self.use_spectral_norm = use_spectral_norm
if use_spectral_norm:
self.apply_spectral_norm()
# backward compatibility
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[Tensor]: List of output tensors of each layer.
"""
outs = []
for f in self.layers:
x = f(x)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def remove_spectral_norm(self):
"""Remove spectral normalization module from all of the layers."""
def _remove_spectral_norm(m):
try:
logging.debug(f"Spectral norm is removed from {m}.")
torch.nn.utils.remove_spectral_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_spectral_norm)
def _load_state_dict_pre_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Fix the compatibility of weight / spectral normalization issue.
Some pretrained models are trained with configs that use weight / spectral
normalization, but actually, the norm is not applied. This causes the mismatch
of the parameters with configs. To solve this issue, when parameter mismatch
happens in loading pretrained model, we remove the norm from the current model.
See also:
- https://github.com/espnet/espnet/pull/5240
- https://github.com/espnet/espnet/pull/5249
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
"""
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
if self.use_weight_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_weight_norm()
self.use_weight_norm = False
for k in current_module_keys:
if k.endswith("weight_g") or k.endswith("weight_v"):
del state_dict[k]
if self.use_spectral_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_spectral_norm()
self.use_spectral_norm = False
for k in current_module_keys:
if (
k.endswith("weight_u")
or k.endswith("weight_v")
or k.endswith("weight_orig")
):
del state_dict[k]
class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale discriminator module."""
def __init__(
self,
scales: int = 3,
downsample_pooling: str = "AvgPool1d",
# follow the official implementation setting
downsample_pooling_params: Dict[str, Any] = {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm: bool = False,
):
"""Initilize HiFiGAN multi-scale discriminator module.
Args:
scales (int): Number of multi-scales.
downsample_pooling (str): Pooling module name for downsampling of the
inputs.
downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
module.
discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
discriminator module.
follow_official_norm (bool): Whether to follow the norm setting of the
official implementaion. The first discriminator uses spectral norm
and the other discriminators use weight norm.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
if i == 0:
params["use_weight_norm"] = False
params["use_spectral_norm"] = True
else:
params["use_weight_norm"] = True
params["use_spectral_norm"] = False
self.discriminators += [HiFiGANScaleDiscriminator(**params)]
self.pooling = None
if scales > 1:
self.pooling = getattr(torch.nn, downsample_pooling)(
**downsample_pooling_params
)
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[List[torch.Tensor]]: List of list of each discriminator outputs,
which consists of eachlayer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
if self.pooling is not None:
x = self.pooling(x)
return outs
class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale + multi-period discriminator module."""
def __init__(
self,
# Multi-scale discriminator related
scales: int = 3,
scale_downsample_pooling: str = "AvgPool1d",
scale_downsample_pooling_params: Dict[str, Any] = {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
scale_discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm: bool = True,
# Multi-period discriminator related
periods: List[int] = [2, 3, 5, 7, 11],
period_discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
):
"""Initilize HiFiGAN multi-scale + multi-period discriminator module.
Args:
scales (int): Number of multi-scales.
scale_downsample_pooling (str): Pooling module name for downsampling of the
inputs.
scale_downsample_pooling_params (dict): Parameters for the above pooling
module.
scale_discriminator_params (dict): Parameters for hifi-gan scale
discriminator module.
follow_official_norm (bool): Whether to follow the norm setting of the
official implementaion. The first discriminator uses spectral norm and
the other discriminators use weight norm.
periods (list): List of periods.
period_discriminator_params (dict): Parameters for hifi-gan period
discriminator module. The period parameter will be overwritten.
"""
super().__init__()
self.msd = HiFiGANMultiScaleDiscriminator(
scales=scales,
downsample_pooling=scale_downsample_pooling,
downsample_pooling_params=scale_downsample_pooling_params,
discriminator_params=scale_discriminator_params,
follow_official_norm=follow_official_norm,
)
self.mpd = HiFiGANMultiPeriodDiscriminator(
periods=periods,
discriminator_params=period_discriminator_params,
)
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[List[Tensor]]: List of list of each discriminator outputs,
which consists of each layer output tensors. Multi scale and
multi period ones are concatenated.
"""
msd_outs = self.msd(x)
mpd_outs = self.mpd(x)
return msd_outs + mpd_outs

244
egs/ljspeech/TTS/vits2/infer.py Executable file
View File

@ -0,0 +1,244 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script performs model inference on test set.
Usage:
./vits/infer.py \
--epoch 1000 \
--exp-dir ./vits/exp \
--max-duration 500
"""
import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
import k2
import torch
import torch.nn as nn
import torchaudio
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
audio_pred: torch.Tensor,
audio_lens: List[int],
audio_lens_pred: List[int],
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(
text=tokens, text_lengths=tokens_lens
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)
futures.append(
executor.submit(
_save_worker,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
)
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()
@torch.no_grad()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param_g = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
# we need cut ids to display recognition results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
infer_dataset(
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,335 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HiFiGAN-related loss modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
from typing import List, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize GeneratorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(
self,
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Calcualate generator adversarial loss.
Args:
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs..
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers: bool = True,
average_by_discriminators: bool = True,
include_final_outputs: bool = False,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
sampling_rate: int = 22050,
frame_length: int = 1024, # in samples
frame_shift: int = 256, # in samples
n_mels: int = 80,
use_fft_mag: bool = True,
):
super().__init__()
self.wav_to_mel = Wav2LogFilterBank(
sampling_rate=sampling_rate,
frame_length=frame_length / sampling_rate, # in second
frame_shift=frame_shift / sampling_rate, # in second
use_fft_mag=use_fft_mag,
num_filters=n_mels,
)
def forward(
self,
y_hat: torch.Tensor,
y: torch.Tensor,
return_mel: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated waveform tensor (B, 1, T).
y (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
mel = self.wav_to_mel(y.squeeze(1))
mel_loss = F.l1_loss(mel_hat, mel)
if return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
"""VITS-related loss modules.
This code is based on https://github.com/jaywalnut310/vits.
"""
class KLDivergenceLoss(torch.nn.Module):
"""KL divergence loss."""
def forward(
self,
z_p: torch.Tensor,
logs_q: torch.Tensor,
m_p: torch.Tensor,
logs_p: torch.Tensor,
z_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
loss = kl / torch.sum(z_mask)
return loss
class KLDivergenceLossWithoutFlow(torch.nn.Module):
"""KL divergence loss without flow."""
def forward(
self,
m_q: torch.Tensor,
logs_q: torch.Tensor,
m_p: torch.Tensor,
logs_p: torch.Tensor,
) -> torch.Tensor:
"""Calculate KL divergence loss without flow.
Args:
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
"""
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
prior_norm = D.Normal(m_p, torch.exp(logs_p))
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
return loss

View File

@ -0,0 +1,81 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
"""Maximum path calculation module.
This code is based on https://github.com/jaywalnut310/vits.
"""
import warnings
import numpy as np
import torch
from numba import njit, prange
try:
from .core import maximum_path_c
is_cython_avalable = True
except ImportError:
is_cython_avalable = False
warnings.warn(
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
"If you want to use the cython version, please build it as follows: "
"`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
)
def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
"""Calculate maximum path.
Args:
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
Returns:
Tensor: Maximum path tensor (B, T_feats, T_text).
"""
device, dtype = neg_x_ent.device, neg_x_ent.dtype
neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
if is_cython_avalable:
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
else:
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
@njit
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
"""Calculate a single maximum path with numba."""
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@njit(parallel=True)
def maximum_path_numba(paths, values, t_ys, t_xs):
"""Calculate batch maximum path with numba."""
for i in prange(paths.shape[0]):
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])

View File

@ -0,0 +1,51 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
"""Maximum path calculation module with cython optimization.
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
"""
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])

View File

@ -0,0 +1,31 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
"""Setup cython code."""
from Cython.Build import cythonize
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [
Extension(
name="core",
sources=["core.pyx"],
)
]
setup(
name="monotonic_align",
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext},
)

View File

@ -0,0 +1,117 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Posterior encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional, Tuple
import torch
from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask
class PosteriorEncoder(torch.nn.Module):
"""Posterior encoder module in VITS.
This is a module of posterior encoder described in `Conditional Variational
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int = 513,
out_channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 5,
layers: int = 16,
stacks: int = 1,
base_dilation: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
):
"""Initilialize PosteriorEncoder module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size in WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of repeat stacking of WaveNet.
base_dilation (int): Base dilation factor.
global_channels (int): Number of global conditioning channels.
dropout_rate (float): Dropout rate.
bias (bool): Whether to use bias parameters in conv.
use_weight_norm (bool): Whether to apply weight norm.
"""
super().__init__()
# define modules
self.input_conv = Conv1d(in_channels, hidden_channels, 1)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_feats).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
Tensor: Projected mean tensor (B, out_channels, T_feats).
Tensor: Projected scale tensor (B, out_channels, T_feats).
Tensor: Mask tensor for input tensor (B, 1, T_feats).
"""
x_mask = (
(~make_pad_mask(x_lengths))
.unsqueeze(1)
.to(
dtype=x.dtype,
device=x.device,
)
)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask

View File

@ -0,0 +1,228 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Residual affine coupling modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional, Tuple, Union
import torch
from flow import FlipFlow
from wavenet import WaveNet
class ResidualAffineCouplingBlock(torch.nn.Module):
"""Residual affine coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
flows: int = 4,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 4,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initilize ResidualAffineCouplingBlock module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
flows (int): Number of flows.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
super().__init__()
self.flows = torch.nn.ModuleList()
for i in range(flows):
self.flows += [
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean,
)
]
self.flows += [FlipFlow()]
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
"""
if not inverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, inverse=inverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, inverse=inverse)
return x
class ResidualAffineCouplingLayer(torch.nn.Module):
"""Residual affine coupling layer."""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 5,
stacks: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initialzie ResidualAffineCouplingLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
# define modules
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
if use_only_mean:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels,
1,
)
else:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * 2,
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, dim=1)
h = self.input_conv(xa) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = stats.split(stats.size(1) // 2, dim=1)
else:
m = stats
logs = torch.zeros_like(m)
if not inverse:
xb = m + xb * torch.exp(logs) * x_mask
x = torch.cat([xa, xb], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * torch.exp(-logs) * x_mask
x = torch.cat([xa, xb], 1)
return x

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to test the exported onnx model by vits/export-onnx.py
Use the onnx model to generate a wav:
./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
"""
import argparse
import logging
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
tokens:
A 1-D tensor of shape (1, T)
Returns:
A tensor of shape (1, T')
"""
noise_scale = torch.tensor([0.667], dtype=torch.float32)
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
alpha = torch.tensor([1.0], dtype=torch.float32)
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
},
)[0]
return torch.from_numpy(out)
def main():
args = get_parser().parse_args()
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,684 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on
- https://github.com/jaywalnut310/vits
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
"""
import copy
import math
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import is_jit_tracing, make_pad_mask
class TextEncoder(torch.nn.Module):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
super().__init__()
self.d_model = d_model
# define modules
self.emb = torch.nn.Embedding(vocabs, d_model)
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
# We use conformer as text encoder
self.encoder = Transformer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
num_layers=num_layers,
dropout=dropout,
)
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
# (B, T_text, embed_dim)
x = self.emb(x) * math.sqrt(self.d_model)
assert x.size(1) == x_lengths.max().item()
# (B, T_text)
pad_mask = make_pad_mask(x_lengths)
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
non_pad_mask = (~pad_mask).unsqueeze(1)
stats = self.proj(x) * non_pad_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
return x, m, logs, non_pad_mask
class Transformer(nn.Module):
"""
Args:
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
def __init__(
self,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
) -> None:
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
dropout=dropout,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
self.after_norm = nn.LayerNorm(d_model)
def forward(
self, x: Tensor, key_padding_mask: Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
lengths:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
"""
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.after_norm(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x
class TransformerEncoderLayer(nn.Module):
"""
TransformerEncoderLayer is made up of self-attn and feedforward.
Args:
d_model: the number of expected features in the input.
num_heads: the number of heads in the multi-head attention models.
dim_feedforward: the dimension of the feed-forward network model.
dropout: the dropout value (default=0.1).
"""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
cnn_module_kernel: int,
dropout: float = 0.1,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.self_attn = RelPositionMultiheadAttention(
d_model, num_heads, dropout=dropout
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.ff_scale = 0.5
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the transformer encoder layer.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(
self.feed_forward_macaron(self.norm_ff_macaron(src))
)
# multi-head self-attention module
src_attn = self.self_attn(
self.norm_mha(src),
pos_emb=pos_emb,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout(src_attn)
# convolution module
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
# feed-forward module
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
src = self.norm_final(src)
return src
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer class.
num_layers: the number of sub-encoder-layers in the encoder.
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
output = src
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
key_padding_mask=key_padding_mask,
)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
x_size = x.size(1)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x_size * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size, self.d_model)
pe_negative = torch.zeros(x_size, self.d_model)
position = torch.arange(0, x_size, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, seq_len, 2*seq_len-1).
Returns:
Tensor: tensor of shape (batch, head, seq_len, seq_len)
"""
(batch_size, num_heads, seq_len, n) = x.shape
if not is_jit_tracing():
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
if is_jit_tracing():
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, seq_len, seq_len),
(batch_stride, head_stride, time_stride - n_stride, n_stride),
storage_offset=n_stride * (seq_len - 1),
)
def forward(
self,
x: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x: Input tensor of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
Its shape is (batch_size, seq_len).
Outputs:
A tensor of shape (seq_len, batch_size, embed_dim).
"""
seq_len, batch_size, _ = x.shape
scaling = float(self.head_dim) ** -0.5
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
v = (
v.contiguous()
.view(seq_len, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
p = self.linear_pos(pos_emb).view(
pos_emb.size(0), -1, self.num_heads, self.head_dim
)
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
# (batch_size, num_head, seq_len, head_dim)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p
) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(
matrix_bd
) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
attn_output_weights = attn_output_weights.view(
batch_size, self.num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=self.dropout, training=self.training
)
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (
batch_size * self.num_heads,
seq_len,
self.head_dim,
)
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def _test_text_encoder():
vocabs = 500
d_model = 192
batch_size = 5
seq_len = 100
m = TextEncoder(vocabs=vocabs, d_model=d_model)
x, m, logs, mask = m(
x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
x_lengths=torch.full((batch_size,), seq_len),
)
print(x.shape, m.shape, logs.shape, mask.shape)
if __name__ == "__main__":
_test_text_encoder()

View File

@ -0,0 +1,108 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
import g2p_en
import tacotron_cleaner.cleaners
from utils import intersperse
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
self.token2id[token] = id
self.blank_id = self.token2id["<blk>"]
self.oov_id = self.token2id["<unk>"]
self.vocab_size = len(self.token2id)
self.g2p = g2p_en.G2p()
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for text in texts:
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = self.g2p(text)
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list

936
egs/ljspeech/TTS/vits2/train.py Executable file
View File

@ -0,0 +1,936 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--lr", type=float, default=2.0e-4, help="The base learning rate."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=20,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
# training params
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 50,
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 22050,
"frame_shift": 256,
"frame_length": 1024,
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
"n_mels": 80,
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def get_model(params: AttributeDict) -> nn.Module:
mel_loss_params = {
"n_mels": params.n_mels,
"frame_length": params.frame_length,
"frame_shift": params.frame_shift,
}
model = VITS(
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
lambda_feat_match=params.lambda_feat_match,
lambda_dur=params.lambda_dur,
lambda_kl=params.lambda_kl,
)
return model
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
"""Parse batch data"""
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer_g: Optimizer,
optimizer_d: Optimizer,
scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
tokenizer:
Used to convert text to phonemes.
optimizer_g:
The optimizer for generator.
optimizer_d:
The optimizer for discriminator.
scheduler_g:
The learning rate scheduler for generator, we call step() every epoch.
scheduler_d:
The learning rate scheduler for discriminator, we call step() every epoch.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
if model.module.generator.use_noised_mas:
# MAS with Gaussian Noise
model.module.generator.noise_current_mas = max(
model.module.generator.noise_initial_mas
- model.module.generator.noise_scale_mas
* params.batch_idx_train
* 0.25,
0.0,
)
try:
with autocast(enabled=params.use_fp16):
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
optimizer_d.zero_grad()
scaler.scale(loss_d).backward()
scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16):
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
scaler.step(optimizer_g)
scaler.update()
# summary stats
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.print_diagnostics and batch_idx == 5:
return
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_lr_g = max(scheduler_g.get_last_lr())
cur_lr_d = max(scheduler_d.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
)
tb_writer.add_scalar(
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
"train/speech_hat_",
speech_hat_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/speech_",
speech_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_image(
"train/mel_hat_",
plot_feature(mel_hat_),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
"train/mel_",
plot_feature(mel_),
params.batch_idx_train,
dataformats="HWC",
)
if (
params.batch_idx_train % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valdi_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valdi_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
returned_sample = None
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# summary stats
tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
)
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss, returned_sample
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
optimizer_g.zero_grad()
loss_g.backward()
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
model = get_model(params)
generator = model.generator
discriminator = model.discriminator
num_param_g = sum([p.numel() for p in generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW(
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
optimizer_d = torch.optim.AdamW(
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
if checkpoints is not None:
# load state_dict for optimizers
if "optimizer_g" in checkpoints:
logging.info("Loading optimizer_g state dict")
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
if "optimizer_d" in checkpoints:
logging.info("Loading optimizer_d state dict")
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
# load state_dict for schedulers
if "scheduler_g" in checkpoints:
logging.info("Loading scheduler_g state dict")
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
if "scheduler_d" in checkpoints:
logging.info("Loading scheduler_d state dict")
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = ljspeech.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()
scheduler_d.step()
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,218 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
"""Flow-related transformation.
This code is derived from https://github.com/bayesiains/nflows.
"""
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
# TODO(kan-bayashi): Documentation and type hint
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
# TODO(kan-bayashi): Documentation and type hint
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
# TODO(kan-bayashi): Documentation and type hint
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = _searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1

View File

@ -0,0 +1,327 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LJSpeechTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,265 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
def get_random_segments(
x: torch.Tensor,
x_lengths: torch.Tensor,
segment_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = x.size()
max_start_idx = x_lengths - segment_size
max_start_idx[max_start_idx < 0] = 0
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
dtype=torch.long,
)
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
def get_segments(
x: torch.Tensor,
start_idxs: torch.Tensor,
segment_size: int,
) -> torch.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = x.size()
segments = x.new_zeros(b, c, segment_size)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx : start_idx + segment_size]
return segments
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
def intersperse(sequence, item=0):
result = [item] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
MATPLOTLIB_FLAG = False
def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
samples = "%.2f" % self["samples"]
ans += "over " + str(samples) + " samples."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('loss_1', 0.1), ('loss_2', 0.07)]
"""
samples = self["samples"] if "samples" in self else 1
ans = []
for k, v in self.items():
if k == "samples":
continue
norm_value = float(v) / samples
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
# checkpoint saving and loading
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)

View File

@ -0,0 +1,610 @@
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""VITS module for GAN-TTS task."""
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from generator import VITSGenerator
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
HiFiGANMultiScaleMultiPeriodDiscriminator,
HiFiGANPeriodDiscriminator,
HiFiGANScaleDiscriminator,
)
from loss import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from utils import get_segments
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
}
AVAILABLE_DISCRIMINATORS = {
"hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
"hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
"hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
"hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
def __init__(
self,
# generator related
vocab_size: int,
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
"langs": None,
"spk_embed_dim": None,
"global_channels": -1,
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_cnn_module_kernel": 5,
"text_encoder_blocks": 6,
"text_encoder_dropout_rate": 0.1,
"decoder_kernel_size": 7,
"decoder_channels": 512,
"decoder_upsample_scales": [8, 8, 2, 2],
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
"decoder_resblock_kernel_sizes": [3, 7, 11],
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_weight_norm_in_decoder": True,
"posterior_encoder_kernel_size": 5,
"posterior_encoder_layers": 16,
"posterior_encoder_stacks": 1,
"posterior_encoder_base_dilation": 1,
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 5,
"flow_base_dilation": 1,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
"stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3,
"use_noised_mas": True,
"noise_initial_mas": 0.01,
"noise_scale_mas": 2e-06,
},
# discriminator related
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
discriminator_params: Dict[str, Any] = {
"scales": 1,
"scale_downsample_pooling": "AvgPool1d",
"scale_downsample_pooling_params": {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
"scale_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
"follow_official_norm": False,
"periods": [2, 3, 5, 7, 11],
"period_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
},
# loss related
generator_adv_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"loss_type": "mse",
},
discriminator_adv_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"loss_type": "mse",
},
feat_match_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"average_by_layers": False,
"include_final_outputs": True,
},
mel_loss_params: Dict[str, Any] = {
"frame_shift": 256,
"frame_length": 1024,
"n_mels": 80,
},
lambda_adv: float = 1.0,
lambda_mel: float = 45.0,
lambda_feat_match: float = 2.0,
lambda_dur: float = 1.0,
lambda_kl: float = 1.0,
cache_generator_outputs: bool = True,
):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
adversarial loss.
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
discriminator adversarial loss.
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
lambda_adv (float): Loss scaling coefficient for adversarial loss.
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
lambda_dur (float): Loss scaling coefficient for duration loss.
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
cache_generator_outputs (bool): Whether to cache generator outputs.
"""
super().__init__()
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
# NOTE(kan-bayashi): Update parameters for the compatibility.
# The idim and odim is automatically decided from input data,
# where idim represents #vocabularies and odim represents
# the input acoustic feature dimension.
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
self.generator = generator_class(
**generator_params,
)
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
self.discriminator = discriminator_class(
**discriminator_params,
)
self.generator_adv_loss = GeneratorAdversarialLoss(
**generator_adv_loss_params,
)
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
**discriminator_adv_loss_params,
)
self.feat_match_loss = FeatureMatchLoss(
**feat_match_loss_params,
)
mel_loss_params.update(sampling_rate=sampling_rate)
self.mel_loss = MelSpectrogramLoss(
**mel_loss_params,
)
self.kl_loss = KLDivergenceLoss()
# coefficients
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_kl = lambda_kl
self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# store sampling rate for saving wav file
# (not used for the training)
self.sampling_rate = sampling_rate
# store parameters for test compatibility
self.spks = self.generator.spks
self.langs = self.generator.langs
self.spk_embed_dim = self.generator.spk_embed_dim
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_sample: bool = False,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
forward_generator: bool = True,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator.
Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
"""
if forward_generator:
return self._forward_generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
return_sample=return_sample,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
def _forward_generator(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_sample: bool = False,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = outs
# parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size * self.generator.upsample_factor,
)
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_)
with torch.no_grad():
# do not store discriminator gradient in generator turn
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_)
else:
mel_loss, (mel_hat_, mel_) = self.mel_loss(
speech_hat_, speech_, return_mel=True
)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
stats = dict(
generator_loss=loss.item(),
generator_mel_loss=mel_loss.item(),
generator_kl_loss=kl_loss.item(),
generator_dur_loss=dur_loss.item(),
generator_adv_loss=adv_loss.item(),
generator_feat_match_loss=feat_match_loss.item(),
)
if return_sample:
stats["returned_sample"] = (
speech_hat_[0].data.cpu().numpy(),
speech_[0].data.cpu().numpy(),
mel_hat_[0].data.cpu().numpy(),
mel_[0].data.cpu().numpy(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def _forward_discrminator(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Perform discriminator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not reuse_cache:
self._cache = outs
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size * self.generator.upsample_factor,
)
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_.detach())
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss
stats = dict(
discriminator_loss=loss.item(),
discriminator_real_loss=real_loss.item(),
discriminator_fake_loss=fake_loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def inference(
self,
text: torch.Tensor,
feats: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference for single sample.
Args:
text (Tensor): Input text index tensor (T_text,).
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids (Tensor): Speaker index tensor (1,).
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
lids (Tensor): Language index tensor (1,).
durations (Tensor): Ground-truth duration tensor (T_text,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
text_lengths = torch.tensor(
[text.size(1)],
dtype=torch.long,
device=text.device,
)
if sids is not None:
sids = sids.view(1)
if lids is not None:
lids = lids.view(1)
if durations is not None:
durations = durations.view(1, 1, -1)
# inference
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose(1, 2)
feats_lengths = torch.tensor(
[feats.size(2)],
dtype=torch.long,
device=feats.device,
)
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
max_len=max_len,
use_teacher_forcing=use_teacher_forcing,
)
else:
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav.view(-1), att_w[0], dur[0]
def inference_batch(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference for one batch.
Args:
text (Tensor): Input text index tensor (B, T_text).
text_lengths (Tensor): Input text index tensor (B,).
sids (Tensor): Speaker index tensor (B,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
Returns:
* wav (Tensor): Generated waveform tensor (B, T_wav).
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
* duration (Tensor): Predicted duration tensor (B, T_text).
"""
# inference
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav, att_w, dur

View File

@ -0,0 +1,348 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""WaveNet modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import logging
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
class WaveNet(torch.nn.Module):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_size: int = 3,
layers: int = 30,
stacks: int = 3,
base_dilation: int = 2,
residual_channels: int = 64,
aux_channels: int = -1,
gate_channels: int = 128,
skip_channels: int = 64,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
use_first_conv: bool = False,
use_last_conv: bool = False,
scale_residual: bool = False,
scale_skip_connect: bool = False,
):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(layers):
dilation = base_dilation ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual,
)
self.conv_layers += [conv]
# define output layers
if self.use_last_conv:
self.last_conv = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, skip_channels, bias=True),
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, out_channels, bias=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(
layers: int,
stacks: int,
kernel_size: int,
base_dilation: int,
) -> int:
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self) -> int:
"""Return receptive field size."""
return self._get_receptive_field_size(
self.layers, self.stacks, self.kernel_size, self.base_dilation
)
class Conv1d(torch.nn.Conv1d):
"""Conv1d module with customized initialization."""
def __init__(self, *args, **kwargs):
"""Initialize Conv1d module."""
super().__init__(*args, **kwargs)
def reset_parameters(self):
"""Reset parameters."""
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
if self.bias is not None:
torch.nn.init.constant_(self.bias, 0.0)
class Conv1d1x1(Conv1d):
"""1x1 Conv1d with customized initialization."""
def __init__(self, in_channels: int, out_channels: int, bias: bool):
"""Initialize 1x1 Conv1d module."""
super().__init__(
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
)
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int = 3,
residual_channels: int = 64,
gate_channels: int = 128,
skip_channels: int = 64,
aux_channels: int = 80,
global_channels: int = -1,
dropout_rate: float = 0.0,
dilation: int = 1,
bias: bool = True,
scale_residual: bool = False,
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = Conv1d(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = Conv1d1x1(
gate_out_channels, residual_channels + skip_channels, bias=bias
)
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ga, xb + gb
x = torch.tanh(xa) * torch.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s