mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Add MAS to VIT-1
This commit is contained in:
parent
027302c902
commit
cafc33bac9
3
egs/ljspeech/TTS/vits2/README.md
Normal file
3
egs/ljspeech/TTS/vits2/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
|
||||||
|
|
||||||
|
Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29.
|
193
egs/ljspeech/TTS/vits2/duration_predictor.py
Normal file
193
egs/ljspeech/TTS/vits2/duration_predictor.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Stochastic duration predictor modules in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from flow import (
|
||||||
|
ConvFlow,
|
||||||
|
DilatedDepthSeparableConv,
|
||||||
|
ElementwiseAffineFlow,
|
||||||
|
FlipFlow,
|
||||||
|
LogFlow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticDurationPredictor(torch.nn.Module):
|
||||||
|
"""Stochastic duration predictor module.
|
||||||
|
|
||||||
|
This is a module of stochastic duration predictor described in `Conditional
|
||||||
|
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int = 192,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dropout_rate: float = 0.5,
|
||||||
|
flows: int = 4,
|
||||||
|
dds_conv_layers: int = 3,
|
||||||
|
global_channels: int = -1,
|
||||||
|
):
|
||||||
|
"""Initialize StochasticDurationPredictor module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of channels.
|
||||||
|
kernel_size (int): Kernel size.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
flows (int): Number of flows.
|
||||||
|
dds_conv_layers (int): Number of conv layers in DDS conv.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pre = torch.nn.Conv1d(channels, channels, 1)
|
||||||
|
self.dds = DilatedDepthSeparableConv(
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
)
|
||||||
|
self.proj = torch.nn.Conv1d(channels, channels, 1)
|
||||||
|
|
||||||
|
self.log_flow = LogFlow()
|
||||||
|
self.flows = torch.nn.ModuleList()
|
||||||
|
self.flows += [ElementwiseAffineFlow(2)]
|
||||||
|
for i in range(flows):
|
||||||
|
self.flows += [
|
||||||
|
ConvFlow(
|
||||||
|
2,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.flows += [FlipFlow()]
|
||||||
|
|
||||||
|
self.post_pre = torch.nn.Conv1d(1, channels, 1)
|
||||||
|
self.post_dds = DilatedDepthSeparableConv(
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
)
|
||||||
|
self.post_proj = torch.nn.Conv1d(channels, channels, 1)
|
||||||
|
self.post_flows = torch.nn.ModuleList()
|
||||||
|
self.post_flows += [ElementwiseAffineFlow(2)]
|
||||||
|
for i in range(flows):
|
||||||
|
self.post_flows += [
|
||||||
|
ConvFlow(
|
||||||
|
2,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.post_flows += [FlipFlow()]
|
||||||
|
|
||||||
|
if global_channels > 0:
|
||||||
|
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
w: Optional[torch.Tensor] = None,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
inverse: bool = False,
|
||||||
|
noise_scale: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T_text).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T_text).
|
||||||
|
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
noise_scale (float): Noise scale value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
|
||||||
|
If inverse, log-duration tensor (B, 1, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = x.detach() # stop gradient
|
||||||
|
x = self.pre(x)
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.global_conv(g.detach()) # stop gradient
|
||||||
|
x = self.dds(x, x_mask)
|
||||||
|
x = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
if not inverse:
|
||||||
|
assert w is not None, "w must be provided."
|
||||||
|
h_w = self.post_pre(w)
|
||||||
|
h_w = self.post_dds(h_w, x_mask)
|
||||||
|
h_w = self.post_proj(h_w) * x_mask
|
||||||
|
e_q = (
|
||||||
|
torch.randn(
|
||||||
|
w.size(0),
|
||||||
|
2,
|
||||||
|
w.size(2),
|
||||||
|
).to(device=x.device, dtype=x.dtype)
|
||||||
|
* x_mask
|
||||||
|
)
|
||||||
|
z_q = e_q
|
||||||
|
logdet_tot_q = 0.0
|
||||||
|
for flow in self.post_flows:
|
||||||
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||||
|
logdet_tot_q += logdet_q
|
||||||
|
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||||
|
u = torch.sigmoid(z_u) * x_mask
|
||||||
|
z0 = (w - u) * x_mask
|
||||||
|
logdet_tot_q += torch.sum(
|
||||||
|
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||||
|
)
|
||||||
|
logq = (
|
||||||
|
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||||
|
- logdet_tot_q
|
||||||
|
)
|
||||||
|
|
||||||
|
logdet_tot = 0
|
||||||
|
z0, logdet = self.log_flow(z0, x_mask)
|
||||||
|
logdet_tot += logdet
|
||||||
|
z = torch.cat([z0, z1], 1)
|
||||||
|
for flow in self.flows:
|
||||||
|
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
|
||||||
|
logdet_tot = logdet_tot + logdet
|
||||||
|
nll = (
|
||||||
|
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||||
|
- logdet_tot
|
||||||
|
)
|
||||||
|
return nll + logq # (B,)
|
||||||
|
else:
|
||||||
|
flows = list(reversed(self.flows))
|
||||||
|
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||||
|
z = (
|
||||||
|
torch.randn(
|
||||||
|
x.size(0),
|
||||||
|
2,
|
||||||
|
x.size(2),
|
||||||
|
).to(device=x.device, dtype=x.dtype)
|
||||||
|
* noise_scale
|
||||||
|
)
|
||||||
|
for flow in flows:
|
||||||
|
z = flow(z, x_mask, g=x, inverse=inverse)
|
||||||
|
z0, z1 = z.split(1, 1)
|
||||||
|
logw = z0
|
||||||
|
return logw
|
271
egs/ljspeech/TTS/vits2/export-onnx.py
Executable file
271
egs/ljspeech/TTS/vits2/export-onnx.py
Executable file
@ -0,0 +1,271 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a VITS model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
Export the model to ONNX:
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
It will generate two files inside vits/exp:
|
||||||
|
- vits-epoch-1000.onnx
|
||||||
|
- vits-epoch-1000.int8.onnx (quantizated model)
|
||||||
|
|
||||||
|
See ./test_onnx.py for how to use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vits/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel(nn.Module):
|
||||||
|
"""A wrapper for VITS generator."""
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
A VITS generator.
|
||||||
|
frame_shift:
|
||||||
|
The frame shift in samples.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
tokens_lens: torch.Tensor,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of VITS.inference_batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
Input text token indexes (1, T_text)
|
||||||
|
tokens_lens:
|
||||||
|
Number of tokens of shape (1,)
|
||||||
|
noise_scale (float):
|
||||||
|
Noise scale parameter for flow.
|
||||||
|
noise_scale_dur (float):
|
||||||
|
Noise scale parameter for duration predictor.
|
||||||
|
alpha (float):
|
||||||
|
Alpha parameter to control the speed of generated speech.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- audio, generated wavform tensor, (B, T_wav)
|
||||||
|
"""
|
||||||
|
audio, _, _ = self.model.inference(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def export_model_onnx(
|
||||||
|
model: nn.Module,
|
||||||
|
model_filename: str,
|
||||||
|
vocab_size: int,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given generator model to ONNX format.
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has one output:
|
||||||
|
|
||||||
|
- audio, a tensor of shape (1, T'); dtype is torch.float32
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The VITS generator.
|
||||||
|
model_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
vocab_size:
|
||||||
|
Number of tokens used in training.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||||
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||||
|
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||||
|
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||||
|
alpha = torch.tensor([1], dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
|
||||||
|
model_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"tokens",
|
||||||
|
"tokens_lens",
|
||||||
|
"noise_scale",
|
||||||
|
"alpha",
|
||||||
|
"noise_scale_dur",
|
||||||
|
],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"tokens": {0: "N", 1: "T"},
|
||||||
|
"tokens_lens": {0: "N"},
|
||||||
|
"audio": {0: "N", 1: "T"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "VITS",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "VITS generator",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.blank_id
|
||||||
|
params.oov_id = tokenizer.oov_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
|
||||||
|
model = model.generator
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = OnnxModel(model=model)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"generator parameters: {num_param}")
|
||||||
|
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
|
||||||
|
export_model_onnx(
|
||||||
|
model,
|
||||||
|
model_filename,
|
||||||
|
params.vocab_size,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported generator to {model_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=model_filename,
|
||||||
|
model_output=model_filename_int8,
|
||||||
|
weight_type=QuantType.QUInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
311
egs/ljspeech/TTS/vits2/flow.py
Normal file
311
egs/ljspeech/TTS/vits2/flow.py
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Basic Flow modules used in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transform import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
class FlipFlow(torch.nn.Module):
|
||||||
|
"""Flip flow module."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Flipped tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = torch.flip(x, [1])
|
||||||
|
if not inverse:
|
||||||
|
logdet = x.new_zeros(x.size(0))
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LogFlow(torch.nn.Module):
|
||||||
|
"""Log flow module."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
inverse: bool = False,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
eps (float): Epsilon for log.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not inverse:
|
||||||
|
y = torch.log(torch.clamp_min(x, eps)) * x_mask
|
||||||
|
logdet = torch.sum(-y, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
else:
|
||||||
|
x = torch.exp(x) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAffineFlow(torch.nn.Module):
|
||||||
|
"""Elementwise affine flow module."""
|
||||||
|
|
||||||
|
def __init__(self, channels: int):
|
||||||
|
"""Initialize ElementwiseAffineFlow module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of channels.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
|
||||||
|
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not inverse:
|
||||||
|
y = self.m + torch.exp(self.logs) * x
|
||||||
|
y = y * x_mask
|
||||||
|
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
else:
|
||||||
|
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(torch.nn.Module):
|
||||||
|
"""Transpose module for torch.nn.Sequential()."""
|
||||||
|
|
||||||
|
def __init__(self, dim1: int, dim2: int):
|
||||||
|
"""Initialize Transpose module."""
|
||||||
|
super().__init__()
|
||||||
|
self.dim1 = dim1
|
||||||
|
self.dim2 = dim2
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Transpose."""
|
||||||
|
return x.transpose(self.dim1, self.dim2)
|
||||||
|
|
||||||
|
|
||||||
|
class DilatedDepthSeparableConv(torch.nn.Module):
|
||||||
|
"""Dilated depth-separable conv module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
layers: int,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
"""Initialize DilatedDepthSeparableConv module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): Number of channels.
|
||||||
|
kernel_size (int): Kernel size.
|
||||||
|
layers (int): Number of layers.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
eps (float): Epsilon for layer norm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs = torch.nn.ModuleList()
|
||||||
|
for i in range(layers):
|
||||||
|
dilation = kernel_size**i
|
||||||
|
padding = (kernel_size * dilation - dilation) // 2
|
||||||
|
self.convs += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
groups=channels,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.LayerNorm(
|
||||||
|
channels,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
),
|
||||||
|
Transpose(1, 2),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Dropout(dropout_rate),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
if g is not None:
|
||||||
|
x = x + g
|
||||||
|
for f in self.convs:
|
||||||
|
y = f(x * x_mask)
|
||||||
|
x = x + y
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFlow(torch.nn.Module):
|
||||||
|
"""Convolutional flow module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
layers: int,
|
||||||
|
bins: int = 10,
|
||||||
|
tail_bound: float = 5.0,
|
||||||
|
):
|
||||||
|
"""Initialize ConvFlow module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size.
|
||||||
|
layers (int): Number of layers.
|
||||||
|
bins (int): Number of bins.
|
||||||
|
tail_bound (float): Tail bound value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.half_channels = in_channels // 2
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.bins = bins
|
||||||
|
self.tail_bound = tail_bound
|
||||||
|
|
||||||
|
self.input_conv = torch.nn.Conv1d(
|
||||||
|
self.half_channels,
|
||||||
|
hidden_channels,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
self.dds_conv = DilatedDepthSeparableConv(
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
layers,
|
||||||
|
dropout_rate=0.0,
|
||||||
|
)
|
||||||
|
self.proj = torch.nn.Conv1d(
|
||||||
|
hidden_channels,
|
||||||
|
self.half_channels * (bins * 3 - 1),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
self.proj.weight.data.zero_()
|
||||||
|
self.proj.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
inverse: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B,).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
|
||||||
|
"""
|
||||||
|
xa, xb = x.split(x.size(1) // 2, 1)
|
||||||
|
h = self.input_conv(xa)
|
||||||
|
h = self.dds_conv(h, x_mask, g=g)
|
||||||
|
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
|
||||||
|
|
||||||
|
b, c, t = xa.shape
|
||||||
|
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
|
||||||
|
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
|
||||||
|
|
||||||
|
# TODO(kan-bayashi): Understand this calculation
|
||||||
|
denom = math.sqrt(self.hidden_channels)
|
||||||
|
unnorm_widths = h[..., : self.bins] / denom
|
||||||
|
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
|
||||||
|
unnorm_derivatives = h[..., 2 * self.bins :]
|
||||||
|
xb, logdet_abs = piecewise_rational_quadratic_transform(
|
||||||
|
xb,
|
||||||
|
unnorm_widths,
|
||||||
|
unnorm_heights,
|
||||||
|
unnorm_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=self.tail_bound,
|
||||||
|
)
|
||||||
|
x = torch.cat([xa, xb], 1) * x_mask
|
||||||
|
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
|
||||||
|
if not inverse:
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
return x
|
549
egs/ljspeech/TTS/vits2/generator.py
Normal file
549
egs/ljspeech/TTS/vits2/generator.py
Normal file
@ -0,0 +1,549 @@
|
|||||||
|
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Generator module in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from duration_predictor import StochasticDurationPredictor
|
||||||
|
from hifigan import HiFiGANGenerator
|
||||||
|
from posterior_encoder import PosteriorEncoder
|
||||||
|
from residual_coupling import ResidualAffineCouplingBlock
|
||||||
|
from text_encoder import TextEncoder
|
||||||
|
from utils import get_random_segments
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class VITSGenerator(torch.nn.Module):
|
||||||
|
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||||
|
with Adversarial Learning for End-to-End Text-to-Speech`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabs: int,
|
||||||
|
aux_channels: int = 513,
|
||||||
|
hidden_channels: int = 192,
|
||||||
|
spks: Optional[int] = None,
|
||||||
|
langs: Optional[int] = None,
|
||||||
|
spk_embed_dim: Optional[int] = None,
|
||||||
|
global_channels: int = -1,
|
||||||
|
segment_size: int = 32,
|
||||||
|
text_encoder_attention_heads: int = 2,
|
||||||
|
text_encoder_ffn_expand: int = 4,
|
||||||
|
text_encoder_cnn_module_kernel: int = 5,
|
||||||
|
text_encoder_blocks: int = 6,
|
||||||
|
text_encoder_dropout_rate: float = 0.1,
|
||||||
|
decoder_kernel_size: int = 7,
|
||||||
|
decoder_channels: int = 512,
|
||||||
|
decoder_upsample_scales: List[int] = [8, 8, 2, 2],
|
||||||
|
decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
|
||||||
|
decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||||
|
decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
use_weight_norm_in_decoder: bool = True,
|
||||||
|
posterior_encoder_kernel_size: int = 5,
|
||||||
|
posterior_encoder_layers: int = 16,
|
||||||
|
posterior_encoder_stacks: int = 1,
|
||||||
|
posterior_encoder_base_dilation: int = 1,
|
||||||
|
posterior_encoder_dropout_rate: float = 0.0,
|
||||||
|
use_weight_norm_in_posterior_encoder: bool = True,
|
||||||
|
flow_flows: int = 4,
|
||||||
|
flow_kernel_size: int = 5,
|
||||||
|
flow_base_dilation: int = 1,
|
||||||
|
flow_layers: int = 4,
|
||||||
|
flow_dropout_rate: float = 0.0,
|
||||||
|
use_weight_norm_in_flow: bool = True,
|
||||||
|
use_only_mean_in_flow: bool = True,
|
||||||
|
stochastic_duration_predictor_kernel_size: int = 3,
|
||||||
|
stochastic_duration_predictor_dropout_rate: float = 0.5,
|
||||||
|
stochastic_duration_predictor_flows: int = 4,
|
||||||
|
stochastic_duration_predictor_dds_conv_layers: int = 3,
|
||||||
|
use_noised_mas: bool = True,
|
||||||
|
noise_initial_mas: float = 0.01,
|
||||||
|
noise_scale_mas: float = 2e-6,
|
||||||
|
):
|
||||||
|
"""Initialize VITS generator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocabs (int): Input vocabulary size.
|
||||||
|
aux_channels (int): Number of acoustic feature channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
|
||||||
|
sids will be provided as the input and use sid embedding layer.
|
||||||
|
langs (Optional[int]): Number of languages. If set to > 1, assume that the
|
||||||
|
lids will be provided as the input and use sid embedding layer.
|
||||||
|
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
|
||||||
|
assume that spembs will be provided as the input.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
segment_size (int): Segment size for decoder.
|
||||||
|
text_encoder_attention_heads (int): Number of heads in conformer block
|
||||||
|
of text encoder.
|
||||||
|
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
|
||||||
|
of text encoder.
|
||||||
|
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
|
||||||
|
text_encoder_blocks (int): Number of conformer blocks in text encoder.
|
||||||
|
text_encoder_dropout_rate (float): Dropout rate in conformer block of
|
||||||
|
text encoder.
|
||||||
|
decoder_kernel_size (int): Decoder kernel size.
|
||||||
|
decoder_channels (int): Number of decoder initial channels.
|
||||||
|
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
|
||||||
|
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
|
||||||
|
upsampling layers in decoder.
|
||||||
|
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
|
||||||
|
in decoder.
|
||||||
|
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
|
||||||
|
resblocks in decoder.
|
||||||
|
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
|
||||||
|
decoder.
|
||||||
|
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
|
||||||
|
posterior_encoder_layers (int): Number of layers of posterior encoder.
|
||||||
|
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
|
||||||
|
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
|
||||||
|
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
|
||||||
|
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
|
||||||
|
normalization in posterior encoder.
|
||||||
|
flow_flows (int): Number of flows in flow.
|
||||||
|
flow_kernel_size (int): Kernel size in flow.
|
||||||
|
flow_base_dilation (int): Base dilation in flow.
|
||||||
|
flow_layers (int): Number of layers in flow.
|
||||||
|
flow_dropout_rate (float): Dropout rate in flow
|
||||||
|
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
|
||||||
|
flow.
|
||||||
|
use_only_mean_in_flow (bool): Whether to use only mean in flow.
|
||||||
|
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
|
||||||
|
duration predictor.
|
||||||
|
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
|
||||||
|
stochastic duration predictor.
|
||||||
|
stochastic_duration_predictor_flows (int): Number of flows in stochastic
|
||||||
|
duration predictor.
|
||||||
|
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
|
||||||
|
layers in stochastic duration predictor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.segment_size = segment_size
|
||||||
|
self.text_encoder = TextEncoder(
|
||||||
|
vocabs=vocabs,
|
||||||
|
d_model=hidden_channels,
|
||||||
|
num_heads=text_encoder_attention_heads,
|
||||||
|
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
|
||||||
|
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
||||||
|
num_layers=text_encoder_blocks,
|
||||||
|
dropout=text_encoder_dropout_rate,
|
||||||
|
)
|
||||||
|
self.decoder = HiFiGANGenerator(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
out_channels=1,
|
||||||
|
channels=decoder_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
kernel_size=decoder_kernel_size,
|
||||||
|
upsample_scales=decoder_upsample_scales,
|
||||||
|
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
|
||||||
|
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
|
||||||
|
resblock_dilations=decoder_resblock_dilations,
|
||||||
|
use_weight_norm=use_weight_norm_in_decoder,
|
||||||
|
)
|
||||||
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
in_channels=aux_channels,
|
||||||
|
out_channels=hidden_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
kernel_size=posterior_encoder_kernel_size,
|
||||||
|
layers=posterior_encoder_layers,
|
||||||
|
stacks=posterior_encoder_stacks,
|
||||||
|
base_dilation=posterior_encoder_base_dilation,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=posterior_encoder_dropout_rate,
|
||||||
|
use_weight_norm=use_weight_norm_in_posterior_encoder,
|
||||||
|
)
|
||||||
|
self.flow = ResidualAffineCouplingBlock(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
flows=flow_flows,
|
||||||
|
kernel_size=flow_kernel_size,
|
||||||
|
base_dilation=flow_base_dilation,
|
||||||
|
layers=flow_layers,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=flow_dropout_rate,
|
||||||
|
use_weight_norm=use_weight_norm_in_flow,
|
||||||
|
use_only_mean=use_only_mean_in_flow,
|
||||||
|
)
|
||||||
|
# TODO(kan-bayashi): Add deterministic version as an option
|
||||||
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
|
channels=hidden_channels,
|
||||||
|
kernel_size=stochastic_duration_predictor_kernel_size,
|
||||||
|
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
||||||
|
flows=stochastic_duration_predictor_flows,
|
||||||
|
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
|
||||||
|
global_channels=global_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||||
|
|
||||||
|
# MAS with Gaussian Noise params
|
||||||
|
self.use_noised_mas = use_noised_mas
|
||||||
|
self.noise_current_mas = noise_initial_mas
|
||||||
|
self.noise_scale_mas = noise_scale_mas
|
||||||
|
self.noise_initial_mas = noise_initial_mas
|
||||||
|
|
||||||
|
self.spks = None
|
||||||
|
if spks is not None and spks > 1:
|
||||||
|
assert global_channels > 0
|
||||||
|
self.spks = spks
|
||||||
|
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
||||||
|
self.spk_embed_dim = None
|
||||||
|
if spk_embed_dim is not None and spk_embed_dim > 0:
|
||||||
|
assert global_channels > 0
|
||||||
|
self.spk_embed_dim = spk_embed_dim
|
||||||
|
self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
|
||||||
|
self.langs = None
|
||||||
|
if langs is not None and langs > 1:
|
||||||
|
assert global_channels > 0
|
||||||
|
self.langs = langs
|
||||||
|
self.lang_emb = torch.nn.Embedding(langs, global_channels)
|
||||||
|
|
||||||
|
# delayed import
|
||||||
|
from monotonic_align import maximum_path
|
||||||
|
|
||||||
|
self.maximum_path = maximum_path
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
feats: torch.Tensor,
|
||||||
|
feats_lengths: torch.Tensor,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
Tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
],
|
||||||
|
]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
|
||||||
|
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
|
||||||
|
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
|
||||||
|
Tensor: Segments start index tensor (B,).
|
||||||
|
Tensor: Text mask tensor (B, 1, T_text).
|
||||||
|
Tensor: Feature mask tensor (B, 1, T_feats).
|
||||||
|
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||||
|
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
|
||||||
|
- Tensor: Flow hidden representation (B, H, T_feats).
|
||||||
|
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
|
||||||
|
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
|
||||||
|
- Tensor: Posterior encoder projected mean (B, H, T_feats).
|
||||||
|
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# forward text encoder
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||||
|
|
||||||
|
# calculate global conditioning
|
||||||
|
g = None
|
||||||
|
if self.spks is not None:
|
||||||
|
# speaker one-hot vector embedding: (B, global_channels, 1)
|
||||||
|
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
|
||||||
|
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
if self.langs is not None:
|
||||||
|
# language one-hot vector embedding: (B, global_channels, 1)
|
||||||
|
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
|
||||||
|
# forward posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||||
|
|
||||||
|
# forward flow
|
||||||
|
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
|
||||||
|
|
||||||
|
# monotonic alignment search
|
||||||
|
with torch.no_grad():
|
||||||
|
# negative cross-entropy
|
||||||
|
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_1 = torch.sum(
|
||||||
|
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||||
|
[1],
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_2 = torch.matmul(
|
||||||
|
-0.5 * (z_p**2).transpose(1, 2),
|
||||||
|
s_p_sq_r,
|
||||||
|
)
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_3 = torch.matmul(
|
||||||
|
z_p.transpose(1, 2),
|
||||||
|
(m_p * s_p_sq_r),
|
||||||
|
)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_4 = torch.sum(
|
||||||
|
-0.5 * (m_p**2) * s_p_sq_r,
|
||||||
|
[1],
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
# (B, T_feats, T_text)
|
||||||
|
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||||
|
|
||||||
|
if self.use_noised_mas:
|
||||||
|
epsilon = (
|
||||||
|
torch.std(neg_x_ent)
|
||||||
|
* torch.randn_like(neg_x_ent)
|
||||||
|
* self.noise_current_mas
|
||||||
|
)
|
||||||
|
neg_x_ent = neg_x_ent + epsilon
|
||||||
|
|
||||||
|
# (B, 1, T_feats, T_text)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
|
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||||
|
attn = (
|
||||||
|
self.maximum_path(
|
||||||
|
neg_x_ent,
|
||||||
|
attn_mask.squeeze(1),
|
||||||
|
)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
# forward duration predictor
|
||||||
|
w = attn.sum(2) # (B, 1, T_text)
|
||||||
|
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
||||||
|
dur_nll = dur_nll / torch.sum(x_mask)
|
||||||
|
|
||||||
|
# expand the length to match with the feature sequence
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
# get random segments
|
||||||
|
z_segments, z_start_idxs = get_random_segments(
|
||||||
|
z,
|
||||||
|
feats_lengths,
|
||||||
|
self.segment_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# forward decoder with random segments
|
||||||
|
wav = self.decoder(z_segments, g=g)
|
||||||
|
|
||||||
|
return (
|
||||||
|
wav,
|
||||||
|
dur_nll,
|
||||||
|
attn,
|
||||||
|
z_start_idxs,
|
||||||
|
x_mask,
|
||||||
|
y_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
)
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
feats: Optional[torch.Tensor] = None,
|
||||||
|
feats_lengths: Optional[torch.Tensor] = None,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
dur: Optional[torch.Tensor] = None,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
max_len: Optional[int] = None,
|
||||||
|
use_teacher_forcing: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Run inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Input text index tensor (B, T_text,).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
|
||||||
|
skip the prediction of durations (i.e., teacher forcing).
|
||||||
|
noise_scale (float): Noise scale parameter for flow.
|
||||||
|
noise_scale_dur (float): Noise scale parameter for duration predictor.
|
||||||
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
|
max_len (Optional[int]): Maximum length of acoustic feature sequence.
|
||||||
|
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Generated waveform tensor (B, T_wav).
|
||||||
|
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
|
||||||
|
Tensor: Duration tensor (B, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# encoder
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||||
|
x_mask = x_mask.to(x.dtype)
|
||||||
|
g = None
|
||||||
|
if self.spks is not None:
|
||||||
|
# (B, global_channels, 1)
|
||||||
|
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
# (B, global_channels, 1)
|
||||||
|
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
if self.langs is not None:
|
||||||
|
# (B, global_channels, 1)
|
||||||
|
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
|
||||||
|
if use_teacher_forcing:
|
||||||
|
# forward posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||||
|
|
||||||
|
# forward flow
|
||||||
|
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
|
||||||
|
|
||||||
|
# monotonic alignment search
|
||||||
|
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_1 = torch.sum(
|
||||||
|
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||||
|
[1],
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_2 = torch.matmul(
|
||||||
|
-0.5 * (z_p**2).transpose(1, 2),
|
||||||
|
s_p_sq_r,
|
||||||
|
)
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_3 = torch.matmul(
|
||||||
|
z_p.transpose(1, 2),
|
||||||
|
(m_p * s_p_sq_r),
|
||||||
|
)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_4 = torch.sum(
|
||||||
|
-0.5 * (m_p**2) * s_p_sq_r,
|
||||||
|
[1],
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
# (B, T_feats, T_text)
|
||||||
|
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||||
|
# (B, 1, T_feats, T_text)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
|
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||||
|
attn = self.maximum_path(
|
||||||
|
neg_x_ent,
|
||||||
|
attn_mask.squeeze(1),
|
||||||
|
).unsqueeze(1)
|
||||||
|
dur = attn.sum(2) # (B, 1, T_text)
|
||||||
|
|
||||||
|
# forward decoder with random segments
|
||||||
|
wav = self.decoder(z * y_mask, g=g)
|
||||||
|
else:
|
||||||
|
# duration
|
||||||
|
if dur is None:
|
||||||
|
logw = self.duration_predictor(
|
||||||
|
x,
|
||||||
|
x_mask,
|
||||||
|
g=g,
|
||||||
|
inverse=True,
|
||||||
|
noise_scale=noise_scale_dur,
|
||||||
|
)
|
||||||
|
w = torch.exp(logw) * x_mask * alpha
|
||||||
|
dur = torch.ceil(w)
|
||||||
|
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
|
||||||
|
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
|
||||||
|
y_mask = y_mask.to(x.dtype)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
|
attn = self._generate_path(dur, attn_mask)
|
||||||
|
|
||||||
|
# expand the length to match with the feature sequence
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
m_p = torch.matmul(
|
||||||
|
attn.squeeze(1),
|
||||||
|
m_p.transpose(1, 2),
|
||||||
|
).transpose(1, 2)
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
logs_p = torch.matmul(
|
||||||
|
attn.squeeze(1),
|
||||||
|
logs_p.transpose(1, 2),
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
z = self.flow(z_p, y_mask, g=g, inverse=True)
|
||||||
|
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
|
||||||
|
|
||||||
|
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
|
||||||
|
|
||||||
|
def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Generate path a.k.a. monotonic attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dur (Tensor): Duration tensor (B, 1, T_text).
|
||||||
|
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Path tensor (B, 1, T_feats, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
b, _, t_y, t_x = mask.shape
|
||||||
|
cum_dur = torch.cumsum(dur, -1)
|
||||||
|
cum_dur_flat = cum_dur.view(b * t_x)
|
||||||
|
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
|
||||||
|
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
||||||
|
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
|
||||||
|
path = path.view(b, t_x, t_y).to(dtype=torch.float)
|
||||||
|
# path will be like (t_x = 3, t_y = 5):
|
||||||
|
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
||||||
|
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
||||||
|
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
||||||
|
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
||||||
|
# path = path.to(dtype=mask.dtype)
|
||||||
|
return path.unsqueeze(1).transpose(2, 3) * mask
|
933
egs/ljspeech/TTS/vits2/hifigan.py
Normal file
933
egs/ljspeech/TTS/vits2/hifigan.py
Normal file
@ -0,0 +1,933 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""HiFi-GAN Modules.
|
||||||
|
|
||||||
|
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANGenerator(torch.nn.Module):
|
||||||
|
"""HiFiGAN generator module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 80,
|
||||||
|
out_channels: int = 1,
|
||||||
|
channels: int = 512,
|
||||||
|
global_channels: int = -1,
|
||||||
|
kernel_size: int = 7,
|
||||||
|
upsample_scales: List[int] = [8, 8, 2, 2],
|
||||||
|
upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
|
||||||
|
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||||
|
resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
use_additional_convs: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
nonlinear_activation: str = "LeakyReLU",
|
||||||
|
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize HiFiGANGenerator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
channels (int): Number of hidden representation channels.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
kernel_size (int): Kernel size of initial and final conv layer.
|
||||||
|
upsample_scales (List[int]): List of upsampling scales.
|
||||||
|
upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
|
||||||
|
resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
|
||||||
|
resblock_dilations (List[List[int]]): List of list of dilations for residual
|
||||||
|
blocks.
|
||||||
|
use_additional_convs (bool): Whether to use additional conv layers in
|
||||||
|
residual blocks.
|
||||||
|
bias (bool): Whether to add bias parameter in convolution layers.
|
||||||
|
nonlinear_activation (str): Activation function module name.
|
||||||
|
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||||
|
function.
|
||||||
|
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||||
|
be applied to all of the conv layers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# check hyperparameters are valid
|
||||||
|
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
||||||
|
assert len(upsample_scales) == len(upsample_kernel_sizes)
|
||||||
|
assert len(resblock_dilations) == len(resblock_kernel_sizes)
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
|
||||||
|
self.num_upsamples = len(upsample_kernel_sizes)
|
||||||
|
self.num_blocks = len(resblock_kernel_sizes)
|
||||||
|
self.input_conv = torch.nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
)
|
||||||
|
self.upsamples = torch.nn.ModuleList()
|
||||||
|
self.blocks = torch.nn.ModuleList()
|
||||||
|
for i in range(len(upsample_kernel_sizes)):
|
||||||
|
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
|
||||||
|
self.upsamples += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
getattr(torch.nn, nonlinear_activation)(
|
||||||
|
**nonlinear_activation_params
|
||||||
|
),
|
||||||
|
torch.nn.ConvTranspose1d(
|
||||||
|
channels // (2**i),
|
||||||
|
channels // (2 ** (i + 1)),
|
||||||
|
upsample_kernel_sizes[i],
|
||||||
|
upsample_scales[i],
|
||||||
|
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
|
||||||
|
output_padding=upsample_scales[i] % 2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for j in range(len(resblock_kernel_sizes)):
|
||||||
|
self.blocks += [
|
||||||
|
ResidualBlock(
|
||||||
|
kernel_size=resblock_kernel_sizes[j],
|
||||||
|
channels=channels // (2 ** (i + 1)),
|
||||||
|
dilations=resblock_dilations[j],
|
||||||
|
bias=bias,
|
||||||
|
use_additional_convs=use_additional_convs,
|
||||||
|
nonlinear_activation=nonlinear_activation,
|
||||||
|
nonlinear_activation_params=nonlinear_activation_params,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.output_conv = torch.nn.Sequential(
|
||||||
|
# NOTE(kan-bayashi): follow official implementation but why
|
||||||
|
# using different slope parameter here? (0.1 vs. 0.01)
|
||||||
|
torch.nn.LeakyReLU(),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
channels // (2 ** (i + 1)),
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
)
|
||||||
|
if global_channels > 0:
|
||||||
|
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
|
||||||
|
|
||||||
|
# apply weight norm
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
# reset parameters
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, out_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
c = self.input_conv(c)
|
||||||
|
if g is not None:
|
||||||
|
c = c + self.global_conv(g)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
c = self.upsamples[i](c)
|
||||||
|
cs = 0.0 # initialize
|
||||||
|
for j in range(self.num_blocks):
|
||||||
|
cs += self.blocks[i * self.num_blocks + j](c)
|
||||||
|
c = cs / self.num_blocks
|
||||||
|
c = self.output_conv(c)
|
||||||
|
|
||||||
|
return c
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
"""Reset parameters.
|
||||||
|
|
||||||
|
This initialization follows the official implementation manner.
|
||||||
|
https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _reset_parameters(m: torch.nn.Module):
|
||||||
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
|
||||||
|
m.weight.data.normal_(0.0, 0.01)
|
||||||
|
logging.debug(f"Reset parameters in {m}.")
|
||||||
|
|
||||||
|
self.apply(_reset_parameters)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
"""Remove weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _remove_weight_norm(m: torch.nn.Module):
|
||||||
|
try:
|
||||||
|
logging.debug(f"Weight norm is removed from {m}.")
|
||||||
|
torch.nn.utils.remove_weight_norm(m)
|
||||||
|
except ValueError: # this module didn't have weight norm
|
||||||
|
return
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
"""Apply weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _apply_weight_norm(m: torch.nn.Module):
|
||||||
|
if isinstance(m, torch.nn.Conv1d) or isinstance(
|
||||||
|
m, torch.nn.ConvTranspose1d
|
||||||
|
):
|
||||||
|
torch.nn.utils.weight_norm(m)
|
||||||
|
logging.debug(f"Weight norm is applied to {m}.")
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Perform inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c (torch.Tensor): Input tensor (T, in_channels).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (T ** upsample_factor, out_channels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
if g is not None:
|
||||||
|
g = g.unsqueeze(0)
|
||||||
|
c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
|
||||||
|
return c.squeeze(0).transpose(1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(torch.nn.Module):
|
||||||
|
"""Residual block module in HiFiGAN."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
channels: int = 512,
|
||||||
|
dilations: List[int] = [1, 3, 5],
|
||||||
|
bias: bool = True,
|
||||||
|
use_additional_convs: bool = True,
|
||||||
|
nonlinear_activation: str = "LeakyReLU",
|
||||||
|
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||||
|
):
|
||||||
|
"""Initialize ResidualBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (int): Kernel size of dilation convolution layer.
|
||||||
|
channels (int): Number of channels for convolution layer.
|
||||||
|
dilations (List[int]): List of dilation factors.
|
||||||
|
use_additional_convs (bool): Whether to use additional convolution layers.
|
||||||
|
bias (bool): Whether to add bias parameter in convolution layers.
|
||||||
|
nonlinear_activation (str): Activation function module name.
|
||||||
|
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||||
|
function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.use_additional_convs = use_additional_convs
|
||||||
|
self.convs1 = torch.nn.ModuleList()
|
||||||
|
if use_additional_convs:
|
||||||
|
self.convs2 = torch.nn.ModuleList()
|
||||||
|
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
||||||
|
for dilation in dilations:
|
||||||
|
self.convs1 += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
getattr(torch.nn, nonlinear_activation)(
|
||||||
|
**nonlinear_activation_params
|
||||||
|
),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias,
|
||||||
|
padding=(kernel_size - 1) // 2 * dilation,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if use_additional_convs:
|
||||||
|
self.convs2 += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
getattr(torch.nn, nonlinear_activation)(
|
||||||
|
**nonlinear_activation_params
|
||||||
|
),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
bias=bias,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
for idx in range(len(self.convs1)):
|
||||||
|
xt = self.convs1[idx](x)
|
||||||
|
if self.use_additional_convs:
|
||||||
|
xt = self.convs2[idx](xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANPeriodDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFiGAN period discriminator module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 1,
|
||||||
|
out_channels: int = 1,
|
||||||
|
period: int = 3,
|
||||||
|
kernel_sizes: List[int] = [5, 3],
|
||||||
|
channels: int = 32,
|
||||||
|
downsample_scales: List[int] = [3, 3, 3, 3, 1],
|
||||||
|
max_downsample_channels: int = 1024,
|
||||||
|
bias: bool = True,
|
||||||
|
nonlinear_activation: str = "LeakyReLU",
|
||||||
|
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
use_spectral_norm: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize HiFiGANPeriodDiscriminator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
period (int): Period.
|
||||||
|
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
|
||||||
|
layer.
|
||||||
|
channels (int): Number of initial channels.
|
||||||
|
downsample_scales (List[int]): List of downsampling scales.
|
||||||
|
max_downsample_channels (int): Number of maximum downsampling channels.
|
||||||
|
use_additional_convs (bool): Whether to use additional conv layers in
|
||||||
|
residual blocks.
|
||||||
|
bias (bool): Whether to add bias parameter in convolution layers.
|
||||||
|
nonlinear_activation (str): Activation function module name.
|
||||||
|
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||||
|
function.
|
||||||
|
use_weight_norm (bool): Whether to use weight norm.
|
||||||
|
If set to true, it will be applied to all of the conv layers.
|
||||||
|
use_spectral_norm (bool): Whether to use spectral norm.
|
||||||
|
If set to true, it will be applied to all of the conv layers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert len(kernel_sizes) == 2
|
||||||
|
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
|
||||||
|
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
|
||||||
|
|
||||||
|
self.period = period
|
||||||
|
self.convs = torch.nn.ModuleList()
|
||||||
|
in_chs = in_channels
|
||||||
|
out_chs = channels
|
||||||
|
for downsample_scale in downsample_scales:
|
||||||
|
self.convs += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
(kernel_sizes[0], 1),
|
||||||
|
(downsample_scale, 1),
|
||||||
|
padding=((kernel_sizes[0] - 1) // 2, 0),
|
||||||
|
),
|
||||||
|
getattr(torch.nn, nonlinear_activation)(
|
||||||
|
**nonlinear_activation_params
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
in_chs = out_chs
|
||||||
|
# NOTE(kan-bayashi): Use downsample_scale + 1?
|
||||||
|
out_chs = min(out_chs * 4, max_downsample_channels)
|
||||||
|
self.output_conv = torch.nn.Conv2d(
|
||||||
|
out_chs,
|
||||||
|
out_channels,
|
||||||
|
(kernel_sizes[1] - 1, 1),
|
||||||
|
1,
|
||||||
|
padding=((kernel_sizes[1] - 1) // 2, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_weight_norm and use_spectral_norm:
|
||||||
|
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
|
||||||
|
|
||||||
|
# apply weight norm
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
# apply spectral norm
|
||||||
|
if use_spectral_norm:
|
||||||
|
self.apply_spectral_norm()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of each layer's tensors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# transform 1d to 2d -> (B, C, T/P, P)
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0:
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t += n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
# forward conv
|
||||||
|
outs = []
|
||||||
|
for layer in self.convs:
|
||||||
|
x = layer(x)
|
||||||
|
outs += [x]
|
||||||
|
x = self.output_conv(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
outs += [x]
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
"""Apply weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _apply_weight_norm(m: torch.nn.Module):
|
||||||
|
if isinstance(m, torch.nn.Conv2d):
|
||||||
|
torch.nn.utils.weight_norm(m)
|
||||||
|
logging.debug(f"Weight norm is applied to {m}.")
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def apply_spectral_norm(self):
|
||||||
|
"""Apply spectral normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _apply_spectral_norm(m: torch.nn.Module):
|
||||||
|
if isinstance(m, torch.nn.Conv2d):
|
||||||
|
torch.nn.utils.spectral_norm(m)
|
||||||
|
logging.debug(f"Spectral norm is applied to {m}.")
|
||||||
|
|
||||||
|
self.apply(_apply_spectral_norm)
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFiGAN multi-period discriminator module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
periods: List[int] = [2, 3, 5, 7, 11],
|
||||||
|
discriminator_params: Dict[str, Any] = {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [5, 3],
|
||||||
|
"channels": 32,
|
||||||
|
"downsample_scales": [3, 3, 3, 3, 1],
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"bias": True,
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||||
|
"use_weight_norm": True,
|
||||||
|
"use_spectral_norm": False,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
"""Initialize HiFiGANMultiPeriodDiscriminator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
periods (List[int]): List of periods.
|
||||||
|
discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
|
||||||
|
discriminator module. The period parameter will be overwritten.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.discriminators = torch.nn.ModuleList()
|
||||||
|
for period in periods:
|
||||||
|
params = copy.deepcopy(discriminator_params)
|
||||||
|
params["period"] = period
|
||||||
|
self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input noise signal (B, 1, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: List of list of each discriminator outputs, which consists of each
|
||||||
|
layer output tensors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
outs = []
|
||||||
|
for f in self.discriminators:
|
||||||
|
outs += [f(x)]
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANScaleDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFi-GAN scale discriminator module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 1,
|
||||||
|
out_channels: int = 1,
|
||||||
|
kernel_sizes: List[int] = [15, 41, 5, 3],
|
||||||
|
channels: int = 128,
|
||||||
|
max_downsample_channels: int = 1024,
|
||||||
|
max_groups: int = 16,
|
||||||
|
bias: int = True,
|
||||||
|
downsample_scales: List[int] = [2, 2, 4, 4, 1],
|
||||||
|
nonlinear_activation: str = "LeakyReLU",
|
||||||
|
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
use_spectral_norm: bool = False,
|
||||||
|
):
|
||||||
|
"""Initilize HiFiGAN scale discriminator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
kernel_sizes (List[int]): List of four kernel sizes. The first will be used
|
||||||
|
for the first conv layer, and the second is for downsampling part, and
|
||||||
|
the remaining two are for the last two output layers.
|
||||||
|
channels (int): Initial number of channels for conv layer.
|
||||||
|
max_downsample_channels (int): Maximum number of channels for downsampling
|
||||||
|
layers.
|
||||||
|
bias (bool): Whether to add bias parameter in convolution layers.
|
||||||
|
downsample_scales (List[int]): List of downsampling scales.
|
||||||
|
nonlinear_activation (str): Activation function module name.
|
||||||
|
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||||
|
function.
|
||||||
|
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||||
|
be applied to all of the conv layers.
|
||||||
|
use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
|
||||||
|
will be applied to all of the conv layers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
|
||||||
|
# check kernel size is valid
|
||||||
|
assert len(kernel_sizes) == 4
|
||||||
|
for ks in kernel_sizes:
|
||||||
|
assert ks % 2 == 1
|
||||||
|
|
||||||
|
# add first layer
|
||||||
|
self.layers += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
channels,
|
||||||
|
# NOTE(kan-bayashi): Use always the same kernel size
|
||||||
|
kernel_sizes[0],
|
||||||
|
bias=bias,
|
||||||
|
padding=(kernel_sizes[0] - 1) // 2,
|
||||||
|
),
|
||||||
|
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# add downsample layers
|
||||||
|
in_chs = channels
|
||||||
|
out_chs = channels
|
||||||
|
# NOTE(kan-bayashi): Remove hard coding?
|
||||||
|
groups = 4
|
||||||
|
for downsample_scale in downsample_scales:
|
||||||
|
self.layers += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=kernel_sizes[1],
|
||||||
|
stride=downsample_scale,
|
||||||
|
padding=(kernel_sizes[1] - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
),
|
||||||
|
getattr(torch.nn, nonlinear_activation)(
|
||||||
|
**nonlinear_activation_params
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
in_chs = out_chs
|
||||||
|
# NOTE(kan-bayashi): Remove hard coding?
|
||||||
|
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||||
|
# NOTE(kan-bayashi): Remove hard coding?
|
||||||
|
groups = min(groups * 4, max_groups)
|
||||||
|
|
||||||
|
# add final layers
|
||||||
|
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||||
|
self.layers += [
|
||||||
|
torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=kernel_sizes[2],
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_sizes[2] - 1) // 2,
|
||||||
|
bias=bias,
|
||||||
|
),
|
||||||
|
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.layers += [
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
out_chs,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_sizes[3],
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_sizes[3] - 1) // 2,
|
||||||
|
bias=bias,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_weight_norm and use_spectral_norm:
|
||||||
|
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
|
||||||
|
|
||||||
|
# apply weight norm
|
||||||
|
self.use_weight_norm = use_weight_norm
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
# apply spectral norm
|
||||||
|
self.use_spectral_norm = use_spectral_norm
|
||||||
|
if use_spectral_norm:
|
||||||
|
self.apply_spectral_norm()
|
||||||
|
|
||||||
|
# backward compatibility
|
||||||
|
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input noise signal (B, 1, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]: List of output tensors of each layer.
|
||||||
|
|
||||||
|
"""
|
||||||
|
outs = []
|
||||||
|
for f in self.layers:
|
||||||
|
x = f(x)
|
||||||
|
outs += [x]
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
"""Apply weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _apply_weight_norm(m: torch.nn.Module):
|
||||||
|
if isinstance(m, torch.nn.Conv1d):
|
||||||
|
torch.nn.utils.weight_norm(m)
|
||||||
|
logging.debug(f"Weight norm is applied to {m}.")
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def apply_spectral_norm(self):
|
||||||
|
"""Apply spectral normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _apply_spectral_norm(m: torch.nn.Module):
|
||||||
|
if isinstance(m, torch.nn.Conv1d):
|
||||||
|
torch.nn.utils.spectral_norm(m)
|
||||||
|
logging.debug(f"Spectral norm is applied to {m}.")
|
||||||
|
|
||||||
|
self.apply(_apply_spectral_norm)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
"""Remove weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _remove_weight_norm(m):
|
||||||
|
try:
|
||||||
|
logging.debug(f"Weight norm is removed from {m}.")
|
||||||
|
torch.nn.utils.remove_weight_norm(m)
|
||||||
|
except ValueError: # this module didn't have weight norm
|
||||||
|
return
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
|
def remove_spectral_norm(self):
|
||||||
|
"""Remove spectral normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _remove_spectral_norm(m):
|
||||||
|
try:
|
||||||
|
logging.debug(f"Spectral norm is removed from {m}.")
|
||||||
|
torch.nn.utils.remove_spectral_norm(m)
|
||||||
|
except ValueError: # this module didn't have weight norm
|
||||||
|
return
|
||||||
|
|
||||||
|
self.apply(_remove_spectral_norm)
|
||||||
|
|
||||||
|
def _load_state_dict_pre_hook(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
"""Fix the compatibility of weight / spectral normalization issue.
|
||||||
|
|
||||||
|
Some pretrained models are trained with configs that use weight / spectral
|
||||||
|
normalization, but actually, the norm is not applied. This causes the mismatch
|
||||||
|
of the parameters with configs. To solve this issue, when parameter mismatch
|
||||||
|
happens in loading pretrained model, we remove the norm from the current model.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- https://github.com/espnet/espnet/pull/5240
|
||||||
|
- https://github.com/espnet/espnet/pull/5249
|
||||||
|
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
|
||||||
|
|
||||||
|
"""
|
||||||
|
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
|
||||||
|
if self.use_weight_norm and any(
|
||||||
|
[k.endswith("weight") for k in current_module_keys]
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"It seems weight norm is not applied in the pretrained model but the"
|
||||||
|
" current model uses it. To keep the compatibility, we remove the norm"
|
||||||
|
" from the current model. This may cause unexpected behavior due to the"
|
||||||
|
" parameter mismatch in finetuning. To avoid this issue, please change"
|
||||||
|
" the following parameters in config to false:\n"
|
||||||
|
" - discriminator_params.follow_official_norm\n"
|
||||||
|
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
|
||||||
|
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
|
||||||
|
"\n"
|
||||||
|
"See also:\n"
|
||||||
|
" - https://github.com/espnet/espnet/pull/5240\n"
|
||||||
|
" - https://github.com/espnet/espnet/pull/5249"
|
||||||
|
)
|
||||||
|
self.remove_weight_norm()
|
||||||
|
self.use_weight_norm = False
|
||||||
|
for k in current_module_keys:
|
||||||
|
if k.endswith("weight_g") or k.endswith("weight_v"):
|
||||||
|
del state_dict[k]
|
||||||
|
|
||||||
|
if self.use_spectral_norm and any(
|
||||||
|
[k.endswith("weight") for k in current_module_keys]
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"It seems spectral norm is not applied in the pretrained model but the"
|
||||||
|
" current model uses it. To keep the compatibility, we remove the norm"
|
||||||
|
" from the current model. This may cause unexpected behavior due to the"
|
||||||
|
" parameter mismatch in finetuning. To avoid this issue, please change"
|
||||||
|
" the following parameters in config to false:\n"
|
||||||
|
" - discriminator_params.follow_official_norm\n"
|
||||||
|
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
|
||||||
|
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
|
||||||
|
"\n"
|
||||||
|
"See also:\n"
|
||||||
|
" - https://github.com/espnet/espnet/pull/5240\n"
|
||||||
|
" - https://github.com/espnet/espnet/pull/5249"
|
||||||
|
)
|
||||||
|
self.remove_spectral_norm()
|
||||||
|
self.use_spectral_norm = False
|
||||||
|
for k in current_module_keys:
|
||||||
|
if (
|
||||||
|
k.endswith("weight_u")
|
||||||
|
or k.endswith("weight_v")
|
||||||
|
or k.endswith("weight_orig")
|
||||||
|
):
|
||||||
|
del state_dict[k]
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFi-GAN multi-scale discriminator module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scales: int = 3,
|
||||||
|
downsample_pooling: str = "AvgPool1d",
|
||||||
|
# follow the official implementation setting
|
||||||
|
downsample_pooling_params: Dict[str, Any] = {
|
||||||
|
"kernel_size": 4,
|
||||||
|
"stride": 2,
|
||||||
|
"padding": 2,
|
||||||
|
},
|
||||||
|
discriminator_params: Dict[str, Any] = {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [15, 41, 5, 3],
|
||||||
|
"channels": 128,
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"max_groups": 16,
|
||||||
|
"bias": True,
|
||||||
|
"downsample_scales": [2, 2, 4, 4, 1],
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||||
|
},
|
||||||
|
follow_official_norm: bool = False,
|
||||||
|
):
|
||||||
|
"""Initilize HiFiGAN multi-scale discriminator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scales (int): Number of multi-scales.
|
||||||
|
downsample_pooling (str): Pooling module name for downsampling of the
|
||||||
|
inputs.
|
||||||
|
downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
|
||||||
|
module.
|
||||||
|
discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
|
||||||
|
discriminator module.
|
||||||
|
follow_official_norm (bool): Whether to follow the norm setting of the
|
||||||
|
official implementaion. The first discriminator uses spectral norm
|
||||||
|
and the other discriminators use weight norm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.discriminators = torch.nn.ModuleList()
|
||||||
|
|
||||||
|
# add discriminators
|
||||||
|
for i in range(scales):
|
||||||
|
params = copy.deepcopy(discriminator_params)
|
||||||
|
if follow_official_norm:
|
||||||
|
if i == 0:
|
||||||
|
params["use_weight_norm"] = False
|
||||||
|
params["use_spectral_norm"] = True
|
||||||
|
else:
|
||||||
|
params["use_weight_norm"] = True
|
||||||
|
params["use_spectral_norm"] = False
|
||||||
|
self.discriminators += [HiFiGANScaleDiscriminator(**params)]
|
||||||
|
self.pooling = None
|
||||||
|
if scales > 1:
|
||||||
|
self.pooling = getattr(torch.nn, downsample_pooling)(
|
||||||
|
**downsample_pooling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input noise signal (B, 1, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[torch.Tensor]]: List of list of each discriminator outputs,
|
||||||
|
which consists of eachlayer output tensors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
outs = []
|
||||||
|
for f in self.discriminators:
|
||||||
|
outs += [f(x)]
|
||||||
|
if self.pooling is not None:
|
||||||
|
x = self.pooling(x)
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFi-GAN multi-scale + multi-period discriminator module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# Multi-scale discriminator related
|
||||||
|
scales: int = 3,
|
||||||
|
scale_downsample_pooling: str = "AvgPool1d",
|
||||||
|
scale_downsample_pooling_params: Dict[str, Any] = {
|
||||||
|
"kernel_size": 4,
|
||||||
|
"stride": 2,
|
||||||
|
"padding": 2,
|
||||||
|
},
|
||||||
|
scale_discriminator_params: Dict[str, Any] = {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [15, 41, 5, 3],
|
||||||
|
"channels": 128,
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"max_groups": 16,
|
||||||
|
"bias": True,
|
||||||
|
"downsample_scales": [2, 2, 4, 4, 1],
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||||
|
},
|
||||||
|
follow_official_norm: bool = True,
|
||||||
|
# Multi-period discriminator related
|
||||||
|
periods: List[int] = [2, 3, 5, 7, 11],
|
||||||
|
period_discriminator_params: Dict[str, Any] = {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [5, 3],
|
||||||
|
"channels": 32,
|
||||||
|
"downsample_scales": [3, 3, 3, 3, 1],
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"bias": True,
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||||
|
"use_weight_norm": True,
|
||||||
|
"use_spectral_norm": False,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
"""Initilize HiFiGAN multi-scale + multi-period discriminator module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scales (int): Number of multi-scales.
|
||||||
|
scale_downsample_pooling (str): Pooling module name for downsampling of the
|
||||||
|
inputs.
|
||||||
|
scale_downsample_pooling_params (dict): Parameters for the above pooling
|
||||||
|
module.
|
||||||
|
scale_discriminator_params (dict): Parameters for hifi-gan scale
|
||||||
|
discriminator module.
|
||||||
|
follow_official_norm (bool): Whether to follow the norm setting of the
|
||||||
|
official implementaion. The first discriminator uses spectral norm and
|
||||||
|
the other discriminators use weight norm.
|
||||||
|
periods (list): List of periods.
|
||||||
|
period_discriminator_params (dict): Parameters for hifi-gan period
|
||||||
|
discriminator module. The period parameter will be overwritten.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.msd = HiFiGANMultiScaleDiscriminator(
|
||||||
|
scales=scales,
|
||||||
|
downsample_pooling=scale_downsample_pooling,
|
||||||
|
downsample_pooling_params=scale_downsample_pooling_params,
|
||||||
|
discriminator_params=scale_discriminator_params,
|
||||||
|
follow_official_norm=follow_official_norm,
|
||||||
|
)
|
||||||
|
self.mpd = HiFiGANMultiPeriodDiscriminator(
|
||||||
|
periods=periods,
|
||||||
|
discriminator_params=period_discriminator_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input noise signal (B, 1, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[Tensor]]: List of list of each discriminator outputs,
|
||||||
|
which consists of each layer output tensors. Multi scale and
|
||||||
|
multi period ones are concatenated.
|
||||||
|
|
||||||
|
"""
|
||||||
|
msd_outs = self.msd(x)
|
||||||
|
mpd_outs = self.mpd(x)
|
||||||
|
return msd_outs + mpd_outs
|
244
egs/ljspeech/TTS/vits2/infer.py
Executable file
244
egs/ljspeech/TTS/vits2/infer.py
Executable file
@ -0,0 +1,244 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script performs model inference on test set.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
./vits/infer.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir ./vits/exp \
|
||||||
|
--max-duration 500
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vits/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def infer_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
) -> None:
|
||||||
|
"""Decode dataset.
|
||||||
|
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
tokenizer:
|
||||||
|
Used to convert text to phonemes.
|
||||||
|
"""
|
||||||
|
# Background worker save audios to disk.
|
||||||
|
def _save_worker(
|
||||||
|
batch_size: int,
|
||||||
|
cut_ids: List[str],
|
||||||
|
audio: torch.Tensor,
|
||||||
|
audio_pred: torch.Tensor,
|
||||||
|
audio_lens: List[int],
|
||||||
|
audio_lens_pred: List[int],
|
||||||
|
):
|
||||||
|
for i in range(batch_size):
|
||||||
|
torchaudio.save(
|
||||||
|
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
||||||
|
audio[i : i + 1, : audio_lens[i]],
|
||||||
|
sample_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
torchaudio.save(
|
||||||
|
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
||||||
|
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
||||||
|
sample_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
num_cuts = 0
|
||||||
|
log_interval = 5
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
|
tokens = batch["tokens"]
|
||||||
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
|
tokens = k2.RaggedTensor(tokens)
|
||||||
|
row_splits = tokens.shape.row_splits(1)
|
||||||
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
tokens = tokens.to(device)
|
||||||
|
tokens_lens = tokens_lens.to(device)
|
||||||
|
# tensor of shape (B, T)
|
||||||
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
|
audio = batch["audio"]
|
||||||
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
audio_pred, _, durations = model.inference_batch(
|
||||||
|
text=tokens, text_lengths=tokens_lens
|
||||||
|
)
|
||||||
|
audio_pred = audio_pred.detach().cpu()
|
||||||
|
# convert to samples
|
||||||
|
audio_lens_pred = (
|
||||||
|
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
_save_worker,
|
||||||
|
batch_size,
|
||||||
|
cut_ids,
|
||||||
|
audio,
|
||||||
|
audio_pred,
|
||||||
|
audio_lens,
|
||||||
|
audio_lens_pred,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cuts += batch_size
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
# return results
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LJSpeechTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||||
|
params.save_wav_dir = params.res_dir / "wav"
|
||||||
|
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||||
|
logging.info("Infer started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.blank_id
|
||||||
|
params.oov_id = tokenizer.oov_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param_g = sum([p.numel() for p in model.generator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||||
|
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||||
|
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = ljspeech.test_cuts()
|
||||||
|
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
infer_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
335
egs/ljspeech/TTS/vits2/loss.py
Normal file
335
egs/ljspeech/TTS/vits2/loss.py
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""HiFiGAN-related loss modules.
|
||||||
|
|
||||||
|
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributions as D
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||||
|
"""Generator adversarial loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
average_by_discriminators: bool = True,
|
||||||
|
loss_type: str = "mse",
|
||||||
|
):
|
||||||
|
"""Initialize GeneratorAversarialLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
average_by_discriminators (bool): Whether to average the loss by
|
||||||
|
the number of discriminators.
|
||||||
|
loss_type (str): Loss type, "mse" or "hinge".
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.average_by_discriminators = average_by_discriminators
|
||||||
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||||
|
if loss_type == "mse":
|
||||||
|
self.criterion = self._mse_loss
|
||||||
|
else:
|
||||||
|
self.criterion = self._hinge_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calcualate generator adversarial loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||||
|
outputs, list of discriminator outputs, or list of list of discriminator
|
||||||
|
outputs..
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Generator adversarial loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
|
adv_loss = 0.0
|
||||||
|
for i, outputs_ in enumerate(outputs):
|
||||||
|
if isinstance(outputs_, (tuple, list)):
|
||||||
|
# NOTE(kan-bayashi): case including feature maps
|
||||||
|
outputs_ = outputs_[-1]
|
||||||
|
adv_loss += self.criterion(outputs_)
|
||||||
|
if self.average_by_discriminators:
|
||||||
|
adv_loss /= i + 1
|
||||||
|
else:
|
||||||
|
adv_loss = self.criterion(outputs)
|
||||||
|
|
||||||
|
return adv_loss
|
||||||
|
|
||||||
|
def _mse_loss(self, x):
|
||||||
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
|
|
||||||
|
def _hinge_loss(self, x):
|
||||||
|
return -x.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||||
|
"""Discriminator adversarial loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
average_by_discriminators: bool = True,
|
||||||
|
loss_type: str = "mse",
|
||||||
|
):
|
||||||
|
"""Initialize DiscriminatorAversarialLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
average_by_discriminators (bool): Whether to average the loss by
|
||||||
|
the number of discriminators.
|
||||||
|
loss_type (str): Loss type, "mse" or "hinge".
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.average_by_discriminators = average_by_discriminators
|
||||||
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||||
|
if loss_type == "mse":
|
||||||
|
self.fake_criterion = self._mse_fake_loss
|
||||||
|
self.real_criterion = self._mse_real_loss
|
||||||
|
else:
|
||||||
|
self.fake_criterion = self._hinge_fake_loss
|
||||||
|
self.real_criterion = self._hinge_real_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||||
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Calcualate discriminator adversarial loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||||
|
outputs, list of discriminator outputs, or list of list of discriminator
|
||||||
|
outputs calculated from generator.
|
||||||
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||||
|
outputs, list of discriminator outputs, or list of list of discriminator
|
||||||
|
outputs calculated from groundtruth.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Discriminator real loss value.
|
||||||
|
Tensor: Discriminator fake loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
|
real_loss = 0.0
|
||||||
|
fake_loss = 0.0
|
||||||
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||||
|
if isinstance(outputs_hat_, (tuple, list)):
|
||||||
|
# NOTE(kan-bayashi): case including feature maps
|
||||||
|
outputs_hat_ = outputs_hat_[-1]
|
||||||
|
outputs_ = outputs_[-1]
|
||||||
|
real_loss += self.real_criterion(outputs_)
|
||||||
|
fake_loss += self.fake_criterion(outputs_hat_)
|
||||||
|
if self.average_by_discriminators:
|
||||||
|
fake_loss /= i + 1
|
||||||
|
real_loss /= i + 1
|
||||||
|
else:
|
||||||
|
real_loss = self.real_criterion(outputs)
|
||||||
|
fake_loss = self.fake_criterion(outputs_hat)
|
||||||
|
|
||||||
|
return real_loss, fake_loss
|
||||||
|
|
||||||
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
|
|
||||||
|
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||||
|
|
||||||
|
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
|
||||||
|
|
||||||
|
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMatchLoss(torch.nn.Module):
|
||||||
|
"""Feature matching loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
average_by_layers: bool = True,
|
||||||
|
average_by_discriminators: bool = True,
|
||||||
|
include_final_outputs: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize FeatureMatchLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
average_by_layers (bool): Whether to average the loss by the number
|
||||||
|
of layers.
|
||||||
|
average_by_discriminators (bool): Whether to average the loss by
|
||||||
|
the number of discriminators.
|
||||||
|
include_final_outputs (bool): Whether to include the final output of
|
||||||
|
each discriminator for loss calculation.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.average_by_layers = average_by_layers
|
||||||
|
self.average_by_discriminators = average_by_discriminators
|
||||||
|
self.include_final_outputs = include_final_outputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||||
|
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate feature matching loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||||
|
discriminator outputs or list of discriminator outputs calcuated
|
||||||
|
from generator's outputs.
|
||||||
|
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||||
|
discriminator outputs or list of discriminator outputs calcuated
|
||||||
|
from groundtruth..
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Feature matching loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
feat_match_loss = 0.0
|
||||||
|
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
||||||
|
feat_match_loss_ = 0.0
|
||||||
|
if not self.include_final_outputs:
|
||||||
|
feats_hat_ = feats_hat_[:-1]
|
||||||
|
feats_ = feats_[:-1]
|
||||||
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||||
|
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||||
|
if self.average_by_layers:
|
||||||
|
feat_match_loss_ /= j + 1
|
||||||
|
feat_match_loss += feat_match_loss_
|
||||||
|
if self.average_by_discriminators:
|
||||||
|
feat_match_loss /= i + 1
|
||||||
|
|
||||||
|
return feat_match_loss
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogramLoss(torch.nn.Module):
|
||||||
|
"""Mel-spectrogram loss."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampling_rate: int = 22050,
|
||||||
|
frame_length: int = 1024, # in samples
|
||||||
|
frame_shift: int = 256, # in samples
|
||||||
|
n_mels: int = 80,
|
||||||
|
use_fft_mag: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.wav_to_mel = Wav2LogFilterBank(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=frame_length / sampling_rate, # in second
|
||||||
|
frame_shift=frame_shift / sampling_rate, # in second
|
||||||
|
use_fft_mag=use_fft_mag,
|
||||||
|
num_filters=n_mels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
y_hat: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
return_mel: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Calculate Mel-spectrogram loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||||
|
y (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||||
|
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
||||||
|
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
||||||
|
waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Mel-spectrogram loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
|
||||||
|
mel = self.wav_to_mel(y.squeeze(1))
|
||||||
|
mel_loss = F.l1_loss(mel_hat, mel)
|
||||||
|
|
||||||
|
if return_mel:
|
||||||
|
return mel_loss, (mel_hat, mel)
|
||||||
|
|
||||||
|
return mel_loss
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
||||||
|
|
||||||
|
"""VITS-related loss modules.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class KLDivergenceLoss(torch.nn.Module):
|
||||||
|
"""KL divergence loss."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
z_p: torch.Tensor,
|
||||||
|
logs_q: torch.Tensor,
|
||||||
|
m_p: torch.Tensor,
|
||||||
|
logs_p: torch.Tensor,
|
||||||
|
z_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate KL divergence loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z_p (Tensor): Flow hidden representation (B, H, T_feats).
|
||||||
|
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
||||||
|
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
||||||
|
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
||||||
|
z_mask (Tensor): Mask tensor (B, 1, T_feats).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: KL divergence loss.
|
||||||
|
|
||||||
|
"""
|
||||||
|
z_p = z_p.float()
|
||||||
|
logs_q = logs_q.float()
|
||||||
|
m_p = m_p.float()
|
||||||
|
logs_p = logs_p.float()
|
||||||
|
z_mask = z_mask.float()
|
||||||
|
kl = logs_p - logs_q - 0.5
|
||||||
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||||
|
kl = torch.sum(kl * z_mask)
|
||||||
|
loss = kl / torch.sum(z_mask)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
||||||
|
"""KL divergence loss without flow."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
m_q: torch.Tensor,
|
||||||
|
logs_q: torch.Tensor,
|
||||||
|
m_p: torch.Tensor,
|
||||||
|
logs_p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate KL divergence loss without flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
|
||||||
|
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
||||||
|
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
||||||
|
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
||||||
|
"""
|
||||||
|
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
|
||||||
|
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
||||||
|
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
||||||
|
return loss
|
81
egs/ljspeech/TTS/vits2/monotonic_align/__init__.py
Normal file
81
egs/ljspeech/TTS/vits2/monotonic_align/__init__.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
|
||||||
|
|
||||||
|
"""Maximum path calculation module.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from numba import njit, prange
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .core import maximum_path_c
|
||||||
|
|
||||||
|
is_cython_avalable = True
|
||||||
|
except ImportError:
|
||||||
|
is_cython_avalable = False
|
||||||
|
warnings.warn(
|
||||||
|
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
|
||||||
|
"If you want to use the cython version, please build it as follows: "
|
||||||
|
"`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate maximum path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
|
||||||
|
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Maximum path tensor (B, T_feats, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
device, dtype = neg_x_ent.device, neg_x_ent.dtype
|
||||||
|
neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
|
||||||
|
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
|
||||||
|
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
|
||||||
|
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
|
||||||
|
if is_cython_avalable:
|
||||||
|
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
|
||||||
|
else:
|
||||||
|
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
|
||||||
|
|
||||||
|
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
|
||||||
|
"""Calculate a single maximum path with numba."""
|
||||||
|
index = t_x - 1
|
||||||
|
for y in range(t_y):
|
||||||
|
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||||
|
if x == y:
|
||||||
|
v_cur = max_neg_val
|
||||||
|
else:
|
||||||
|
v_cur = value[y - 1, x]
|
||||||
|
if x == 0:
|
||||||
|
if y == 0:
|
||||||
|
v_prev = 0.0
|
||||||
|
else:
|
||||||
|
v_prev = max_neg_val
|
||||||
|
else:
|
||||||
|
v_prev = value[y - 1, x - 1]
|
||||||
|
value[y, x] += max(v_prev, v_cur)
|
||||||
|
|
||||||
|
for y in range(t_y - 1, -1, -1):
|
||||||
|
path[y, index] = 1
|
||||||
|
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||||
|
index = index - 1
|
||||||
|
|
||||||
|
|
||||||
|
@njit(parallel=True)
|
||||||
|
def maximum_path_numba(paths, values, t_ys, t_xs):
|
||||||
|
"""Calculate batch maximum path with numba."""
|
||||||
|
for i in prange(paths.shape[0]):
|
||||||
|
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
|
51
egs/ljspeech/TTS/vits2/monotonic_align/core.pyx
Normal file
51
egs/ljspeech/TTS/vits2/monotonic_align/core.pyx
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
|
||||||
|
|
||||||
|
"""Maximum path calculation module with cython optimization.
|
||||||
|
|
||||||
|
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
cimport cython
|
||||||
|
|
||||||
|
from cython.parallel import prange
|
||||||
|
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||||
|
cdef int x
|
||||||
|
cdef int y
|
||||||
|
cdef float v_prev
|
||||||
|
cdef float v_cur
|
||||||
|
cdef float tmp
|
||||||
|
cdef int index = t_x - 1
|
||||||
|
|
||||||
|
for y in range(t_y):
|
||||||
|
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||||
|
if x == y:
|
||||||
|
v_cur = max_neg_val
|
||||||
|
else:
|
||||||
|
v_cur = value[y - 1, x]
|
||||||
|
if x == 0:
|
||||||
|
if y == 0:
|
||||||
|
v_prev = 0.0
|
||||||
|
else:
|
||||||
|
v_prev = max_neg_val
|
||||||
|
else:
|
||||||
|
v_prev = value[y - 1, x - 1]
|
||||||
|
value[y, x] += max(v_prev, v_cur)
|
||||||
|
|
||||||
|
for y in range(t_y - 1, -1, -1):
|
||||||
|
path[y, index] = 1
|
||||||
|
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||||
|
index = index - 1
|
||||||
|
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||||
|
cdef int b = paths.shape[0]
|
||||||
|
cdef int i
|
||||||
|
for i in prange(b, nogil=True):
|
||||||
|
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
31
egs/ljspeech/TTS/vits2/monotonic_align/setup.py
Normal file
31
egs/ljspeech/TTS/vits2/monotonic_align/setup.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
|
||||||
|
"""Setup cython code."""
|
||||||
|
|
||||||
|
from Cython.Build import cythonize
|
||||||
|
from setuptools import Extension, setup
|
||||||
|
from setuptools.command.build_ext import build_ext as _build_ext
|
||||||
|
|
||||||
|
|
||||||
|
class build_ext(_build_ext):
|
||||||
|
"""Overwrite build_ext."""
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
"""Prevent numpy from thinking it is still in its setup process."""
|
||||||
|
_build_ext.finalize_options(self)
|
||||||
|
__builtins__.__NUMPY_SETUP__ = False
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
self.include_dirs.append(numpy.get_include())
|
||||||
|
|
||||||
|
|
||||||
|
exts = [
|
||||||
|
Extension(
|
||||||
|
name="core",
|
||||||
|
sources=["core.pyx"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
setup(
|
||||||
|
name="monotonic_align",
|
||||||
|
ext_modules=cythonize(exts, language_level=3),
|
||||||
|
cmdclass={"build_ext": build_ext},
|
||||||
|
)
|
117
egs/ljspeech/TTS/vits2/posterior_encoder.py
Normal file
117
egs/ljspeech/TTS/vits2/posterior_encoder.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Posterior encoder module in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from wavenet import Conv1d, WaveNet
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class PosteriorEncoder(torch.nn.Module):
|
||||||
|
"""Posterior encoder module in VITS.
|
||||||
|
|
||||||
|
This is a module of posterior encoder described in `Conditional Variational
|
||||||
|
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 513,
|
||||||
|
out_channels: int = 192,
|
||||||
|
hidden_channels: int = 192,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
layers: int = 16,
|
||||||
|
stacks: int = 1,
|
||||||
|
base_dilation: int = 1,
|
||||||
|
global_channels: int = -1,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
):
|
||||||
|
"""Initilialize PosteriorEncoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size in WaveNet.
|
||||||
|
layers (int): Number of layers of WaveNet.
|
||||||
|
stacks (int): Number of repeat stacking of WaveNet.
|
||||||
|
base_dilation (int): Base dilation factor.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
bias (bool): Whether to use bias parameters in conv.
|
||||||
|
use_weight_norm (bool): Whether to apply weight norm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.input_conv = Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.encoder = WaveNet(
|
||||||
|
in_channels=-1,
|
||||||
|
out_channels=-1,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
layers=layers,
|
||||||
|
stacks=stacks,
|
||||||
|
base_dilation=base_dilation,
|
||||||
|
residual_channels=hidden_channels,
|
||||||
|
aux_channels=-1,
|
||||||
|
gate_channels=hidden_channels * 2,
|
||||||
|
skip_channels=hidden_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
bias=bias,
|
||||||
|
use_weight_norm=use_weight_norm,
|
||||||
|
use_first_conv=False,
|
||||||
|
use_last_conv=False,
|
||||||
|
scale_residual=False,
|
||||||
|
scale_skip_connect=True,
|
||||||
|
)
|
||||||
|
self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T_feats).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
|
||||||
|
Tensor: Projected mean tensor (B, out_channels, T_feats).
|
||||||
|
Tensor: Projected scale tensor (B, out_channels, T_feats).
|
||||||
|
Tensor: Mask tensor for input tensor (B, 1, T_feats).
|
||||||
|
|
||||||
|
"""
|
||||||
|
x_mask = (
|
||||||
|
(~make_pad_mask(x_lengths))
|
||||||
|
.unsqueeze(1)
|
||||||
|
.to(
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
x = self.input_conv(x) * x_mask
|
||||||
|
x = self.encoder(x, x_mask, g=g)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||||
|
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||||
|
|
||||||
|
return z, m, logs, x_mask
|
228
egs/ljspeech/TTS/vits2/residual_coupling.py
Normal file
228
egs/ljspeech/TTS/vits2/residual_coupling.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Residual affine coupling modules in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flow import FlipFlow
|
||||||
|
from wavenet import WaveNet
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||||
|
"""Residual affine coupling block module.
|
||||||
|
|
||||||
|
This is a module of residual affine coupling block, which used as "Flow" in
|
||||||
|
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`_.
|
||||||
|
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 192,
|
||||||
|
hidden_channels: int = 192,
|
||||||
|
flows: int = 4,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
base_dilation: int = 1,
|
||||||
|
layers: int = 4,
|
||||||
|
global_channels: int = -1,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
use_only_mean: bool = True,
|
||||||
|
):
|
||||||
|
"""Initilize ResidualAffineCouplingBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
flows (int): Number of flows.
|
||||||
|
kernel_size (int): Kernel size for WaveNet.
|
||||||
|
base_dilation (int): Base dilation factor for WaveNet.
|
||||||
|
layers (int): Number of layers of WaveNet.
|
||||||
|
stacks (int): Number of stacks of WaveNet.
|
||||||
|
global_channels (int): Number of global channels.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||||
|
bias (bool): Whether to use bias paramters in WaveNet.
|
||||||
|
use_only_mean (bool): Whether to estimate only mean.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.flows = torch.nn.ModuleList()
|
||||||
|
for i in range(flows):
|
||||||
|
self.flows += [
|
||||||
|
ResidualAffineCouplingLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
base_dilation=base_dilation,
|
||||||
|
layers=layers,
|
||||||
|
stacks=1,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
use_weight_norm=use_weight_norm,
|
||||||
|
bias=bias,
|
||||||
|
use_only_mean=use_only_mean,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.flows += [FlipFlow()]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
inverse: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, in_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not inverse:
|
||||||
|
for flow in self.flows:
|
||||||
|
x, _ = flow(x, x_mask, g=g, inverse=inverse)
|
||||||
|
else:
|
||||||
|
for flow in reversed(self.flows):
|
||||||
|
x = flow(x, x_mask, g=g, inverse=inverse)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAffineCouplingLayer(torch.nn.Module):
|
||||||
|
"""Residual affine coupling layer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 192,
|
||||||
|
hidden_channels: int = 192,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
base_dilation: int = 1,
|
||||||
|
layers: int = 5,
|
||||||
|
stacks: int = 1,
|
||||||
|
global_channels: int = -1,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
use_only_mean: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialzie ResidualAffineCouplingLayer module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size for WaveNet.
|
||||||
|
base_dilation (int): Base dilation factor for WaveNet.
|
||||||
|
layers (int): Number of layers of WaveNet.
|
||||||
|
stacks (int): Number of stacks of WaveNet.
|
||||||
|
global_channels (int): Number of global channels.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||||
|
bias (bool): Whether to use bias paramters in WaveNet.
|
||||||
|
use_only_mean (bool): Whether to estimate only mean.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
|
||||||
|
super().__init__()
|
||||||
|
self.half_channels = in_channels // 2
|
||||||
|
self.use_only_mean = use_only_mean
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.input_conv = torch.nn.Conv1d(
|
||||||
|
self.half_channels,
|
||||||
|
hidden_channels,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
self.encoder = WaveNet(
|
||||||
|
in_channels=-1,
|
||||||
|
out_channels=-1,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
layers=layers,
|
||||||
|
stacks=stacks,
|
||||||
|
base_dilation=base_dilation,
|
||||||
|
residual_channels=hidden_channels,
|
||||||
|
aux_channels=-1,
|
||||||
|
gate_channels=hidden_channels * 2,
|
||||||
|
skip_channels=hidden_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
bias=bias,
|
||||||
|
use_weight_norm=use_weight_norm,
|
||||||
|
use_first_conv=False,
|
||||||
|
use_last_conv=False,
|
||||||
|
scale_residual=False,
|
||||||
|
scale_skip_connect=True,
|
||||||
|
)
|
||||||
|
if use_only_mean:
|
||||||
|
self.proj = torch.nn.Conv1d(
|
||||||
|
hidden_channels,
|
||||||
|
self.half_channels,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.proj = torch.nn.Conv1d(
|
||||||
|
hidden_channels,
|
||||||
|
self.half_channels * 2,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
self.proj.weight.data.zero_()
|
||||||
|
self.proj.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
inverse: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, in_channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
|
||||||
|
"""
|
||||||
|
xa, xb = x.split(x.size(1) // 2, dim=1)
|
||||||
|
h = self.input_conv(xa) * x_mask
|
||||||
|
h = self.encoder(h, x_mask, g=g)
|
||||||
|
stats = self.proj(h) * x_mask
|
||||||
|
if not self.use_only_mean:
|
||||||
|
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||||
|
else:
|
||||||
|
m = stats
|
||||||
|
logs = torch.zeros_like(m)
|
||||||
|
|
||||||
|
if not inverse:
|
||||||
|
xb = m + xb * torch.exp(logs) * x_mask
|
||||||
|
x = torch.cat([xa, xb], 1)
|
||||||
|
logdet = torch.sum(logs, [1, 2])
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
xb = (xb - m) * torch.exp(-logs) * x_mask
|
||||||
|
x = torch.cat([xa, xb], 1)
|
||||||
|
return x
|
123
egs/ljspeech/TTS/vits2/test_onnx.py
Executable file
123
egs/ljspeech/TTS/vits2/test_onnx.py
Executable file
@ -0,0 +1,123 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is used to test the exported onnx model by vits/export-onnx.py
|
||||||
|
|
||||||
|
Use the onnx model to generate a wav:
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the onnx model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(self, model_filename: str):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 4
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||||
|
|
||||||
|
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
A 1-D tensor of shape (1, T)
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (1, T')
|
||||||
|
"""
|
||||||
|
noise_scale = torch.tensor([0.667], dtype=torch.float32)
|
||||||
|
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
alpha = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
out = self.model.run(
|
||||||
|
[
|
||||||
|
self.model.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||||
|
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||||
|
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||||
|
self.model.get_inputs()[3].name: alpha.numpy(),
|
||||||
|
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(args.tokens)
|
||||||
|
|
||||||
|
logging.info("About to create onnx model")
|
||||||
|
model = OnnxModel(args.model_filename)
|
||||||
|
|
||||||
|
text = "I went there to see the land, the people and how their system works, end quote."
|
||||||
|
tokens = tokenizer.texts_to_token_ids([text])
|
||||||
|
tokens = torch.tensor(tokens) # (1, T)
|
||||||
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||||
|
audio = model(tokens, tokens_lens) # (1, T')
|
||||||
|
|
||||||
|
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||||
|
logging.info("Saved to test_onnx.wav")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
684
egs/ljspeech/TTS/vits2/text_encoder.py
Normal file
684
egs/ljspeech/TTS/vits2/text_encoder.py
Normal file
@ -0,0 +1,684 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Text encoder module in VITS.
|
||||||
|
|
||||||
|
This code is based on
|
||||||
|
- https://github.com/jaywalnut310/vits
|
||||||
|
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
|
||||||
|
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from icefall.utils import is_jit_tracing, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoder(torch.nn.Module):
|
||||||
|
"""Text encoder module in VITS.
|
||||||
|
|
||||||
|
This is a module of text encoder described in `Conditional Variational Autoencoder
|
||||||
|
with Adversarial Learning for End-to-End Text-to-Speech`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabs: int,
|
||||||
|
d_model: int = 192,
|
||||||
|
num_heads: int = 2,
|
||||||
|
dim_feedforward: int = 768,
|
||||||
|
cnn_module_kernel: int = 5,
|
||||||
|
num_layers: int = 6,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
"""Initialize TextEncoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocabs (int): Vocabulary size.
|
||||||
|
d_model (int): attention dimension
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
dim_feedforward (int): feedforward dimention
|
||||||
|
cnn_module_kernel (int): convolution kernel size
|
||||||
|
num_layers (int): number of encoder layers
|
||||||
|
dropout (float): dropout rate
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.emb = torch.nn.Embedding(vocabs, d_model)
|
||||||
|
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
|
||||||
|
|
||||||
|
# We use conformer as text encoder
|
||||||
|
self.encoder = Transformer(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lengths: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input index tensor (B, T_text).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||||
|
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||||
|
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||||
|
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# (B, T_text, embed_dim)
|
||||||
|
x = self.emb(x) * math.sqrt(self.d_model)
|
||||||
|
|
||||||
|
assert x.size(1) == x_lengths.max().item()
|
||||||
|
|
||||||
|
# (B, T_text)
|
||||||
|
pad_mask = make_pad_mask(x_lengths)
|
||||||
|
|
||||||
|
# encoder assume the channel last (B, T_text, embed_dim)
|
||||||
|
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||||
|
|
||||||
|
# convert the channel first (B, embed_dim, T_text)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
non_pad_mask = (~pad_mask).unsqueeze(1)
|
||||||
|
stats = self.proj(x) * non_pad_mask
|
||||||
|
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||||
|
|
||||||
|
return x, m, logs, non_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
d_model (int): attention dimension
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
dim_feedforward (int): feedforward dimention
|
||||||
|
cnn_module_kernel (int): convolution kernel size
|
||||||
|
num_layers (int): number of encoder layers
|
||||||
|
dropout (float): dropout rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 192,
|
||||||
|
num_heads: int = 2,
|
||||||
|
dim_feedforward: int = 768,
|
||||||
|
cnn_module_kernel: int = 5,
|
||||||
|
num_layers: int = 6,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
||||||
|
self.after_norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: Tensor, key_padding_mask: Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
lengths:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
"""
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
||||||
|
|
||||||
|
x = self.after_norm(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
TransformerEncoderLayer is made up of self-attn and feedforward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: the number of expected features in the input.
|
||||||
|
num_heads: the number of heads in the multi-head attention models.
|
||||||
|
dim_feedforward: the dimension of the feed-forward network model.
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
num_heads: int,
|
||||||
|
dim_feedforward: int,
|
||||||
|
cnn_module_kernel: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
|
nn.Linear(d_model, dim_feedforward),
|
||||||
|
Swish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim_feedforward, d_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, num_heads, dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
|
self.feed_forward = nn.Sequential(
|
||||||
|
nn.Linear(d_model, dim_feedforward),
|
||||||
|
Swish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim_feedforward, d_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
||||||
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
|
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||||
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
|
|
||||||
|
self.ff_scale = 0.5
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Pass the input through the transformer encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
|
||||||
|
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||||
|
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||||
|
"""
|
||||||
|
# macaron style feed-forward module
|
||||||
|
src = src + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi-head self-attention module
|
||||||
|
src_attn = self.self_attn(
|
||||||
|
self.norm_mha(src),
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
)
|
||||||
|
src = src + self.dropout(src_attn)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
|
||||||
|
|
||||||
|
# feed-forward module
|
||||||
|
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
|
||||||
|
|
||||||
|
src = self.norm_final(src)
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
r"""TransformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_layer: an instance of the TransformerEncoderLayer class.
|
||||||
|
num_layers: the number of sub-encoder-layers in the encoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
|
||||||
|
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||||
|
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||||
|
"""
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
output = mod(
|
||||||
|
output,
|
||||||
|
pos_emb,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
|
"""Relative positional encoding module.
|
||||||
|
|
||||||
|
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: Embedding dimension.
|
||||||
|
dropout_rate: Dropout rate.
|
||||||
|
max_len: Maximum input length.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||||
|
"""Construct an PositionalEncoding object."""
|
||||||
|
super(RelPositionalEncoding, self).__init__()
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.xscale = math.sqrt(self.d_model)
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
|
self.pe = None
|
||||||
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
|
def extend_pe(self, x: Tensor) -> None:
|
||||||
|
"""Reset the positional encodings."""
|
||||||
|
x_size = x.size(1)
|
||||||
|
if self.pe is not None:
|
||||||
|
# self.pe contains both positive and negative parts
|
||||||
|
# the length of self.pe is 2 * input_len - 1
|
||||||
|
if self.pe.size(1) >= x_size * 2 - 1:
|
||||||
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||||
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
|
return
|
||||||
|
# Suppose `i` means to the position of query vector and `j` means the
|
||||||
|
# position of key vector. We use position relative positions when keys
|
||||||
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
|
pe_positive = torch.zeros(x_size, self.d_model)
|
||||||
|
pe_negative = torch.zeros(x_size, self.d_model)
|
||||||
|
position = torch.arange(0, x_size, dtype=torch.float32).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
|
* -(math.log(10000.0) / self.d_model)
|
||||||
|
)
|
||||||
|
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||||
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||||
|
|
||||||
|
# Reserve the order of positive indices and concat both positive and
|
||||||
|
# negative indices. This is used to support the shifting trick
|
||||||
|
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||||
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||||
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Add positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
"""
|
||||||
|
self.extend_pe(x)
|
||||||
|
x = x * self.xscale
|
||||||
|
pos_emb = self.pe[
|
||||||
|
:,
|
||||||
|
self.pe.size(1) // 2
|
||||||
|
- x.size(1)
|
||||||
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
|
+ x.size(1),
|
||||||
|
]
|
||||||
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionMultiheadAttention(nn.Module):
|
||||||
|
r"""Multi-Head Attention layer with relative position encoding
|
||||||
|
|
||||||
|
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super(RelPositionMultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
|
# linear transformation for positional encoding.
|
||||||
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self) -> None:
|
||||||
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||||
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
|
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (batch, head, seq_len, 2*seq_len-1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: tensor of shape (batch, head, seq_len, seq_len)
|
||||||
|
"""
|
||||||
|
(batch_size, num_heads, seq_len, n) = x.shape
|
||||||
|
|
||||||
|
if not is_jit_tracing():
|
||||||
|
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
|
||||||
|
|
||||||
|
if is_jit_tracing():
|
||||||
|
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
|
||||||
|
cols = torch.arange(seq_len)
|
||||||
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
|
indexes = rows + cols
|
||||||
|
|
||||||
|
x = x.reshape(-1, n)
|
||||||
|
x = torch.gather(x, dim=1, index=indexes)
|
||||||
|
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
batch_stride = x.stride(0)
|
||||||
|
head_stride = x.stride(1)
|
||||||
|
time_stride = x.stride(2)
|
||||||
|
n_stride = x.stride(3)
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, num_heads, seq_len, seq_len),
|
||||||
|
(batch_stride, head_stride, time_stride - n_stride, n_stride),
|
||||||
|
storage_offset=n_stride * (seq_len - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (seq_len, batch_size, embed_dim)
|
||||||
|
pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
|
||||||
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
|
be ignored by the attention. This is an binary mask. When the value is True,
|
||||||
|
the corresponding value on the attention layer will be filled with -inf.
|
||||||
|
Its shape is (batch_size, seq_len).
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A tensor of shape (seq_len, batch_size, embed_dim).
|
||||||
|
"""
|
||||||
|
seq_len, batch_size, _ = x.shape
|
||||||
|
scaling = float(self.head_dim) ** -0.5
|
||||||
|
|
||||||
|
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||||
|
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||||
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(seq_len, batch_size * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
|
||||||
|
|
||||||
|
p = self.linear_pos(pos_emb).view(
|
||||||
|
pos_emb.size(0), -1, self.num_heads, self.head_dim
|
||||||
|
)
|
||||||
|
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
|
||||||
|
p = p.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# (batch_size, num_head, seq_len, head_dim)
|
||||||
|
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||||
|
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||||
|
|
||||||
|
# compute attention score
|
||||||
|
# first compute matrix a and matrix c
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
|
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
|
||||||
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch_size, num_head, seq_len, seq_len)
|
||||||
|
|
||||||
|
# compute matrix b and matrix d
|
||||||
|
matrix_bd = torch.matmul(
|
||||||
|
q_with_bias_v, p
|
||||||
|
) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||||
|
matrix_bd = self.rel_shift(
|
||||||
|
matrix_bd
|
||||||
|
) # (batch_size, num_head, seq_len, seq_len)
|
||||||
|
|
||||||
|
# (batch_size, num_head, seq_len, seq_len)
|
||||||
|
attn_output_weights = (matrix_ac + matrix_bd) * scaling
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
batch_size * self.num_heads, seq_len, seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape == (batch_size, seq_len)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
batch_size, self.num_heads, seq_len, seq_len
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
batch_size * self.num_heads, seq_len, seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||||
|
attn_output_weights = nn.functional.dropout(
|
||||||
|
attn_output_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
|
||||||
|
# (batch_size * num_head, seq_len, head_dim)
|
||||||
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
assert attn_output.shape == (
|
||||||
|
batch_size * self.num_heads,
|
||||||
|
seq_len,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = (
|
||||||
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(seq_len, batch_size, self.embed_dim)
|
||||||
|
)
|
||||||
|
# (seq_len, batch_size, embed_dim)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class ConvolutionModule(nn.Module):
|
||||||
|
"""ConvolutionModule in Conformer model.
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Construct an ConvolutionModule object."""
|
||||||
|
super(ConvolutionModule, self).__init__()
|
||||||
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
groups=channels,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(channels)
|
||||||
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.activation = Swish()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (#time, batch, channels).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|
||||||
|
# GLU mechanism
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
|
# 1D Depthwise Conv
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Return Swich activation function."""
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_text_encoder():
|
||||||
|
vocabs = 500
|
||||||
|
d_model = 192
|
||||||
|
batch_size = 5
|
||||||
|
seq_len = 100
|
||||||
|
|
||||||
|
m = TextEncoder(vocabs=vocabs, d_model=d_model)
|
||||||
|
x, m, logs, mask = m(
|
||||||
|
x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
|
||||||
|
x_lengths=torch.full((batch_size,), seq_len),
|
||||||
|
)
|
||||||
|
print(x.shape, m.shape, logs.shape, mask.shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_text_encoder()
|
108
egs/ljspeech/TTS/vits2/tokenizer.py
Normal file
108
egs/ljspeech/TTS/vits2/tokenizer.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import g2p_en
|
||||||
|
import tacotron_cleaner.cleaners
|
||||||
|
from utils import intersperse
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer(object):
|
||||||
|
def __init__(self, tokens: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens: the file that maps tokens to ids
|
||||||
|
"""
|
||||||
|
# Parse token file
|
||||||
|
self.token2id: Dict[str, int] = {}
|
||||||
|
with open(tokens, "r", encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split()
|
||||||
|
if len(info) == 1:
|
||||||
|
# case of space
|
||||||
|
token = " "
|
||||||
|
id = int(info[0])
|
||||||
|
else:
|
||||||
|
token, id = info[0], int(info[1])
|
||||||
|
self.token2id[token] = id
|
||||||
|
|
||||||
|
self.blank_id = self.token2id["<blk>"]
|
||||||
|
self.oov_id = self.token2id["<unk>"]
|
||||||
|
self.vocab_size = len(self.token2id)
|
||||||
|
|
||||||
|
self.g2p = g2p_en.G2p()
|
||||||
|
|
||||||
|
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of transcripts.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Text normalization
|
||||||
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
|
# Convert to phonemes
|
||||||
|
tokens = self.g2p(text)
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t in self.token2id:
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
else:
|
||||||
|
token_ids.append(self.oov_id)
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.blank_id)
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
||||||
|
|
||||||
|
def tokens_to_token_ids(
|
||||||
|
self, tokens_list: List[str], intersperse_blank: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens_list:
|
||||||
|
A list of token list, each corresponding to one utterance.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for tokens in tokens_list:
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t in self.token2id:
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
else:
|
||||||
|
token_ids.append(self.oov_id)
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.blank_id)
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
936
egs/ljspeech/TTS/vits2/train.py
Executable file
936
egs/ljspeech/TTS/vits2/train.py
Executable file
@ -0,0 +1,936 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||||
|
from vits import VITS
|
||||||
|
|
||||||
|
from icefall import diagnostics
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of epochs to train.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Resume training from this epoch. It should be positive.
|
||||||
|
If larger than 1, it will load checkpoint from
|
||||||
|
exp-dir/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vits/exp",
|
||||||
|
help="""The experiment dir.
|
||||||
|
It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-diagnostics",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Accumulate stats on activations, print them and exit.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--inf-check",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Add hooks to check for infinite module outputs and gradients.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-n",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="""Save checkpoint after processing this number of epochs"
|
||||||
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
|
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||||
|
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||||
|
Since it will take around 1000 epochs, we suggest using a large
|
||||||
|
save_every_n to save disk space.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use half precision training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
are saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
|
the model that has the lowest training loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_valid_loss: Best validation loss so far. It is used to select
|
||||||
|
the model that has the lowest validation loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_train_epoch: It is the epoch that has the best training loss.
|
||||||
|
|
||||||
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||||
|
|
||||||
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||||
|
contains number of batches trained so far across
|
||||||
|
epochs.
|
||||||
|
|
||||||
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
|
- encoder_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
|
- warm_step: The warmup period that dictates the decay of the
|
||||||
|
scale on "simple" (un-pruned) loss.
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# training params
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": -1, # 0
|
||||||
|
"log_interval": 50,
|
||||||
|
"valid_interval": 200,
|
||||||
|
"env_info": get_env_info(),
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
"frame_shift": 256,
|
||||||
|
"frame_length": 1024,
|
||||||
|
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||||
|
"n_mels": 80,
|
||||||
|
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||||
|
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
||||||
|
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
||||||
|
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
|
||||||
|
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict, model: nn.Module
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
|
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
Returns:
|
||||||
|
Return a dict containing previously saved training info.
|
||||||
|
"""
|
||||||
|
if params.start_epoch > 1:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert filename.is_file(), f"{filename} does not exist!"
|
||||||
|
|
||||||
|
saved_params = load_checkpoint(filename, model=model)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
|
mel_loss_params = {
|
||||||
|
"n_mels": params.n_mels,
|
||||||
|
"frame_length": params.frame_length,
|
||||||
|
"frame_shift": params.frame_shift,
|
||||||
|
}
|
||||||
|
model = VITS(
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
feature_dim=params.feature_dim,
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
mel_loss_params=mel_loss_params,
|
||||||
|
lambda_adv=params.lambda_adv,
|
||||||
|
lambda_mel=params.lambda_mel,
|
||||||
|
lambda_feat_match=params.lambda_feat_match,
|
||||||
|
lambda_dur=params.lambda_dur,
|
||||||
|
lambda_kl=params.lambda_kl,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||||
|
"""Parse batch data"""
|
||||||
|
audio = batch["audio"].to(device)
|
||||||
|
features = batch["features"].to(device)
|
||||||
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
|
features_lens = batch["features_lens"].to(device)
|
||||||
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
|
tokens = k2.RaggedTensor(tokens)
|
||||||
|
row_splits = tokens.shape.row_splits(1)
|
||||||
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
tokens = tokens.to(device)
|
||||||
|
tokens_lens = tokens_lens.to(device)
|
||||||
|
# a tensor of shape (B, T)
|
||||||
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
optimizer_g: Optimizer,
|
||||||
|
optimizer_d: Optimizer,
|
||||||
|
scheduler_g: LRSchedulerType,
|
||||||
|
scheduler_d: LRSchedulerType,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
scaler: GradScaler,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
tokenizer:
|
||||||
|
Used to convert text to phonemes.
|
||||||
|
optimizer_g:
|
||||||
|
The optimizer for generator.
|
||||||
|
optimizer_d:
|
||||||
|
The optimizer for discriminator.
|
||||||
|
scheduler_g:
|
||||||
|
The learning rate scheduler for generator, we call step() every epoch.
|
||||||
|
scheduler_d:
|
||||||
|
The learning rate scheduler for discriminator, we call step() every epoch.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
|
rank:
|
||||||
|
The rank of the node in DDP training. If no DDP is used, it should
|
||||||
|
be set to 0.
|
||||||
|
"""
|
||||||
|
model.train()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to summary the stats over iterations in one epoch
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
saved_bad_model = False
|
||||||
|
|
||||||
|
def save_bad_model(suffix: str = ""):
|
||||||
|
save_checkpoint(
|
||||||
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
|
batch, tokenizer, device
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
if model.module.generator.use_noised_mas:
|
||||||
|
# MAS with Gaussian Noise
|
||||||
|
model.module.generator.noise_current_mas = max(
|
||||||
|
model.module.generator.noise_initial_mas
|
||||||
|
- model.module.generator.noise_scale_mas
|
||||||
|
* params.batch_idx_train
|
||||||
|
* 0.25,
|
||||||
|
0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
# forward discriminator
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
for k, v in stats_d.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
# update discriminator
|
||||||
|
optimizer_d.zero_grad()
|
||||||
|
scaler.scale(loss_d).backward()
|
||||||
|
scaler.step(optimizer_d)
|
||||||
|
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
# forward generator
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
|
)
|
||||||
|
for k, v in stats_g.items():
|
||||||
|
if "returned_sample" not in k:
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
# update generator
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
scaler.scale(loss_g).backward()
|
||||||
|
scaler.step(optimizer_g)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
except: # noqa
|
||||||
|
save_bad_model()
|
||||||
|
raise
|
||||||
|
|
||||||
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
|
return
|
||||||
|
|
||||||
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
|
# behavior depending on the current grad scale.
|
||||||
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
|
if cur_grad_scale < 8.0 or (
|
||||||
|
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||||
|
):
|
||||||
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
|
if cur_grad_scale < 0.01:
|
||||||
|
if not saved_bad_model:
|
||||||
|
save_bad_model(suffix="-first-warning")
|
||||||
|
saved_bad_model = True
|
||||||
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
|
if cur_grad_scale < 1.0e-05:
|
||||||
|
save_bad_model()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.batch_idx_train % params.log_interval == 0:
|
||||||
|
cur_lr_g = max(scheduler_g.get_last_lr())
|
||||||
|
cur_lr_d = max(scheduler_d.get_last_lr())
|
||||||
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||||
|
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||||
|
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
|
||||||
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
|
||||||
|
)
|
||||||
|
loss_info.write_summary(
|
||||||
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
if params.use_fp16:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
|
)
|
||||||
|
if "returned_sample" in stats_g:
|
||||||
|
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/speech_hat_",
|
||||||
|
speech_hat_,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/speech_",
|
||||||
|
speech_,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/mel_hat_",
|
||||||
|
plot_feature(mel_hat_),
|
||||||
|
params.batch_idx_train,
|
||||||
|
dataformats="HWC",
|
||||||
|
)
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/mel_",
|
||||||
|
plot_feature(mel_),
|
||||||
|
params.batch_idx_train,
|
||||||
|
dataformats="HWC",
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train % params.valid_interval == 0
|
||||||
|
and not params.print_diagnostics
|
||||||
|
):
|
||||||
|
logging.info("Computing validation loss")
|
||||||
|
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
valid_info.write_summary(
|
||||||
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valdi_speech_hat",
|
||||||
|
speech_hat,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valdi_speech",
|
||||||
|
speech,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
|
params.train_loss = loss_value
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||||
|
"""Run the validation process."""
|
||||||
|
model.eval()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to summary the stats over iterations
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
returned_sample = None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
|
(
|
||||||
|
audio,
|
||||||
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
) = prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
# forward discriminator
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
assert loss_d.requires_grad is False
|
||||||
|
for k, v in stats_d.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# forward generator
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
)
|
||||||
|
assert loss_g.requires_grad is False
|
||||||
|
for k, v in stats_g.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
|
# infer for first batch:
|
||||||
|
if batch_idx == 0 and rank == 0:
|
||||||
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
|
audio_pred, _, duration = inner_model.inference(
|
||||||
|
text=tokens[0, : tokens_lens[0].item()]
|
||||||
|
)
|
||||||
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
|
audio_len_pred = (
|
||||||
|
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||||
|
)
|
||||||
|
assert audio_len_pred == len(audio_pred), (
|
||||||
|
audio_len_pred,
|
||||||
|
len(audio_pred),
|
||||||
|
)
|
||||||
|
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||||
|
returned_sample = (audio_pred, audio_gt)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
tot_loss.reduce(device)
|
||||||
|
|
||||||
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
|
if loss_value < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss, returned_sample
|
||||||
|
|
||||||
|
|
||||||
|
def scan_pessimistic_batches_for_oom(
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
optimizer_g: torch.optim.Optimizer,
|
||||||
|
optimizer_d: torch.optim.Optimizer,
|
||||||
|
params: AttributeDict,
|
||||||
|
):
|
||||||
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||||
|
)
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
|
for criterion, cuts in batches.items():
|
||||||
|
batch = train_dl.dataset[cuts]
|
||||||
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
|
batch, tokenizer, device
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# for discriminator
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
optimizer_d.zero_grad()
|
||||||
|
loss_d.backward()
|
||||||
|
# for generator
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
)
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
loss_g.backward()
|
||||||
|
except Exception as e:
|
||||||
|
if "CUDA out of memory" in str(e):
|
||||||
|
logging.error(
|
||||||
|
"Your GPU ran out of memory with the current "
|
||||||
|
"max_duration setting. We recommend decreasing "
|
||||||
|
"max_duration and trying again.\n"
|
||||||
|
f"Failing criterion: {criterion} "
|
||||||
|
f"(={crit_values[criterion]}) ..."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(rank, world_size, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`, which is
|
||||||
|
passed automatically by `mp.spawn()` in :func:`main`.
|
||||||
|
The node with rank 0 is responsible for saving checkpoint.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
fix_random_seed(params.seed)
|
||||||
|
if world_size > 1:
|
||||||
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
|
||||||
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.blank_id
|
||||||
|
params.oov_id = tokenizer.oov_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
generator = model.generator
|
||||||
|
discriminator = model.discriminator
|
||||||
|
|
||||||
|
num_param_g = sum([p.numel() for p in generator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||||
|
num_param_d = sum([p.numel() for p in discriminator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||||
|
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
if world_size > 1:
|
||||||
|
logging.info("Using DDP")
|
||||||
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
|
optimizer_g = torch.optim.AdamW(
|
||||||
|
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
|
)
|
||||||
|
optimizer_d = torch.optim.AdamW(
|
||||||
|
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||||
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
||||||
|
|
||||||
|
if checkpoints is not None:
|
||||||
|
# load state_dict for optimizers
|
||||||
|
if "optimizer_g" in checkpoints:
|
||||||
|
logging.info("Loading optimizer_g state dict")
|
||||||
|
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
|
||||||
|
if "optimizer_d" in checkpoints:
|
||||||
|
logging.info("Loading optimizer_d state dict")
|
||||||
|
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
||||||
|
|
||||||
|
# load state_dict for schedulers
|
||||||
|
if "scheduler_g" in checkpoints:
|
||||||
|
logging.info("Loading scheduler_g state dict")
|
||||||
|
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
|
||||||
|
if "scheduler_d" in checkpoints:
|
||||||
|
logging.info("Loading scheduler_d state dict")
|
||||||
|
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
||||||
|
|
||||||
|
if params.print_diagnostics:
|
||||||
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
512
|
||||||
|
) # allow 4 megabytes per sub-module
|
||||||
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
if params.inf_check:
|
||||||
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
|
train_cuts = ljspeech.train_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
|
# an utterance duration distribution for your dataset to select
|
||||||
|
# the threshold
|
||||||
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
|
# logging.warning(
|
||||||
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
|
# )
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
|
valid_cuts = ljspeech.valid_cuts()
|
||||||
|
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
|
if not params.print_diagnostics:
|
||||||
|
scan_pessimistic_batches_for_oom(
|
||||||
|
model=model,
|
||||||
|
train_dl=train_dl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
logging.info("Loading grad scaler state dict")
|
||||||
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
|
logging.info(f"Start epoch {epoch}")
|
||||||
|
|
||||||
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
scaler=scaler,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.print_diagnostics:
|
||||||
|
diagnostic.print_diagnostics()
|
||||||
|
break
|
||||||
|
|
||||||
|
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
if rank == 0:
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
# step per epoch
|
||||||
|
scheduler_g.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LJSpeechTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
218
egs/ljspeech/TTS/vits2/transform.py
Normal file
218
egs/ljspeech/TTS/vits2/transform.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
|
||||||
|
|
||||||
|
"""Flow-related transformation.
|
||||||
|
|
||||||
|
This code is derived from https://github.com/bayesiains/nflows.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||||
|
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||||
|
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(kan-bayashi): Documentation and type hint
|
||||||
|
def piecewise_rational_quadratic_transform(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails=None,
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
if tails is None:
|
||||||
|
spline_fn = rational_quadratic_spline
|
||||||
|
spline_kwargs = {}
|
||||||
|
else:
|
||||||
|
spline_fn = unconstrained_rational_quadratic_spline
|
||||||
|
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||||
|
|
||||||
|
outputs, logabsdet = spline_fn(
|
||||||
|
inputs=inputs,
|
||||||
|
unnormalized_widths=unnormalized_widths,
|
||||||
|
unnormalized_heights=unnormalized_heights,
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
**spline_kwargs
|
||||||
|
)
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(kan-bayashi): Documentation and type hint
|
||||||
|
def unconstrained_rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||||
|
outside_interval_mask = ~inside_interval_mask
|
||||||
|
|
||||||
|
outputs = torch.zeros_like(inputs)
|
||||||
|
logabsdet = torch.zeros_like(inputs)
|
||||||
|
|
||||||
|
if tails == "linear":
|
||||||
|
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||||
|
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||||
|
unnormalized_derivatives[..., 0] = constant
|
||||||
|
unnormalized_derivatives[..., -1] = constant
|
||||||
|
|
||||||
|
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||||
|
logabsdet[outside_interval_mask] = 0
|
||||||
|
else:
|
||||||
|
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||||
|
|
||||||
|
(
|
||||||
|
outputs[inside_interval_mask],
|
||||||
|
logabsdet[inside_interval_mask],
|
||||||
|
) = rational_quadratic_spline(
|
||||||
|
inputs=inputs[inside_interval_mask],
|
||||||
|
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||||
|
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||||
|
inverse=inverse,
|
||||||
|
left=-tail_bound,
|
||||||
|
right=tail_bound,
|
||||||
|
bottom=-tail_bound,
|
||||||
|
top=tail_bound,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(kan-bayashi): Documentation and type hint
|
||||||
|
def rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
left=0.0,
|
||||||
|
right=1.0,
|
||||||
|
bottom=0.0,
|
||||||
|
top=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
||||||
|
):
|
||||||
|
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||||
|
raise ValueError("Input to a transform is not within its domain")
|
||||||
|
|
||||||
|
num_bins = unnormalized_widths.shape[-1]
|
||||||
|
|
||||||
|
if min_bin_width * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin width too large for the number of bins")
|
||||||
|
if min_bin_height * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin height too large for the number of bins")
|
||||||
|
|
||||||
|
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||||
|
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||||
|
cumwidths = torch.cumsum(widths, dim=-1)
|
||||||
|
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
||||||
|
cumwidths = (right - left) * cumwidths + left
|
||||||
|
cumwidths[..., 0] = left
|
||||||
|
cumwidths[..., -1] = right
|
||||||
|
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||||
|
|
||||||
|
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||||
|
|
||||||
|
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||||
|
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||||
|
cumheights = torch.cumsum(heights, dim=-1)
|
||||||
|
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
||||||
|
cumheights = (top - bottom) * cumheights + bottom
|
||||||
|
cumheights[..., 0] = bottom
|
||||||
|
cumheights[..., -1] = top
|
||||||
|
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
bin_idx = _searchsorted(cumheights, inputs)[..., None]
|
||||||
|
else:
|
||||||
|
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
|
||||||
|
|
||||||
|
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||||
|
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||||
|
delta = heights / widths
|
||||||
|
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||||
|
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
a = (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
) + input_heights * (input_delta - input_derivatives)
|
||||||
|
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
)
|
||||||
|
c = -input_delta * (inputs - input_cumheights)
|
||||||
|
|
||||||
|
discriminant = b.pow(2) - 4 * a * c
|
||||||
|
assert (discriminant >= 0).all()
|
||||||
|
|
||||||
|
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||||
|
outputs = root * input_bin_widths + input_cumwidths
|
||||||
|
|
||||||
|
theta_one_minus_theta = root * (1 - root)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||||
|
* theta_one_minus_theta
|
||||||
|
)
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * root.pow(2)
|
||||||
|
+ 2 * input_delta * theta_one_minus_theta
|
||||||
|
+ input_derivatives * (1 - root).pow(2)
|
||||||
|
)
|
||||||
|
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||||
|
|
||||||
|
return outputs, -logabsdet
|
||||||
|
else:
|
||||||
|
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||||
|
theta_one_minus_theta = theta * (1 - theta)
|
||||||
|
|
||||||
|
numerator = input_heights * (
|
||||||
|
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||||
|
)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||||
|
* theta_one_minus_theta
|
||||||
|
)
|
||||||
|
outputs = input_cumheights + numerator / denominator
|
||||||
|
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * theta.pow(2)
|
||||||
|
+ 2 * input_delta * theta_one_minus_theta
|
||||||
|
+ input_derivatives * (1 - theta).pow(2)
|
||||||
|
)
|
||||||
|
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def _searchsorted(bin_locations, inputs, eps=1e-6):
|
||||||
|
bin_locations[..., -1] += eps
|
||||||
|
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
327
egs/ljspeech/TTS/vits2/tts_datamodule.py
Normal file
327
egs/ljspeech/TTS/vits2/tts_datamodule.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
SpeechSynthesisDataset,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
class LJSpeechTtsDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for tts experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="TTS data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/spectrogram"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = 22050
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
|
frame_shift=256 / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=True,
|
||||||
|
)
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = 22050
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
|
frame_shift=256 / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=True,
|
||||||
|
)
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
valid_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create valid dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create test dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = 22050
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
|
frame_shift=256 / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=True,
|
||||||
|
)
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
test_sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=test_sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get validation cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
||||||
|
)
|
265
egs/ljspeech/TTS/vits2/utils.py
Normal file
265
egs/ljspeech/TTS/vits2/utils.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||||
|
def get_random_segments(
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lengths: torch.Tensor,
|
||||||
|
segment_size: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Get random segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, C, T).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
segment_size (int): Segment size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Segmented tensor (B, C, segment_size).
|
||||||
|
Tensor: Start index tensor (B,).
|
||||||
|
|
||||||
|
"""
|
||||||
|
b, c, t = x.size()
|
||||||
|
max_start_idx = x_lengths - segment_size
|
||||||
|
max_start_idx[max_start_idx < 0] = 0
|
||||||
|
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
segments = get_segments(x, start_idxs, segment_size)
|
||||||
|
|
||||||
|
return segments, start_idxs
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
||||||
|
def get_segments(
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_idxs: torch.Tensor,
|
||||||
|
segment_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Get segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, C, T).
|
||||||
|
start_idxs (Tensor): Start index tensor (B,).
|
||||||
|
segment_size (int): Segment size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Segmented tensor (B, C, segment_size).
|
||||||
|
|
||||||
|
"""
|
||||||
|
b, c, t = x.size()
|
||||||
|
segments = x.new_zeros(b, c, segment_size)
|
||||||
|
for i, start_idx in enumerate(start_idxs):
|
||||||
|
segments[i] = x[i, :, start_idx : start_idx + segment_size]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
|
||||||
|
def intersperse(sequence, item=0):
|
||||||
|
result = [item] * (len(sequence) * 2 + 1)
|
||||||
|
result[1::2] = sequence
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/jaywalnut310/vits/blob/main/utils.py
|
||||||
|
MATPLOTLIB_FLAG = False
|
||||||
|
|
||||||
|
|
||||||
|
def plot_feature(spectrogram):
|
||||||
|
global MATPLOTLIB_FLAG
|
||||||
|
if not MATPLOTLIB_FLAG:
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
MATPLOTLIB_FLAG = True
|
||||||
|
mpl_logger = logging.getLogger("matplotlib")
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
plt.xlabel("Frames")
|
||||||
|
plt.ylabel("Channels")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsTracker(collections.defaultdict):
|
||||||
|
def __init__(self):
|
||||||
|
# Passing the type 'int' to the base-class constructor
|
||||||
|
# makes undefined items default to int() which is zero.
|
||||||
|
# This class will play a role as metrics tracker.
|
||||||
|
# It can record many metrics, including but not limited to loss.
|
||||||
|
super(MetricsTracker, self).__init__(int)
|
||||||
|
|
||||||
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||||
|
ans = MetricsTracker()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v
|
||||||
|
for k, v in other.items():
|
||||||
|
ans[k] = ans[k] + v
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||||
|
ans = MetricsTracker()
|
||||||
|
for k, v in self.items():
|
||||||
|
ans[k] = v * alpha
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ans = ""
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
norm_value = "%.4g" % v
|
||||||
|
ans += str(k) + "=" + str(norm_value) + ", "
|
||||||
|
samples = "%.2f" % self["samples"]
|
||||||
|
ans += "over " + str(samples) + " samples."
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Returns a list of pairs, like:
|
||||||
|
[('loss_1', 0.1), ('loss_2', 0.07)]
|
||||||
|
"""
|
||||||
|
samples = self["samples"] if "samples" in self else 1
|
||||||
|
ans = []
|
||||||
|
for k, v in self.items():
|
||||||
|
if k == "samples":
|
||||||
|
continue
|
||||||
|
norm_value = float(v) / samples
|
||||||
|
ans.append((k, norm_value))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def reduce(self, device):
|
||||||
|
"""
|
||||||
|
Reduce using torch.distributed, which I believe ensures that
|
||||||
|
all processes get the total.
|
||||||
|
"""
|
||||||
|
keys = sorted(self.keys())
|
||||||
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
for k, v in zip(keys, s.cpu().tolist()):
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def write_summary(
|
||||||
|
self,
|
||||||
|
tb_writer: SummaryWriter,
|
||||||
|
prefix: str,
|
||||||
|
batch_idx: int,
|
||||||
|
) -> None:
|
||||||
|
"""Add logging information to a TensorBoard writer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tb_writer: a TensorBoard writer
|
||||||
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||||
|
or "train/current_"
|
||||||
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||||
|
"""
|
||||||
|
for k, v in self.norm_items():
|
||||||
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
|
||||||
|
|
||||||
|
# checkpoint saving and loading
|
||||||
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer_g: Optional[Optimizer] = None,
|
||||||
|
optimizer_d: Optional[Optimizer] = None,
|
||||||
|
scheduler_g: Optional[LRSchedulerType] = None,
|
||||||
|
scheduler_d: Optional[LRSchedulerType] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save training information to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
The checkpoint filename.
|
||||||
|
model:
|
||||||
|
The model to be saved. We only save its `state_dict()`.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
|
params:
|
||||||
|
User defined parameters, e.g., epoch, loss.
|
||||||
|
optimizer_g:
|
||||||
|
The optimizer for generator used in the training.
|
||||||
|
Its `state_dict` will be saved.
|
||||||
|
optimizer_d:
|
||||||
|
The optimizer for discriminator used in the training.
|
||||||
|
Its `state_dict` will be saved.
|
||||||
|
scheduler_g:
|
||||||
|
The learning rate scheduler for generator used in the training.
|
||||||
|
Its `state_dict` will be saved.
|
||||||
|
scheduler_d:
|
||||||
|
The learning rate scheduler for discriminator used in the training.
|
||||||
|
Its `state_dict` will be saved.
|
||||||
|
scalar:
|
||||||
|
The GradScaler to be saved. We only save its `state_dict()`.
|
||||||
|
rank:
|
||||||
|
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Saving checkpoint to {filename}")
|
||||||
|
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
model = model.module
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
||||||
|
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
||||||
|
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
||||||
|
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
||||||
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||||
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if params:
|
||||||
|
for k, v in params.items():
|
||||||
|
assert k not in checkpoint
|
||||||
|
checkpoint[k] = v
|
||||||
|
|
||||||
|
torch.save(checkpoint, filename)
|
610
egs/ljspeech/TTS/vits2/vits.py
Normal file
610
egs/ljspeech/TTS/vits2/vits.py
Normal file
@ -0,0 +1,610 @@
|
|||||||
|
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""VITS module for GAN-TTS task."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from generator import VITSGenerator
|
||||||
|
from hifigan import (
|
||||||
|
HiFiGANMultiPeriodDiscriminator,
|
||||||
|
HiFiGANMultiScaleDiscriminator,
|
||||||
|
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
||||||
|
HiFiGANPeriodDiscriminator,
|
||||||
|
HiFiGANScaleDiscriminator,
|
||||||
|
)
|
||||||
|
from loss import (
|
||||||
|
DiscriminatorAdversarialLoss,
|
||||||
|
FeatureMatchLoss,
|
||||||
|
GeneratorAdversarialLoss,
|
||||||
|
KLDivergenceLoss,
|
||||||
|
MelSpectrogramLoss,
|
||||||
|
)
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
from utils import get_segments
|
||||||
|
|
||||||
|
AVAILABLE_GENERATERS = {
|
||||||
|
"vits_generator": VITSGenerator,
|
||||||
|
}
|
||||||
|
AVAILABLE_DISCRIMINATORS = {
|
||||||
|
"hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
|
||||||
|
"hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
|
||||||
|
"hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
|
||||||
|
"hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
|
||||||
|
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VITS(nn.Module):
|
||||||
|
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# generator related
|
||||||
|
vocab_size: int,
|
||||||
|
feature_dim: int = 513,
|
||||||
|
sampling_rate: int = 22050,
|
||||||
|
generator_type: str = "vits_generator",
|
||||||
|
generator_params: Dict[str, Any] = {
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"spks": None,
|
||||||
|
"langs": None,
|
||||||
|
"spk_embed_dim": None,
|
||||||
|
"global_channels": -1,
|
||||||
|
"segment_size": 32,
|
||||||
|
"text_encoder_attention_heads": 2,
|
||||||
|
"text_encoder_ffn_expand": 4,
|
||||||
|
"text_encoder_cnn_module_kernel": 5,
|
||||||
|
"text_encoder_blocks": 6,
|
||||||
|
"text_encoder_dropout_rate": 0.1,
|
||||||
|
"decoder_kernel_size": 7,
|
||||||
|
"decoder_channels": 512,
|
||||||
|
"decoder_upsample_scales": [8, 8, 2, 2],
|
||||||
|
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
|
||||||
|
"decoder_resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"use_weight_norm_in_decoder": True,
|
||||||
|
"posterior_encoder_kernel_size": 5,
|
||||||
|
"posterior_encoder_layers": 16,
|
||||||
|
"posterior_encoder_stacks": 1,
|
||||||
|
"posterior_encoder_base_dilation": 1,
|
||||||
|
"posterior_encoder_dropout_rate": 0.0,
|
||||||
|
"use_weight_norm_in_posterior_encoder": True,
|
||||||
|
"flow_flows": 4,
|
||||||
|
"flow_kernel_size": 5,
|
||||||
|
"flow_base_dilation": 1,
|
||||||
|
"flow_layers": 4,
|
||||||
|
"flow_dropout_rate": 0.0,
|
||||||
|
"use_weight_norm_in_flow": True,
|
||||||
|
"use_only_mean_in_flow": True,
|
||||||
|
"stochastic_duration_predictor_kernel_size": 3,
|
||||||
|
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||||
|
"stochastic_duration_predictor_flows": 4,
|
||||||
|
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||||
|
"use_noised_mas": True,
|
||||||
|
"noise_initial_mas": 0.01,
|
||||||
|
"noise_scale_mas": 2e-06,
|
||||||
|
},
|
||||||
|
# discriminator related
|
||||||
|
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
|
||||||
|
discriminator_params: Dict[str, Any] = {
|
||||||
|
"scales": 1,
|
||||||
|
"scale_downsample_pooling": "AvgPool1d",
|
||||||
|
"scale_downsample_pooling_params": {
|
||||||
|
"kernel_size": 4,
|
||||||
|
"stride": 2,
|
||||||
|
"padding": 2,
|
||||||
|
},
|
||||||
|
"scale_discriminator_params": {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [15, 41, 5, 3],
|
||||||
|
"channels": 128,
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"max_groups": 16,
|
||||||
|
"bias": True,
|
||||||
|
"downsample_scales": [2, 2, 4, 4, 1],
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||||
|
"use_weight_norm": True,
|
||||||
|
"use_spectral_norm": False,
|
||||||
|
},
|
||||||
|
"follow_official_norm": False,
|
||||||
|
"periods": [2, 3, 5, 7, 11],
|
||||||
|
"period_discriminator_params": {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [5, 3],
|
||||||
|
"channels": 32,
|
||||||
|
"downsample_scales": [3, 3, 3, 3, 1],
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"bias": True,
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||||
|
"use_weight_norm": True,
|
||||||
|
"use_spectral_norm": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# loss related
|
||||||
|
generator_adv_loss_params: Dict[str, Any] = {
|
||||||
|
"average_by_discriminators": False,
|
||||||
|
"loss_type": "mse",
|
||||||
|
},
|
||||||
|
discriminator_adv_loss_params: Dict[str, Any] = {
|
||||||
|
"average_by_discriminators": False,
|
||||||
|
"loss_type": "mse",
|
||||||
|
},
|
||||||
|
feat_match_loss_params: Dict[str, Any] = {
|
||||||
|
"average_by_discriminators": False,
|
||||||
|
"average_by_layers": False,
|
||||||
|
"include_final_outputs": True,
|
||||||
|
},
|
||||||
|
mel_loss_params: Dict[str, Any] = {
|
||||||
|
"frame_shift": 256,
|
||||||
|
"frame_length": 1024,
|
||||||
|
"n_mels": 80,
|
||||||
|
},
|
||||||
|
lambda_adv: float = 1.0,
|
||||||
|
lambda_mel: float = 45.0,
|
||||||
|
lambda_feat_match: float = 2.0,
|
||||||
|
lambda_dur: float = 1.0,
|
||||||
|
lambda_kl: float = 1.0,
|
||||||
|
cache_generator_outputs: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize VITS module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): Input vocabrary size.
|
||||||
|
odim (int): Acoustic feature dimension. The actual output channels will
|
||||||
|
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||||
|
compatibility odim is used to indicate the acoustic feature dimension.
|
||||||
|
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||||
|
be referred in saving waveform during the inference.
|
||||||
|
generator_type (str): Generator type.
|
||||||
|
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||||
|
discriminator_type (str): Discriminator type.
|
||||||
|
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
|
||||||
|
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
|
||||||
|
adversarial loss.
|
||||||
|
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
|
||||||
|
discriminator adversarial loss.
|
||||||
|
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
|
||||||
|
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
|
||||||
|
lambda_adv (float): Loss scaling coefficient for adversarial loss.
|
||||||
|
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
|
||||||
|
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
|
||||||
|
lambda_dur (float): Loss scaling coefficient for duration loss.
|
||||||
|
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
|
||||||
|
cache_generator_outputs (bool): Whether to cache generator outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||||
|
if generator_type == "vits_generator":
|
||||||
|
# NOTE(kan-bayashi): Update parameters for the compatibility.
|
||||||
|
# The idim and odim is automatically decided from input data,
|
||||||
|
# where idim represents #vocabularies and odim represents
|
||||||
|
# the input acoustic feature dimension.
|
||||||
|
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
|
||||||
|
self.generator = generator_class(
|
||||||
|
**generator_params,
|
||||||
|
)
|
||||||
|
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
|
||||||
|
self.discriminator = discriminator_class(
|
||||||
|
**discriminator_params,
|
||||||
|
)
|
||||||
|
self.generator_adv_loss = GeneratorAdversarialLoss(
|
||||||
|
**generator_adv_loss_params,
|
||||||
|
)
|
||||||
|
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
|
||||||
|
**discriminator_adv_loss_params,
|
||||||
|
)
|
||||||
|
self.feat_match_loss = FeatureMatchLoss(
|
||||||
|
**feat_match_loss_params,
|
||||||
|
)
|
||||||
|
mel_loss_params.update(sampling_rate=sampling_rate)
|
||||||
|
self.mel_loss = MelSpectrogramLoss(
|
||||||
|
**mel_loss_params,
|
||||||
|
)
|
||||||
|
self.kl_loss = KLDivergenceLoss()
|
||||||
|
|
||||||
|
# coefficients
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.lambda_mel = lambda_mel
|
||||||
|
self.lambda_kl = lambda_kl
|
||||||
|
self.lambda_feat_match = lambda_feat_match
|
||||||
|
self.lambda_dur = lambda_dur
|
||||||
|
|
||||||
|
# cache
|
||||||
|
self.cache_generator_outputs = cache_generator_outputs
|
||||||
|
self._cache = None
|
||||||
|
|
||||||
|
# store sampling rate for saving wav file
|
||||||
|
# (not used for the training)
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
# store parameters for test compatibility
|
||||||
|
self.spks = self.generator.spks
|
||||||
|
self.langs = self.generator.langs
|
||||||
|
self.spk_embed_dim = self.generator.spk_embed_dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
feats: torch.Tensor,
|
||||||
|
feats_lengths: torch.Tensor,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
return_sample: bool = False,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
forward_generator: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
|
"""Perform generator forward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||||
|
speech_lengths (Tensor): Speech length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
forward_generator (bool): Whether to forward generator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- loss (Tensor): Loss scalar tensor.
|
||||||
|
- stats (Dict[str, float]): Statistics to be monitored.
|
||||||
|
"""
|
||||||
|
if forward_generator:
|
||||||
|
return self._forward_generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
speech=speech,
|
||||||
|
speech_lengths=speech_lengths,
|
||||||
|
return_sample=return_sample,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._forward_discrminator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
speech=speech,
|
||||||
|
speech_lengths=speech_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward_generator(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
feats: torch.Tensor,
|
||||||
|
feats_lengths: torch.Tensor,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
return_sample: bool = False,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
|
"""Perform generator forward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||||
|
speech_lengths (Tensor): Speech length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* loss (Tensor): Loss scalar tensor.
|
||||||
|
* stats (Dict[str, float]): Statistics to be monitored.
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
feats = feats.transpose(1, 2)
|
||||||
|
speech = speech.unsqueeze(1)
|
||||||
|
|
||||||
|
# calculate generator outputs
|
||||||
|
reuse_cache = True
|
||||||
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
|
reuse_cache = False
|
||||||
|
outs = self.generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outs = self._cache
|
||||||
|
|
||||||
|
# store cache
|
||||||
|
if self.training and self.cache_generator_outputs and not reuse_cache:
|
||||||
|
self._cache = outs
|
||||||
|
|
||||||
|
# parse outputs
|
||||||
|
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||||
|
_, z_p, m_p, logs_p, _, logs_q = outs_
|
||||||
|
speech_ = get_segments(
|
||||||
|
x=speech,
|
||||||
|
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||||
|
segment_size=self.generator.segment_size * self.generator.upsample_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate discriminator outputs
|
||||||
|
p_hat = self.discriminator(speech_hat_)
|
||||||
|
with torch.no_grad():
|
||||||
|
# do not store discriminator gradient in generator turn
|
||||||
|
p = self.discriminator(speech_)
|
||||||
|
|
||||||
|
# calculate losses
|
||||||
|
with autocast(enabled=False):
|
||||||
|
if not return_sample:
|
||||||
|
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||||
|
else:
|
||||||
|
mel_loss, (mel_hat_, mel_) = self.mel_loss(
|
||||||
|
speech_hat_, speech_, return_mel=True
|
||||||
|
)
|
||||||
|
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
||||||
|
dur_loss = torch.sum(dur_nll.float())
|
||||||
|
adv_loss = self.generator_adv_loss(p_hat)
|
||||||
|
feat_match_loss = self.feat_match_loss(p_hat, p)
|
||||||
|
|
||||||
|
mel_loss = mel_loss * self.lambda_mel
|
||||||
|
kl_loss = kl_loss * self.lambda_kl
|
||||||
|
dur_loss = dur_loss * self.lambda_dur
|
||||||
|
adv_loss = adv_loss * self.lambda_adv
|
||||||
|
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||||||
|
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
|
||||||
|
|
||||||
|
stats = dict(
|
||||||
|
generator_loss=loss.item(),
|
||||||
|
generator_mel_loss=mel_loss.item(),
|
||||||
|
generator_kl_loss=kl_loss.item(),
|
||||||
|
generator_dur_loss=dur_loss.item(),
|
||||||
|
generator_adv_loss=adv_loss.item(),
|
||||||
|
generator_feat_match_loss=feat_match_loss.item(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_sample:
|
||||||
|
stats["returned_sample"] = (
|
||||||
|
speech_hat_[0].data.cpu().numpy(),
|
||||||
|
speech_[0].data.cpu().numpy(),
|
||||||
|
mel_hat_[0].data.cpu().numpy(),
|
||||||
|
mel_[0].data.cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if reuse_cache or not self.training:
|
||||||
|
self._cache = None
|
||||||
|
|
||||||
|
return loss, stats
|
||||||
|
|
||||||
|
def _forward_discrminator(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
feats: torch.Tensor,
|
||||||
|
feats_lengths: torch.Tensor,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||||
|
"""Perform discriminator forward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||||
|
speech_lengths (Tensor): Speech length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* loss (Tensor): Loss scalar tensor.
|
||||||
|
* stats (Dict[str, float]): Statistics to be monitored.
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
feats = feats.transpose(1, 2)
|
||||||
|
speech = speech.unsqueeze(1)
|
||||||
|
|
||||||
|
# calculate generator outputs
|
||||||
|
reuse_cache = True
|
||||||
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
|
reuse_cache = False
|
||||||
|
outs = self.generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outs = self._cache
|
||||||
|
|
||||||
|
# store cache
|
||||||
|
if self.cache_generator_outputs and not reuse_cache:
|
||||||
|
self._cache = outs
|
||||||
|
|
||||||
|
# parse outputs
|
||||||
|
speech_hat_, _, _, start_idxs, *_ = outs
|
||||||
|
speech_ = get_segments(
|
||||||
|
x=speech,
|
||||||
|
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||||
|
segment_size=self.generator.segment_size * self.generator.upsample_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate discriminator outputs
|
||||||
|
p_hat = self.discriminator(speech_hat_.detach())
|
||||||
|
p = self.discriminator(speech_)
|
||||||
|
|
||||||
|
# calculate losses
|
||||||
|
with autocast(enabled=False):
|
||||||
|
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||||
|
loss = real_loss + fake_loss
|
||||||
|
|
||||||
|
stats = dict(
|
||||||
|
discriminator_loss=loss.item(),
|
||||||
|
discriminator_real_loss=real_loss.item(),
|
||||||
|
discriminator_fake_loss=fake_loss.item(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if reuse_cache or not self.training:
|
||||||
|
self._cache = None
|
||||||
|
|
||||||
|
return loss, stats
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
feats: Optional[torch.Tensor] = None,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
spembs: Optional[torch.Tensor] = None,
|
||||||
|
lids: Optional[torch.Tensor] = None,
|
||||||
|
durations: Optional[torch.Tensor] = None,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
max_len: Optional[int] = None,
|
||||||
|
use_teacher_forcing: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Run inference for single sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Input text index tensor (T_text,).
|
||||||
|
feats (Tensor): Feature tensor (T_feats, aux_channels).
|
||||||
|
sids (Tensor): Speaker index tensor (1,).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
|
||||||
|
lids (Tensor): Language index tensor (1,).
|
||||||
|
durations (Tensor): Ground-truth duration tensor (T_text,).
|
||||||
|
noise_scale (float): Noise scale value for flow.
|
||||||
|
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||||
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
|
max_len (Optional[int]): Maximum length.
|
||||||
|
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||||
|
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||||
|
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
text = text[None]
|
||||||
|
text_lengths = torch.tensor(
|
||||||
|
[text.size(1)],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=text.device,
|
||||||
|
)
|
||||||
|
if sids is not None:
|
||||||
|
sids = sids.view(1)
|
||||||
|
if lids is not None:
|
||||||
|
lids = lids.view(1)
|
||||||
|
if durations is not None:
|
||||||
|
durations = durations.view(1, 1, -1)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
if use_teacher_forcing:
|
||||||
|
assert feats is not None
|
||||||
|
feats = feats[None].transpose(1, 2)
|
||||||
|
feats_lengths = torch.tensor(
|
||||||
|
[feats.size(2)],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=feats.device,
|
||||||
|
)
|
||||||
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
max_len=max_len,
|
||||||
|
use_teacher_forcing=use_teacher_forcing,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
dur=durations,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
max_len=max_len,
|
||||||
|
)
|
||||||
|
return wav.view(-1), att_w[0], dur[0]
|
||||||
|
|
||||||
|
def inference_batch(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
sids: Optional[torch.Tensor] = None,
|
||||||
|
durations: Optional[torch.Tensor] = None,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
max_len: Optional[int] = None,
|
||||||
|
use_teacher_forcing: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Run inference for one batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Input text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Input text index tensor (B,).
|
||||||
|
sids (Tensor): Speaker index tensor (B,).
|
||||||
|
noise_scale (float): Noise scale value for flow.
|
||||||
|
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||||
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
|
max_len (Optional[int]): Maximum length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||||
|
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||||
|
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||||
|
"""
|
||||||
|
# inference
|
||||||
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
sids=sids,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
max_len=max_len,
|
||||||
|
)
|
||||||
|
return wav, att_w, dur
|
348
egs/ljspeech/TTS/vits2/wavenet.py
Normal file
348
egs/ljspeech/TTS/vits2/wavenet.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""WaveNet modules.
|
||||||
|
|
||||||
|
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class WaveNet(torch.nn.Module):
|
||||||
|
"""WaveNet with global conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 1,
|
||||||
|
out_channels: int = 1,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
layers: int = 30,
|
||||||
|
stacks: int = 3,
|
||||||
|
base_dilation: int = 2,
|
||||||
|
residual_channels: int = 64,
|
||||||
|
aux_channels: int = -1,
|
||||||
|
gate_channels: int = 128,
|
||||||
|
skip_channels: int = 64,
|
||||||
|
global_channels: int = -1,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
use_weight_norm: bool = True,
|
||||||
|
use_first_conv: bool = False,
|
||||||
|
use_last_conv: bool = False,
|
||||||
|
scale_residual: bool = False,
|
||||||
|
scale_skip_connect: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize WaveNet module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
kernel_size (int): Kernel size of dilated convolution.
|
||||||
|
layers (int): Number of residual block layers.
|
||||||
|
stacks (int): Number of stacks i.e., dilation cycles.
|
||||||
|
base_dilation (int): Base dilation factor.
|
||||||
|
residual_channels (int): Number of channels in residual conv.
|
||||||
|
gate_channels (int): Number of channels in gated conv.
|
||||||
|
skip_channels (int): Number of channels in skip conv.
|
||||||
|
aux_channels (int): Number of channels for local conditioning feature.
|
||||||
|
global_channels (int): Number of channels for global conditioning feature.
|
||||||
|
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
|
||||||
|
bias (bool): Whether to use bias parameter in conv layer.
|
||||||
|
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||||
|
be applied to all of the conv layers.
|
||||||
|
use_first_conv (bool): Whether to use the first conv layers.
|
||||||
|
use_last_conv (bool): Whether to use the last conv layers.
|
||||||
|
scale_residual (bool): Whether to scale the residual outputs.
|
||||||
|
scale_skip_connect (bool): Whether to scale the skip connection outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.layers = layers
|
||||||
|
self.stacks = stacks
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.base_dilation = base_dilation
|
||||||
|
self.use_first_conv = use_first_conv
|
||||||
|
self.use_last_conv = use_last_conv
|
||||||
|
self.scale_skip_connect = scale_skip_connect
|
||||||
|
|
||||||
|
# check the number of layers and stacks
|
||||||
|
assert layers % stacks == 0
|
||||||
|
layers_per_stack = layers // stacks
|
||||||
|
|
||||||
|
# define first convolution
|
||||||
|
if self.use_first_conv:
|
||||||
|
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
||||||
|
|
||||||
|
# define residual blocks
|
||||||
|
self.conv_layers = torch.nn.ModuleList()
|
||||||
|
for layer in range(layers):
|
||||||
|
dilation = base_dilation ** (layer % layers_per_stack)
|
||||||
|
conv = ResidualBlock(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
residual_channels=residual_channels,
|
||||||
|
gate_channels=gate_channels,
|
||||||
|
skip_channels=skip_channels,
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dilation=dilation,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
bias=bias,
|
||||||
|
scale_residual=scale_residual,
|
||||||
|
)
|
||||||
|
self.conv_layers += [conv]
|
||||||
|
|
||||||
|
# define output layers
|
||||||
|
if self.use_last_conv:
|
||||||
|
self.last_conv = torch.nn.Sequential(
|
||||||
|
torch.nn.ReLU(inplace=True),
|
||||||
|
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
||||||
|
torch.nn.ReLU(inplace=True),
|
||||||
|
Conv1d1x1(skip_channels, out_channels, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply weight norm
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: Optional[torch.Tensor] = None,
|
||||||
|
c: Optional[torch.Tensor] = None,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
|
||||||
|
(B, residual_channels, T).
|
||||||
|
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
|
||||||
|
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
|
||||||
|
(B, residual_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# encode to hidden representation
|
||||||
|
if self.use_first_conv:
|
||||||
|
x = self.first_conv(x)
|
||||||
|
|
||||||
|
# residual block
|
||||||
|
skips = 0.0
|
||||||
|
for f in self.conv_layers:
|
||||||
|
x, h = f(x, x_mask=x_mask, c=c, g=g)
|
||||||
|
skips = skips + h
|
||||||
|
x = skips
|
||||||
|
if self.scale_skip_connect:
|
||||||
|
x = x * math.sqrt(1.0 / len(self.conv_layers))
|
||||||
|
|
||||||
|
# apply final layers
|
||||||
|
if self.use_last_conv:
|
||||||
|
x = self.last_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
"""Remove weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _remove_weight_norm(m: torch.nn.Module):
|
||||||
|
try:
|
||||||
|
logging.debug(f"Weight norm is removed from {m}.")
|
||||||
|
torch.nn.utils.remove_weight_norm(m)
|
||||||
|
except ValueError: # this module didn't have weight norm
|
||||||
|
return
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
"""Apply weight normalization module from all of the layers."""
|
||||||
|
|
||||||
|
def _apply_weight_norm(m: torch.nn.Module):
|
||||||
|
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
||||||
|
torch.nn.utils.weight_norm(m)
|
||||||
|
logging.debug(f"Weight norm is applied to {m}.")
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_receptive_field_size(
|
||||||
|
layers: int,
|
||||||
|
stacks: int,
|
||||||
|
kernel_size: int,
|
||||||
|
base_dilation: int,
|
||||||
|
) -> int:
|
||||||
|
assert layers % stacks == 0
|
||||||
|
layers_per_cycle = layers // stacks
|
||||||
|
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
|
||||||
|
return (kernel_size - 1) * sum(dilations) + 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def receptive_field_size(self) -> int:
|
||||||
|
"""Return receptive field size."""
|
||||||
|
return self._get_receptive_field_size(
|
||||||
|
self.layers, self.stacks, self.kernel_size, self.base_dilation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(torch.nn.Conv1d):
|
||||||
|
"""Conv1d module with customized initialization."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Initialize Conv1d module."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
"""Reset parameters."""
|
||||||
|
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
||||||
|
if self.bias is not None:
|
||||||
|
torch.nn.init.constant_(self.bias, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d1x1(Conv1d):
|
||||||
|
"""1x1 Conv1d with customized initialization."""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, bias: bool):
|
||||||
|
"""Initialize 1x1 Conv1d module."""
|
||||||
|
super().__init__(
|
||||||
|
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(torch.nn.Module):
|
||||||
|
"""Residual block module in WaveNet."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
residual_channels: int = 64,
|
||||||
|
gate_channels: int = 128,
|
||||||
|
skip_channels: int = 64,
|
||||||
|
aux_channels: int = 80,
|
||||||
|
global_channels: int = -1,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
dilation: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
scale_residual: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize ResidualBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (int): Kernel size of dilation convolution layer.
|
||||||
|
residual_channels (int): Number of channels for residual connection.
|
||||||
|
skip_channels (int): Number of channels for skip connection.
|
||||||
|
aux_channels (int): Number of local conditioning channels.
|
||||||
|
dropout (float): Dropout probability.
|
||||||
|
dilation (int): Dilation factor.
|
||||||
|
bias (bool): Whether to add bias parameter in convolution layers.
|
||||||
|
scale_residual (bool): Whether to scale the residual outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.residual_channels = residual_channels
|
||||||
|
self.skip_channels = skip_channels
|
||||||
|
self.scale_residual = scale_residual
|
||||||
|
|
||||||
|
# check
|
||||||
|
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||||
|
assert gate_channels % 2 == 0
|
||||||
|
|
||||||
|
# dilation conv
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
self.conv = Conv1d(
|
||||||
|
residual_channels,
|
||||||
|
gate_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
# local conditioning
|
||||||
|
if aux_channels > 0:
|
||||||
|
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
|
||||||
|
else:
|
||||||
|
self.conv1x1_aux = None
|
||||||
|
|
||||||
|
# global conditioning
|
||||||
|
if global_channels > 0:
|
||||||
|
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
|
||||||
|
else:
|
||||||
|
self.conv1x1_glo = None
|
||||||
|
|
||||||
|
# conv output is split into two groups
|
||||||
|
gate_out_channels = gate_channels // 2
|
||||||
|
|
||||||
|
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
|
||||||
|
# (integrate res 1x1 + skip 1x1 convs)
|
||||||
|
self.conv1x1_out = Conv1d1x1(
|
||||||
|
gate_out_channels, residual_channels + skip_channels, bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: Optional[torch.Tensor] = None,
|
||||||
|
c: Optional[torch.Tensor] = None,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, residual_channels, T).
|
||||||
|
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
|
||||||
|
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
||||||
|
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
||||||
|
x = self.conv(x)
|
||||||
|
|
||||||
|
# split into two part for gated activation
|
||||||
|
splitdim = 1
|
||||||
|
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||||
|
|
||||||
|
# local conditioning
|
||||||
|
if c is not None:
|
||||||
|
c = self.conv1x1_aux(c)
|
||||||
|
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
||||||
|
xa, xb = xa + ca, xb + cb
|
||||||
|
|
||||||
|
# global conditioning
|
||||||
|
if g is not None:
|
||||||
|
g = self.conv1x1_glo(g)
|
||||||
|
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
|
||||||
|
xa, xb = xa + ga, xb + gb
|
||||||
|
|
||||||
|
x = torch.tanh(xa) * torch.sigmoid(xb)
|
||||||
|
|
||||||
|
# residual + skip 1x1 conv
|
||||||
|
x = self.conv1x1_out(x)
|
||||||
|
if x_mask is not None:
|
||||||
|
x = x * x_mask
|
||||||
|
|
||||||
|
# split integrated conv results
|
||||||
|
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
|
||||||
|
|
||||||
|
# for residual connection
|
||||||
|
x = x + residual
|
||||||
|
if self.scale_residual:
|
||||||
|
x = x * math.sqrt(0.5)
|
||||||
|
|
||||||
|
return x, s
|
Loading…
x
Reference in New Issue
Block a user