first commit

This commit is contained in:
yaozengwei 2023-10-22 23:14:00 +08:00
parent 162ceaf4b3
commit 3df16b3f2b
27 changed files with 7766 additions and 0 deletions

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LJSpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/spectrogram.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_spectrogram_ljspeech():
src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram")
num_jobs = min(4, os.cpu_count())
sampling_rate = 22050
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True
prefix = "ljspeech"
suffix = "jsonl.gz"
partition = "all"
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
)
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=use_fft_mag,
)
extractor = Spectrogram(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_spectrogram_ljspeech()

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in vits/train.py
for usage.
"""
from lhotse import load_manifest_lazy
def main():
path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz"
cuts = load_manifest_lazy(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Cut statistics:
Cuts count: 13100
Total duration (hh:mm:ss) 23:55:18
mean 6.6
std 2.2
min 1.1
25% 5.0
50% 6.8
75% 8.4
99% 10.0
99.5% 10.1
99.9% 10.1
max 10.1
Recordings available: 13100
Features available: 13100
Supervisions available: 13100
"""

View File

@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script split the LJSpeech dataset cuts into three sets:
- training, 12500
- validation, 100
- test, 500
The numbers are from https://arxiv.org/pdf/2106.06103.pdf
Usage example:
python3 ./local/split_subsets.py ./data/spectrogram
"""
import argparse
import logging
import random
from pathlib import Path
from lhotse import load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest_dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest_dir = Path(args.manifest_dir)
prefix = "ljspeech"
suffix = "jsonl.gz"
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
cut_ids = list(all_cuts.ids)
random.shuffle(cut_ids)
train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500])
valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100])
test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:])
assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)"
assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)"
assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)"
train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}")
valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}")
test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}")
logging.info("Splitted into three sets: training (12500), validation (100), and test (500)")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet)
validate_for_tts(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

77
egs/ljspeech/tts/prepare.sh Executable file
View File

@ -0,0 +1,77 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=1
stage=-1
stop_stage=100
# dl_dir=$PWD/download
dl_dir=/star-data/zengwei/download/ljspeech/
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LJSpeech,
# you can create a symlink
#
# ln -sfv /path/to/LJSpeech $dl_dir/LJSpeech
#
if [ ! -d $dl_dir/LJSpeech-1.1 ]; then
lhotse download ljspeech $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
touch data/manifests/.ljspeech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute spectrogram for LJSpeech"
mkdir -p data/spectrogram
if [ ! -e data/spectrogram/.ljspeech.done ]; then
./local/compute_spectrogram_ljspeech.py
touch data/spectrogram/.ljspeech.done
fi
if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then
log "Validating data/fbank for LJSpeech"
python3 ./local/validate_manifest.py \
data/spectrogram/ljspeech_cuts_all.jsonl.gz
touch data/spectrogram/.ljspeech-validated.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split the LJSpeech cuts into three sets"
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
./local/split_subsets.py data/spectrogram
touch data/spectrogram/.ljspeech_split.done
fi
fi

View File

@ -0,0 +1,97 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -0,0 +1,161 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
return total_norm

View File

@ -0,0 +1,194 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Stochastic duration predictor modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
import torch
import torch.nn.functional as F
from flow import (
ConvFlow,
DilatedDepthSeparableConv,
ElementwiseAffineFlow,
FlipFlow,
LogFlow,
)
class StochasticDurationPredictor(torch.nn.Module):
"""Stochastic duration predictor module.
This is a module of stochastic duration predictor described in `Conditional
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
flows: int = 4,
dds_conv_layers: int = 3,
global_channels: int = -1,
):
"""Initialize StochasticDurationPredictor module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
dropout_rate (float): Dropout rate.
flows (int): Number of flows.
dds_conv_layers (int): Number of conv layers in DDS conv.
global_channels (int): Number of global conditioning channels.
"""
super().__init__()
self.pre = torch.nn.Conv1d(channels, channels, 1)
self.dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate,
)
self.proj = torch.nn.Conv1d(channels, channels, 1)
self.log_flow = LogFlow()
self.flows = torch.nn.ModuleList()
self.flows += [ElementwiseAffineFlow(2)]
for i in range(flows):
self.flows += [
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers,
)
]
self.flows += [FlipFlow()]
self.post_pre = torch.nn.Conv1d(1, channels, 1)
self.post_dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate,
)
self.post_proj = torch.nn.Conv1d(channels, channels, 1)
self.post_flows = torch.nn.ModuleList()
self.post_flows += [ElementwiseAffineFlow(2)]
for i in range(flows):
self.post_flows += [
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers,
)
]
self.post_flows += [FlipFlow()]
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
w: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
noise_scale: float = 1.0,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T_text).
x_mask (Tensor): Mask tensor (B, 1, T_text).
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
inverse (bool): Whether to inverse the flow.
noise_scale (float): Noise scale value.
Returns:
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
If inverse, log-duration tensor (B, 1, T_text).
"""
x = x.detach() # stop gradient
x = self.pre(x)
if g is not None:
x = x + self.global_conv(g.detach()) # stop gradient
x = self.dds(x, x_mask)
x = self.proj(x) * x_mask
if not inverse:
assert w is not None, "w must be provided."
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(
w.size(0),
2,
w.size(2),
).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q
logdet_tot_q = 0.0
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # (B,)
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(
x.size(0),
2,
x.size(2),
).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)
z0, z1 = z.split(1, 1)
logw = z0
return logw

View File

@ -0,0 +1,416 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import librosa
import numpy as np
import torch
from torch import nn
from icefall.utils import make_pad_mask
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py
class Stft(nn.Module):
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
def extra_repr(self):
return (
f"n_fft={self.n_fft}, "
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
f"normalized={self.normalized}, "
f"onesided={self.onesided}"
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.
Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
"""
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
# NOTE(kamo):
# The default behaviour of torch.stft is compatible with librosa.stft
# about padding and scaling.
# Note that it's different from scipy.signal.stft
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(
self.win_length, dtype=input.dtype, device=input.device
)
else:
window = None
# For the compatibility of ARM devices, which do not support
# torch.stft() due to the lack of MKL (on older pytorch versions),
# there is an alternative replacement implementation with librosa.
# Note: pytorch >= 1.10.0 now has native support for FFT and STFT
# on all cpu targets including ARM.
if input.is_cuda or torch.backends.mkl.is_available():
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
)
stft_kwargs["return_complex"] = True
output = torch.stft(input, **stft_kwargs)
output = torch.view_as_real(output)
else:
if self.training:
raise NotImplementedError(
"stft is implemented with librosa on this device, which does not "
"support the training mode."
)
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
# note: librosa does not support a win_length that is < n_ftt
# but the window can be manually padded (see below).
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.hop_length,
center=self.center,
window=window,
pad_mode="reflect",
)
if window is not None:
# pad the given window to n_fft
n_pad_left = (self.n_fft - window.shape[0]) // 2
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
stft_kwargs["window"] = torch.cat(
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
).numpy()
else:
win_length = (
self.win_length if self.win_length is not None else self.n_fft
)
stft_kwargs["window"] = torch.ones(win_length)
output = []
# iterate over istances in a batch
for i, instance in enumerate(input):
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
output = torch.stack(output, 0)
if not self.onesided:
len_conj = self.n_fft - output.shape[1]
conj = output[:, 1 : 1 + len_conj].flip(1)
conj[:, :, :, -1].data *= -1
output = torch.cat([output, conj], 1)
if self.normalized:
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
# output: (Batch, Freq, Frames, 2=real_imag)
# -> (Batch, Frames, Freq, 2=real_imag)
output = output.transpose(1, 2)
if multi_channel:
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
1, 2
)
if ilens is not None:
if self.center:
pad = self.n_fft // 2
ilens = ilens + 2 * pad
olens = (
torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc")
+ 1
)
output.masked_fill_(make_pad_mask(olens), 0.0)
else:
olens = None
return output, olens
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py
class LinearSpectrogram(nn.Module):
"""Linear amplitude spectrogram.
Stft -> amplitude-spec
"""
def __init__(
self,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.n_fft = n_fft
def output_size(self) -> int:
return self.n_fft // 2 + 1
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder."""
return dict(
n_fft=self.n_fft,
n_shift=self.hop_length,
win_length=self.win_length,
window=self.window,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# STFT -> Power spectrum -> Amp spectrum
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
return input_amp, feats_lens
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py
class LogMel(nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = None,
fmax: float = None,
htk: bool = False,
log_base: float = None,
):
super().__init__()
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.mel_options = _mel_options
self.log_base = log_base
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self,
feat: torch.Tensor,
ilens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)
if self.log_base is None:
logmel_feat = mel_feat.log()
elif self.log_base == 2.0:
logmel_feat = mel_feat.log2()
elif self.log_base == 10.0:
logmel_feat = mel_feat.log10()
else:
logmel_feat = mel_feat.log() / torch.log(self.log_base)
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0)
else:
ilens = feat.new_full(
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
)
return logmel_feat, ilens
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py
class LogMelFbank(nn.Module):
"""Conventional frontend structure for TTS.
Stft -> amplitude-spec -> Log-Mel-Fbank
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: Optional[int] = 80,
fmax: Optional[int] = 7600,
htk: bool = False,
log_base: Optional[float] = 10.0,
):
super().__init__()
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
log_base=log_base,
)
def output_size(self) -> int:
return self.n_mels
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder"""
return dict(
fs=self.fs,
n_fft=self.n_fft,
n_shift=self.hop_length,
window=self.window,
n_mels=self.n_mels,
win_length=self.win_length,
fmin=self.fmin,
fmax=self.fmax,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
# TTS: log_10(abs(stft))
# ASR: log_e(power(stft))
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
input_feats, _ = self.logmel(input_amp, feats_lens)
return input_feats, feats_lens

View File

@ -0,0 +1,311 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Basic Flow modules used in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional, Tuple, Union
import torch
from transform import piecewise_rational_quadratic_transform
class FlipFlow(torch.nn.Module):
"""Flip flow module."""
def forward(
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Flipped tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
x = torch.flip(x, [1])
if not inverse:
logdet = x.new_zeros(x.size(0))
return x, logdet
else:
return x
class LogFlow(torch.nn.Module):
"""Log flow module."""
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
inverse: bool = False,
eps: float = 1e-5,
**kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
eps (float): Epsilon for log.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = torch.log(torch.clamp_min(x, eps)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffineFlow(torch.nn.Module):
"""Elementwise affine flow module."""
def __init__(self, channels: int):
"""Initialize ElementwiseAffineFlow module.
Args:
channels (int): Number of channels.
"""
super().__init__()
self.channels = channels
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_lengths (Tensor): Length tensor (B,).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class Transpose(torch.nn.Module):
"""Transpose module for torch.nn.Sequential()."""
def __init__(self, dim1: int, dim2: int):
"""Initialize Transpose module."""
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transpose."""
return x.transpose(self.dim1, self.dim2)
class DilatedDepthSeparableConv(torch.nn.Module):
"""Dilated depth-separable conv module."""
def __init__(
self,
channels: int,
kernel_size: int,
layers: int,
dropout_rate: float = 0.0,
eps: float = 1e-5,
):
"""Initialize DilatedDepthSeparableConv module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
dropout_rate (float): Dropout rate.
eps (float): Epsilon for layer norm.
"""
super().__init__()
self.convs = torch.nn.ModuleList()
for i in range(layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs += [
torch.nn.Sequential(
torch.nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
),
Transpose(1, 2),
torch.nn.LayerNorm(
channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
torch.nn.GELU(),
torch.nn.Conv1d(
channels,
channels,
1,
),
Transpose(1, 2),
torch.nn.LayerNorm(
channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
)
]
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, channels, T).
"""
if g is not None:
x = x + g
for f in self.convs:
y = f(x * x_mask)
x = x + y
return x * x_mask
class ConvFlow(torch.nn.Module):
"""Convolutional flow module."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
layers: int,
bins: int = 10,
tail_bound: float = 5.0,
):
"""Initialize ConvFlow module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
bins (int): Number of bins.
tail_bound (float): Tail bound value.
"""
super().__init__()
self.half_channels = in_channels // 2
self.hidden_channels = hidden_channels
self.bins = bins
self.tail_bound = tail_bound
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.dds_conv = DilatedDepthSeparableConv(
hidden_channels,
kernel_size,
layers,
dropout_rate=0.0,
)
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * (bins * 3 - 1),
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, 1)
h = self.input_conv(xa)
h = self.dds_conv(h, x_mask, g=g)
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
b, c, t = xa.shape
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
# TODO(kan-bayashi): Understand this calculation
denom = math.sqrt(self.hidden_channels)
unnorm_widths = h[..., : self.bins] / denom
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins :]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([xa, xb], 1) * x_mask
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
if not inverse:
return x, logdet
else:
return x

View File

@ -0,0 +1,524 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Generator module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
aux_channels: int = 513,
hidden_channels: int = 192,
spks: Optional[int] = None,
langs: Optional[int] = None,
spk_embed_dim: Optional[int] = None,
global_channels: int = -1,
segment_size: int = 32,
text_encoder_attention_heads: int = 2,
text_encoder_ffn_expand: int = 4,
text_encoder_blocks: int = 6,
text_encoder_dropout_rate: float = 0.1,
decoder_kernel_size: int = 7,
decoder_channels: int = 512,
decoder_upsample_scales: List[int] = [8, 8, 2, 2],
decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
use_weight_norm_in_decoder: bool = True,
posterior_encoder_kernel_size: int = 5,
posterior_encoder_layers: int = 16,
posterior_encoder_stacks: int = 1,
posterior_encoder_base_dilation: int = 1,
posterior_encoder_dropout_rate: float = 0.0,
use_weight_norm_in_posterior_encoder: bool = True,
flow_flows: int = 4,
flow_kernel_size: int = 5,
flow_base_dilation: int = 1,
flow_layers: int = 4,
flow_dropout_rate: float = 0.0,
use_weight_norm_in_flow: bool = True,
use_only_mean_in_flow: bool = True,
stochastic_duration_predictor_kernel_size: int = 3,
stochastic_duration_predictor_dropout_rate: float = 0.5,
stochastic_duration_predictor_flows: int = 4,
stochastic_duration_predictor_dds_conv_layers: int = 3,
):
"""Initialize VITS generator module.
Args:
vocabs (int): Input vocabulary size.
aux_channels (int): Number of acoustic feature channels.
hidden_channels (int): Number of hidden channels.
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
sids will be provided as the input and use sid embedding layer.
langs (Optional[int]): Number of languages. If set to > 1, assume that the
lids will be provided as the input and use sid embedding layer.
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
assume that spembs will be provided as the input.
global_channels (int): Number of global conditioning channels.
segment_size (int): Segment size for decoder.
text_encoder_attention_heads (int): Number of heads in conformer block
of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder.
decoder_kernel_size (int): Decoder kernel size.
decoder_channels (int): Number of decoder initial channels.
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
upsampling layers in decoder.
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
in decoder.
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
resblocks in decoder.
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
decoder.
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
posterior_encoder_layers (int): Number of layers of posterior encoder.
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
normalization in posterior encoder.
flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow.
flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
flow.
use_only_mean_in_flow (bool): Whether to use only mean in flow.
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
duration predictor.
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
stochastic duration predictor.
stochastic_duration_predictor_flows (int): Number of flows in stochastic
duration predictor.
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
layers in stochastic duration predictor.
"""
super().__init__()
self.segment_size = segment_size
self.text_encoder = TextEncoder(
vocabs=vocabs,
d_model=hidden_channels,
num_heads=text_encoder_attention_heads,
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
num_layers=text_encoder_blocks,
dropout=text_encoder_dropout_rate,
)
self.decoder = HiFiGANGenerator(
in_channels=hidden_channels,
out_channels=1,
channels=decoder_channels,
global_channels=global_channels,
kernel_size=decoder_kernel_size,
upsample_scales=decoder_upsample_scales,
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
resblock_dilations=decoder_resblock_dilations,
use_weight_norm=use_weight_norm_in_decoder,
)
self.posterior_encoder = PosteriorEncoder(
in_channels=aux_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
kernel_size=posterior_encoder_kernel_size,
layers=posterior_encoder_layers,
stacks=posterior_encoder_stacks,
base_dilation=posterior_encoder_base_dilation,
global_channels=global_channels,
dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder,
)
self.flow = ResidualAffineCouplingBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
flows=flow_flows,
kernel_size=flow_kernel_size,
base_dilation=flow_base_dilation,
layers=flow_layers,
global_channels=global_channels,
dropout_rate=flow_dropout_rate,
use_weight_norm=use_weight_norm_in_flow,
use_only_mean=use_only_mean_in_flow,
)
# TODO(kan-bayashi): Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor(
channels=hidden_channels,
kernel_size=stochastic_duration_predictor_kernel_size,
dropout_rate=stochastic_duration_predictor_dropout_rate,
flows=stochastic_duration_predictor_flows,
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
global_channels=global_channels,
)
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
if spk_embed_dim is not None and spk_embed_dim > 0:
assert global_channels > 0
self.spk_embed_dim = spk_embed_dim
self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
self.langs = None
if langs is not None and langs > 1:
assert global_channels > 0
self.langs = langs
self.lang_emb = torch.nn.Embedding(langs, global_channels)
# delayed import
from monotonic_align import maximum_path
self.maximum_path = maximum_path
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
]:
"""Calculate forward propagation.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
Tensor: Segments start index tensor (B,).
Tensor: Text mask tensor (B, 1, T_text).
Tensor: Feature mask tensor (B, 1, T_feats).
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
- Tensor: Flow hidden representation (B, H, T_feats).
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
- Tensor: Posterior encoder projected mean (B, H, T_feats).
- Tensor: Posterior encoder projected scale (B, H, T_feats).
"""
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
# calculate global conditioning
g = None
if self.spks is not None:
# speaker one-hot vector embedding: (B, global_channels, 1)
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# language one-hot vector embedding: (B, global_channels, 1)
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
# forward flow
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
# monotonic alignment search
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
# (B, 1, T_text)
neg_x_ent_1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2),
s_p_sq_r,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = torch.matmul(
z_p.transpose(1, 2),
(m_p * s_p_sq_r),
)
# (B, 1, T_text)
neg_x_ent_4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True,
)
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = (
self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1),
)
.unsqueeze(1)
.detach()
)
# forward duration predictor
w = attn.sum(2) # (B, 1, T_text)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / torch.sum(x_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
# get random segments
z_segments, z_start_idxs = get_random_segments(
z,
feats_lengths,
self.segment_size,
)
# forward decoder with random segments
wav = self.decoder(z_segments, g=g)
return (
wav,
dur_nll,
attn,
z_start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
)
def inference(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: Optional[torch.Tensor] = None,
feats_lengths: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
dur: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (B, T_text,).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
skip the prediction of durations (i.e., teacher forcing).
noise_scale (float): Noise scale parameter for flow.
noise_scale_dur (float): Noise scale parameter for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length of acoustic feature sequence.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Tensor: Generated waveform tensor (B, T_wav).
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
Tensor: Duration tensor (B, T_text).
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
g = None
if self.spks is not None:
# (B, global_channels, 1)
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if use_teacher_forcing:
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
# forward flow
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
# monotonic alignment search
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
# (B, 1, T_text)
neg_x_ent_1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2),
s_p_sq_r,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = torch.matmul(
z_p.transpose(1, 2),
(m_p * s_p_sq_r),
)
# (B, 1, T_text)
neg_x_ent_4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True,
)
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1),
).unsqueeze(1)
dur = attn.sum(2) # (B, 1, T_text)
# forward decoder with random segments
wav = self.decoder(z * y_mask, g=g)
else:
# duration
if dur is None:
logw = self.duration_predictor(
x,
x_mask,
g=g,
inverse=True,
noise_scale=noise_scale_dur,
)
w = torch.exp(logw) * x_mask * alpha
dur = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = self._generate_path(dur, attn_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = torch.matmul(
attn.squeeze(1),
m_p.transpose(1, 2),
).transpose(1, 2)
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = torch.matmul(
attn.squeeze(1),
logs_p.transpose(1, 2),
).transpose(1, 2)
# decoder
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, inverse=True)
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate path a.k.a. monotonic attention.
Args:
dur (Tensor): Duration tensor (B, 1, T_text).
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
Returns:
Tensor: Path tensor (B, 1, T_feats, T_text).
"""
b, _, t_y, t_x = mask.shape
cum_dur = torch.cumsum(dur, -1)
cum_dur_flat = cum_dur.view(b * t_x)
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
# path will be like (t_x = 3, t_y = 5):
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
return path.unsqueeze(1).transpose(2, 3) * mask

View File

@ -0,0 +1,933 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HiFi-GAN Modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import copy
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn.functional as F
class HiFiGANGenerator(torch.nn.Module):
"""HiFiGAN generator module."""
def __init__(
self,
in_channels: int = 80,
out_channels: int = 1,
channels: int = 512,
global_channels: int = -1,
kernel_size: int = 7,
upsample_scales: List[int] = [8, 8, 2, 2],
upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
use_additional_convs: bool = True,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
):
"""Initialize HiFiGANGenerator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (List[int]): List of upsampling scales.
upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
resblock_dilations (List[List[int]]): List of list of dilations for residual
blocks.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
"""
super().__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, "Kernel size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
# define modules
self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
self.num_upsamples = len(upsample_kernel_sizes)
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.ConvTranspose1d(
channels // (2**i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
),
)
]
for j in range(len(resblock_kernel_sizes)):
self.blocks += [
ResidualBlock(
kernel_size=resblock_kernel_sizes[j],
channels=channels // (2 ** (i + 1)),
dilations=resblock_dilations[j],
bias=bias,
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
)
]
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
),
torch.nn.Tanh(),
)
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
def forward(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T).
"""
c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
cs = 0.0 # initialize
for j in range(self.num_blocks):
cs += self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks
c = self.output_conv(c)
return c
def reset_parameters(self):
"""Reset parameters.
This initialization follows the official implementation manner.
https://github.com/jik876/hifi-gan/blob/master/models.py
"""
def _reset_parameters(m: torch.nn.Module):
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
m.weight.data.normal_(0.0, 0.01)
logging.debug(f"Reset parameters in {m}.")
self.apply(_reset_parameters)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(
m, torch.nn.ConvTranspose1d
):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def inference(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Perform inference.
Args:
c (torch.Tensor): Input tensor (T, in_channels).
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
Returns:
Tensor: Output tensor (T ** upsample_factor, out_channels).
"""
if g is not None:
g = g.unsqueeze(0)
c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
return c.squeeze(0).transpose(1, 0)
class ResidualBlock(torch.nn.Module):
"""Residual block module in HiFiGAN."""
def __init__(
self,
kernel_size: int = 3,
channels: int = 512,
dilations: List[int] = [1, 3, 5],
bias: bool = True,
use_additional_convs: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
channels (int): Number of channels for convolution layer.
dilations (List[int]): List of dilation factors.
use_additional_convs (bool): Whether to use additional convolution layers.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
"""
super().__init__()
self.use_additional_convs = use_additional_convs
self.convs1 = torch.nn.ModuleList()
if use_additional_convs:
self.convs2 = torch.nn.ModuleList()
assert kernel_size % 2 == 1, "Kernel size must be odd number."
for dilation in dilations:
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
bias=bias,
padding=(kernel_size - 1) // 2 * dilation,
),
)
]
if use_additional_convs:
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
bias=bias,
padding=(kernel_size - 1) // 2,
),
)
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
Returns:
Tensor: Output tensor (B, channels, T).
"""
for idx in range(len(self.convs1)):
xt = self.convs1[idx](x)
if self.use_additional_convs:
xt = self.convs2[idx](xt)
x = xt + x
return x
class HiFiGANPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN period discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
period: int = 3,
kernel_sizes: List[int] = [5, 3],
channels: int = 32,
downsample_scales: List[int] = [3, 3, 3, 3, 1],
max_downsample_channels: int = 1024,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initialize HiFiGANPeriodDiscriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
period (int): Period.
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
layer.
channels (int): Number of initial channels.
downsample_scales (List[int]): List of downsampling scales.
max_downsample_channels (int): Number of maximum downsampling channels.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm.
If set to true, it will be applied to all of the conv layers.
"""
super().__init__()
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
self.period = period
self.convs = torch.nn.ModuleList()
in_chs = in_channels
out_chs = channels
for downsample_scale in downsample_scales:
self.convs += [
torch.nn.Sequential(
torch.nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Use downsample_scale + 1?
out_chs = min(out_chs * 4, max_downsample_channels)
self.output_conv = torch.nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
if use_spectral_norm:
self.apply_spectral_norm()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
Returns:
list: List of each layer's tensors.
"""
# transform 1d to 2d -> (B, C, T/P, P)
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t += n_pad
x = x.view(b, c, t // self.period, self.period)
# forward conv
outs = []
for layer in self.convs:
x = layer(x)
outs += [x]
x = self.output_conv(x)
x = torch.flatten(x, 1, -1)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN multi-period discriminator module."""
def __init__(
self,
periods: List[int] = [2, 3, 5, 7, 11],
discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
):
"""Initialize HiFiGANMultiPeriodDiscriminator module.
Args:
periods (List[int]): List of periods.
discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
discriminator module. The period parameter will be overwritten.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params["period"] = period
self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each
layer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
return outs
class HiFiGANScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN scale discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_sizes: List[int] = [15, 41, 5, 3],
channels: int = 128,
max_downsample_channels: int = 1024,
max_groups: int = 16,
bias: int = True,
downsample_scales: List[int] = [2, 2, 4, 4, 1],
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initilize HiFiGAN scale discriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_sizes (List[int]): List of four kernel sizes. The first will be used
for the first conv layer, and the second is for downsampling part, and
the remaining two are for the last two output layers.
channels (int): Initial number of channels for conv layer.
max_downsample_channels (int): Maximum number of channels for downsampling
layers.
bias (bool): Whether to add bias parameter in convolution layers.
downsample_scales (List[int]): List of downsampling scales.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
will be applied to all of the conv layers.
"""
super().__init__()
self.layers = torch.nn.ModuleList()
# check kernel size is valid
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
# add first layer
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_channels,
channels,
# NOTE(kan-bayashi): Use always the same kernel size
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
# add downsample layers
in_chs = channels
out_chs = channels
# NOTE(kan-bayashi): Remove hard coding?
groups = 4
for downsample_scale in downsample_scales:
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Remove hard coding?
out_chs = min(in_chs * 2, max_downsample_channels)
# NOTE(kan-bayashi): Remove hard coding?
groups = min(groups * 4, max_groups)
# add final layers
out_chs = min(in_chs * 2, max_downsample_channels)
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
self.layers += [
torch.nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
),
]
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
self.use_weight_norm = use_weight_norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
self.use_spectral_norm = use_spectral_norm
if use_spectral_norm:
self.apply_spectral_norm()
# backward compatibility
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[Tensor]: List of output tensors of each layer.
"""
outs = []
for f in self.layers:
x = f(x)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def remove_spectral_norm(self):
"""Remove spectral normalization module from all of the layers."""
def _remove_spectral_norm(m):
try:
logging.debug(f"Spectral norm is removed from {m}.")
torch.nn.utils.remove_spectral_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_spectral_norm)
def _load_state_dict_pre_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Fix the compatibility of weight / spectral normalization issue.
Some pretrained models are trained with configs that use weight / spectral
normalization, but actually, the norm is not applied. This causes the mismatch
of the parameters with configs. To solve this issue, when parameter mismatch
happens in loading pretrained model, we remove the norm from the current model.
See also:
- https://github.com/espnet/espnet/pull/5240
- https://github.com/espnet/espnet/pull/5249
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
"""
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
if self.use_weight_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_weight_norm()
self.use_weight_norm = False
for k in current_module_keys:
if k.endswith("weight_g") or k.endswith("weight_v"):
del state_dict[k]
if self.use_spectral_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_spectral_norm()
self.use_spectral_norm = False
for k in current_module_keys:
if (
k.endswith("weight_u")
or k.endswith("weight_v")
or k.endswith("weight_orig")
):
del state_dict[k]
class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale discriminator module."""
def __init__(
self,
scales: int = 3,
downsample_pooling: str = "AvgPool1d",
# follow the official implementation setting
downsample_pooling_params: Dict[str, Any] = {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm: bool = False,
):
"""Initilize HiFiGAN multi-scale discriminator module.
Args:
scales (int): Number of multi-scales.
downsample_pooling (str): Pooling module name for downsampling of the
inputs.
downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
module.
discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
discriminator module.
follow_official_norm (bool): Whether to follow the norm setting of the
official implementaion. The first discriminator uses spectral norm
and the other discriminators use weight norm.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
if i == 0:
params["use_weight_norm"] = False
params["use_spectral_norm"] = True
else:
params["use_weight_norm"] = True
params["use_spectral_norm"] = False
self.discriminators += [HiFiGANScaleDiscriminator(**params)]
self.pooling = None
if scales > 1:
self.pooling = getattr(torch.nn, downsample_pooling)(
**downsample_pooling_params
)
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[List[torch.Tensor]]: List of list of each discriminator outputs,
which consists of eachlayer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
if self.pooling is not None:
x = self.pooling(x)
return outs
class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale + multi-period discriminator module."""
def __init__(
self,
# Multi-scale discriminator related
scales: int = 3,
scale_downsample_pooling: str = "AvgPool1d",
scale_downsample_pooling_params: Dict[str, Any] = {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
scale_discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm: bool = True,
# Multi-period discriminator related
periods: List[int] = [2, 3, 5, 7, 11],
period_discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
):
"""Initilize HiFiGAN multi-scale + multi-period discriminator module.
Args:
scales (int): Number of multi-scales.
scale_downsample_pooling (str): Pooling module name for downsampling of the
inputs.
scale_downsample_pooling_params (dict): Parameters for the above pooling
module.
scale_discriminator_params (dict): Parameters for hifi-gan scale
discriminator module.
follow_official_norm (bool): Whether to follow the norm setting of the
official implementaion. The first discriminator uses spectral norm and
the other discriminators use weight norm.
periods (list): List of periods.
period_discriminator_params (dict): Parameters for hifi-gan period
discriminator module. The period parameter will be overwritten.
"""
super().__init__()
self.msd = HiFiGANMultiScaleDiscriminator(
scales=scales,
downsample_pooling=scale_downsample_pooling,
downsample_pooling_params=scale_downsample_pooling_params,
discriminator_params=scale_discriminator_params,
follow_official_norm=follow_official_norm,
)
self.mpd = HiFiGANMultiPeriodDiscriminator(
periods=periods,
discriminator_params=period_discriminator_params,
)
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[List[Tensor]]: List of list of each discriminator outputs,
which consists of each layer output tensors. Multi scale and
multi period ones are concatenated.
"""
msd_outs = self.msd(x)
mpd_outs = self.mpd(x)
return msd_outs + mpd_outs

View File

@ -0,0 +1,332 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HiFiGAN-related loss modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
from typing import List, Optional, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize GeneratorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(
self,
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Calcualate generator adversarial loss.
Args:
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs..
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers: bool = True,
average_by_discriminators: bool = True,
include_final_outputs: bool = False,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
sampling_rate: int = 22050,
frame_length: int = 1024, # in samples
frame_shift: int = 256, # in samples
n_mels: int = 80,
use_fft_mag: bool = True,
):
super().__init__()
self.wav_to_mel = Wav2LogFilterBank(
sampling_rate=sampling_rate,
frame_length=frame_length / sampling_rate, # in second
frame_shift=frame_shift / sampling_rate, # in second
use_fft_mag=use_fft_mag,
num_filters=n_mels,
)
def forward(
self,
y_hat: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated waveform tensor (B, 1, T).
y (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
mel = self.wav_to_mel(y.squeeze(1))
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
"""VITS-related loss modules.
This code is based on https://github.com/jaywalnut310/vits.
"""
class KLDivergenceLoss(torch.nn.Module):
"""KL divergence loss."""
def forward(
self,
z_p: torch.Tensor,
logs_q: torch.Tensor,
m_p: torch.Tensor,
logs_p: torch.Tensor,
z_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
loss = kl / torch.sum(z_mask)
return loss
class KLDivergenceLossWithoutFlow(torch.nn.Module):
"""KL divergence loss without flow."""
def forward(
self,
m_q: torch.Tensor,
logs_q: torch.Tensor,
m_p: torch.Tensor,
logs_p: torch.Tensor,
) -> torch.Tensor:
"""Calculate KL divergence loss without flow.
Args:
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
"""
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
prior_norm = D.Normal(m_p, torch.exp(logs_p))
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
return loss

View File

@ -0,0 +1,534 @@
import copy
import math
import torch
from torch import nn
from torch.nn import functional as F
import commons
import modules
import attentions
import monotonic_align
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from commons import init_weights, get_padding
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2,3,5,7,11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
n_vocab,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
**kwargs):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.use_sdp = use_sdp
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
if use_sdp:
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
else:
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
w = attn.sum(2)
if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g)
l_length = l_length / torch.sum(x_mask)
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
# expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
if self.use_sdp:
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
else:
logw = self.dp(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)

View File

@ -0,0 +1,81 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
"""Maximum path calculation module.
This code is based on https://github.com/jaywalnut310/vits.
"""
import warnings
import numpy as np
import torch
from numba import njit, prange
try:
from .core import maximum_path_c
is_cython_avalable = True
except ImportError:
is_cython_avalable = False
warnings.warn(
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
"If you want to use the cython version, please build it as follows: "
"`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
)
def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
"""Calculate maximum path.
Args:
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
Returns:
Tensor: Maximum path tensor (B, T_feats, T_text).
"""
device, dtype = neg_x_ent.device, neg_x_ent.dtype
neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
if is_cython_avalable:
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
else:
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
@njit
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
"""Calculate a single maximum path with numba."""
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@njit(parallel=True)
def maximum_path_numba(paths, values, t_ys, t_xs):
"""Calculate batch maximum path with numba."""
for i in prange(paths.shape[0]):
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])

View File

@ -0,0 +1,51 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
"""Maximum path calculation module with cython optimization.
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
"""
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])

View File

@ -0,0 +1,31 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
"""Setup cython code."""
from Cython.Build import cythonize
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [
Extension(
name="core",
sources=["core.pyx"],
)
]
setup(
name="monotonic_align",
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext},
)

View File

@ -0,0 +1,117 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Posterior encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional, Tuple
import torch
from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):
"""Posterior encoder module in VITS.
This is a module of posterior encoder described in `Conditional Variational
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int = 513,
out_channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 5,
layers: int = 16,
stacks: int = 1,
base_dilation: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
):
"""Initilialize PosteriorEncoder module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size in WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of repeat stacking of WaveNet.
base_dilation (int): Base dilation factor.
global_channels (int): Number of global conditioning channels.
dropout_rate (float): Dropout rate.
bias (bool): Whether to use bias parameters in conv.
use_weight_norm (bool): Whether to apply weight norm.
"""
super().__init__()
# define modules
self.input_conv = Conv1d(in_channels, hidden_channels, 1)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_feats).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
Tensor: Projected mean tensor (B, out_channels, T_feats).
Tensor: Projected scale tensor (B, out_channels, T_feats).
Tensor: Mask tensor for input tensor (B, 1, T_feats).
"""
x_mask = (
(~make_pad_mask(x_lengths))
.unsqueeze(1)
.to(
dtype=x.dtype,
device=x.device,
)
)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask

View File

@ -0,0 +1,229 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Residual affine coupling modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional, Tuple, Union
import torch
from flow import FlipFlow
from wavenet import WaveNet
class ResidualAffineCouplingBlock(torch.nn.Module):
"""Residual affine coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
flows: int = 4,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 4,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initilize ResidualAffineCouplingBlock module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
flows (int): Number of flows.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
super().__init__()
self.flows = torch.nn.ModuleList()
for i in range(flows):
self.flows += [
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean,
)
]
self.flows += [FlipFlow()]
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
"""
if not inverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, inverse=inverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, inverse=inverse)
return x
class ResidualAffineCouplingLayer(torch.nn.Module):
"""Residual affine coupling layer."""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 5,
stacks: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initialzie ResidualAffineCouplingLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
# define modules
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
if use_only_mean:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels,
1,
)
else:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * 2,
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, dim=1)
h = self.input_conv(xa) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = stats.split(stats.size(1) // 2, dim=1)
else:
m = stats
logs = torch.zeros_like(m)
if not inverse:
xb = m + xb * torch.exp(logs) * x_mask
x = torch.cat([xa, xb], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * torch.exp(-logs) * x_mask
x = torch.cat([xa, xb], 1)
return x

View File

@ -0,0 +1,17 @@
# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py
""" from https://github.com/keithito/tacotron """
'''
Defines the set of symbols used in text input to the model.
'''
_pad = '_'
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
# Export all symbols:
symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
# Special symbol ids
SPACE_ID = symbol_table.index(" ")

View File

@ -0,0 +1,534 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on
- https://github.com/jaywalnut310/vits
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
"""
import copy
import math
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import make_pad_mask
class TextEncoder(torch.nn.Module):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
num_layers: int = 6,
dropout: float = 0.1,
):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
super().__init__()
self.d_model = d_model
# define modules
self.emb = torch.nn.Embedding(vocabs, d_model)
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
self.encoder = Transformer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
num_layers=num_layers,
dropout=dropout,
)
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
# (B, T_text, embed_dim)
x = self.emb(x) * math.sqrt(self.d_model)
assert x.size(1) == x_lengths.max().item()
# (B, T_text)
pad_mask = make_pad_mask(x_lengths)
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
non_pad_mask = (~pad_mask).unsqueeze(1)
stats = self.proj(x) * non_pad_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
return x, m, logs, non_pad_mask
class Transformer(nn.Module):
"""
Args:
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
def __init__(
self,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
num_layers: int = 6,
dropout: float = 0.1,
) -> None:
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
self.after_norm = nn.LayerNorm(d_model)
def forward(
self, x: Tensor, key_padding_mask: Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
lengths:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
"""
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder(
x, pos_emb, key_padding_mask=key_padding_mask
) # (T, N, C)
x = self.after_norm(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x
class TransformerEncoderLayer(nn.Module):
"""
TransformerEncoderLayer is made up of self-attn and feedforward.
Args:
d_model: the number of expected features in the input.
num_heads: the number of heads in the multi-head attention models.
dim_feedforward: the dimension of the feed-forward network model.
dropout: the dropout value (default=0.1).
"""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
dropout: float = 0.1,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the transformer encoder layer.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# multi-head self-attention module
src_attn = self.self_attn(
self.norm_mha(src),
pos_emb=pos_emb,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout(src_attn)
# feed-forward module
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
src = self.norm_final(src)
return src
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer class.
num_layers: the number of sub-encoder-layers in the encoder.
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
output = src
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
key_padding_mask=key_padding_mask,
)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
x_size = x.size(1)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x_size * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size, self.d_model)
pe_negative = torch.zeros(x_size, self.d_model)
position = torch.arange(0, x_size, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, seq_len, 2*seq_len-1).
Returns:
Tensor: tensor of shape (batch, head, seq_len, seq_len)
"""
(batch_size, num_heads, seq_len, n) = x.shape
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, seq_len, seq_len),
(batch_stride, head_stride, time_stride - n_stride, n_stride),
storage_offset=n_stride * (seq_len - 1),
)
def forward(
self,
x: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x: Input tensor of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
Its shape is (batch_size, seq_len).
Outputs:
A tensor of shape (seq_len, batch_size, embed_dim).
"""
seq_len, batch_size, _ = x.shape
scaling = float(self.head_dim) ** -0.5
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
# (batch_size, num_head, seq_len, head_dim)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
attn_output_weights = attn_output_weights.view(
batch_size, self.num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=self.dropout, training=self.training
)
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def _test_text_encoder():
vocabs = 500
d_model = 192
batch_size = 5
seq_len = 100
m = TextEncoder(vocabs=vocabs, d_model=d_model)
x, m, logs, mask = m(
x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
x_lengths=torch.full((batch_size,), seq_len),
)
print(x.shape, m.shape, logs.shape, mask.shape)
if __name__ == "__main__":
_test_text_encoder()

896
egs/ljspeech/tts/vits/train.py Executable file
View File

@ -0,0 +1,896 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch.optim import Optimizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
setup_logger,
str2bool,
)
from symbols import symbol_table
from utils import (
MetricsTracker,
prepare_token_batch,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
from vits import VITS
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lr", type=float, default=2.0e-4, help="The base learning rate."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=4000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 1.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
# training params
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 50,
# "reset_interval": 200,
"valid_interval": 500,
"env_info": get_env_info(),
"sampling_rate": 22050,
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
"vocab_size": len(symbol_table),
"mel_loss_params": {
"frame_shift": 256,
"frame_length": 1024,
"n_mels": 80,
},
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params
def get_model(params: AttributeDict) -> nn.Module:
model = VITS(
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
mel_loss_params=params.mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
lambda_feat_match=params.lambda_feat_match,
lambda_dur=params.lambda_dur,
lambda_kl=params.lambda_kl,
)
return model
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["text"])
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer_g: Optimizer,
optimizer_d: Optimizer,
scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["text"])
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
try:
with autocast(enabled=params.use_fp16):
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
optimizer_d.zero_grad()
scaler.scale(loss_d).backward()
scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16):
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
scaler.step(optimizer_g)
scaler.update()
# summary stats
# tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.print_diagnostics and batch_idx == 5:
return
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
# if batch_idx % 100 == 0 and params.use_fp16:
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
# if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
# if batch_idx % params.log_interval == 0:
if params.batch_idx_train % params.log_interval == 0:
cur_lr_g = max(scheduler_g.get_last_lr())
cur_lr_d = max(scheduler_d.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
)
tb_writer.add_scalar(
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
optimizer_g.zero_grad()
loss_g.backward()
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
generator = model.generator
discriminator = model.discriminator
num_param_g = sum([p.numel() for p in generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW(
generator.parameters(),
lr=params.lr,
betas=(0.8, 0.99),
eps=1e-9,
weight_decay=0,
)
optimizer_d = torch.optim.AdamW(
discriminator.parameters(),
lr=params.lr,
betas=(0.8, 0.99),
eps=1e-9,
weight_decay=0,
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
if checkpoints is not None:
# load state_dict for optimizers
if "optimizer_g" in checkpoints:
logging.info("Loading optimizer_g state dict")
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
if "optimizer_d" in checkpoints:
logging.info("Loading optimizer_d state dict")
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
# load state_dict for schedulers
if "scheduler_g" in checkpoints:
logging.info("Loading scheduler_g state dict")
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
if "scheduler_d" in checkpoints:
logging.info("Loading scheduler_d state dict")
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = ljspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()
scheduler_d.step()
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,217 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
"""Flow-related transformation.
This code is derived from https://github.com/bayesiains/nflows.
"""
import numpy as np
import torch
from torch.nn import functional as F
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
# TODO(kan-bayashi): Documentation and type hint
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
# TODO(kan-bayashi): Documentation and type hint
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
# TODO(kan-bayashi): Documentation and type hint
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = _searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1

View File

@ -0,0 +1,306 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
SpeechSynthesisDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LJSpeechTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
)
else:
validate = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
)
else:
test = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,470 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Function to get random segments."""
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
import re
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from phonemizer import phonemize
from symbols import symbol_table
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from unidecode import unidecode
def get_random_segments(
x: torch.Tensor,
x_lengths: torch.Tensor,
segment_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = x.size()
max_start_idx = x_lengths - segment_size
max_start_idx[max_start_idx < 0] = 0
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
dtype=torch.long,
)
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
def get_segments(
x: torch.Tensor,
start_idxs: torch.Tensor,
segment_size: int,
) -> torch.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = x.size()
segments = x.new_zeros(b, c, segment_size)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx : start_idx + segment_size]
return segments
# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py
def force_gatherable(data, device):
"""Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning.
or a list, tuple, dict.
"""
if isinstance(data, dict):
return {k: force_gatherable(v, device) for k, v in data.items()}
# DataParallel can't handle NamedTuple well
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[force_gatherable(o, device) for o in data])
elif isinstance(data, (list, tuple, set)):
return type(data)(force_gatherable(v, device) for v in data)
elif isinstance(data, np.ndarray):
return force_gatherable(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
if data.dim() == 0:
# To 1-dim array
data = data[None]
return data.to(device)
elif isinstance(data, float):
return torch.tensor([data], dtype=torch.float, device=device)
elif isinstance(data, int):
return torch.tensor([data], dtype=torch.long, device=device)
elif data is None:
return None
else:
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
return data
# The following codes are based on https://github.com/jaywalnut310/vits
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def text_clean(text):
'''Pipeline for English text, including abbreviation expansion. + punctuation + stress.
Returns:
A string of phonemes.
'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = phonemize(
text,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
phonemes = collapse_whitespace(phonemes)
return phonemes
# Mappings from symbol to numeric ID and vice versa:
symbol_to_id = {s: i for i, s in enumerate(symbol_table)}
id_to_symbol = {i: s for i, s in enumerate(symbol_table)}
# def text_to_sequence(text: str) -> List[int]:
# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
# '''
# cleaned_text = text_clean(text)
# sequence = [symbol_to_id[symbol] for symbol in cleaned_text]
# return sequence
#
#
# def sequence_to_text(sequence: List[int]) -> str:
# '''Converts a sequence of IDs back to a string'''
# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence)
# return result
def intersperse(sequence, item=0):
result = [item] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def prepare_token_batch(
texts: List[str],
intersperse_blank: bool = True,
blank_id: int = 0,
pad_id: int = 0,
) -> torch.Tensor:
"""Convert a list of text strings into a batch of symbol tokens with padding.
Args:
texts: list of text strings
intersperse_blank: whether to intersperse blank tokens in the converted token sequence.
blank_id: index of blank token
pad_id: padding index
"""
# normalize text
normalized_texts = []
for text in texts:
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
normalized_texts.append(text)
# convert to phonemes
phonemes = phonemize(
normalized_texts,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
# convert to symbol ids
lengths = []
sequences = []
for idx, sequence in enumerate(phonemes):
try:
sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)]
except RuntimeError:
print(text[idx])
print(normalized_texts[idx])
if intersperse_blank:
sequence = intersperse(sequence, blank_id)
sequences.append(torch.tensor(sequence, dtype=torch.int64))
lengths.append(len(sequence))
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
lengths = torch.tensor(lengths, dtype=torch.int64)
return sequences, lengths
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
samples = "%.2f" % self["samples"]
ans += "over" + str(samples) + " samples."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('loss_1', 0.1), ('loss_2', 0.07)]
"""
samples = self["samples"] if "samples" in self else 1
ans = []
for k, v in self.items():
if k == "samples":
continue
norm_value = float(v) / samples
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
# checkpoint saving and loading
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
params:
A dict of training configurations to be saved.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
scaler=scaler,
sampler=sampler,
rank=rank,
)

View File

@ -0,0 +1,567 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""VITS module for GAN-TTS task."""
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
HiFiGANMultiScaleMultiPeriodDiscriminator,
HiFiGANPeriodDiscriminator,
HiFiGANScaleDiscriminator,
)
from loss import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
KLDivergenceLoss,
MelSpectrogramLoss,
)
from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
}
AVAILABLE_DISCRIMINATORS = {
"hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
"hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
"hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
"hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
"""
def __init__(
self,
# generator related
vocab_size: int,
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
"langs": None,
"spk_embed_dim": None,
"global_channels": -1,
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_blocks": 6,
"text_encoder_dropout_rate": 0.1,
"decoder_kernel_size": 7,
"decoder_channels": 512,
"decoder_upsample_scales": [8, 8, 2, 2],
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
"decoder_resblock_kernel_sizes": [3, 7, 11],
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_weight_norm_in_decoder": True,
"posterior_encoder_kernel_size": 5,
"posterior_encoder_layers": 16,
"posterior_encoder_stacks": 1,
"posterior_encoder_base_dilation": 1,
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 5,
"flow_base_dilation": 1,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
"stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3,
},
# discriminator related
discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
discriminator_params: Dict[str, Any] = {
"scales": 1,
"scale_downsample_pooling": "AvgPool1d",
"scale_downsample_pooling_params": {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
"scale_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
"follow_official_norm": False,
"periods": [2, 3, 5, 7, 11],
"period_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
},
# loss related
generator_adv_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"loss_type": "mse",
},
discriminator_adv_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"loss_type": "mse",
},
feat_match_loss_params: Dict[str, Any] = {
"average_by_discriminators": False,
"average_by_layers": False,
"include_final_outputs": True,
},
mel_loss_params: Dict[str, Any] = {
"frame_shift": 256,
"frame_length": 1024,
"n_mels": 80,
},
lambda_adv: float = 1.0,
lambda_mel: float = 45.0,
lambda_feat_match: float = 2.0,
lambda_dur: float = 1.0,
lambda_kl: float = 1.0,
cache_generator_outputs: bool = True,
):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
adversarial loss.
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
discriminator adversarial loss.
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
lambda_adv (float): Loss scaling coefficient for adversarial loss.
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
lambda_dur (float): Loss scaling coefficient for duration loss.
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
cache_generator_outputs (bool): Whether to cache generator outputs.
"""
super().__init__()
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
# NOTE(kan-bayashi): Update parameters for the compatibility.
# The idim and odim is automatically decided from input data,
# where idim represents #vocabularies and odim represents
# the input acoustic feature dimension.
generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
self.generator = generator_class(
**generator_params,
)
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
self.discriminator = discriminator_class(
**discriminator_params,
)
self.generator_adv_loss = GeneratorAdversarialLoss(
**generator_adv_loss_params,
)
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
**discriminator_adv_loss_params,
)
self.feat_match_loss = FeatureMatchLoss(
**feat_match_loss_params,
)
mel_loss_params.update(sampling_rate=sampling_rate)
self.mel_loss = MelSpectrogramLoss(
**mel_loss_params,
)
self.kl_loss = KLDivergenceLoss()
# coefficients
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_kl = lambda_kl
self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# store sampling rate for saving wav file
# (not used for the training)
self.sampling_rate = sampling_rate
# store parameters for test compatibility
self.spks = self.generator.spks
self.langs = self.generator.langs
self.spk_embed_dim = self.generator.spk_embed_dim
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
forward_generator: bool = True,
) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator.
Returns:
Dict[str, Any]:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
if forward_generator:
return self._forward_generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
def _forward_generator(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = outs
# parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size * self.generator.upsample_factor,
)
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_)
with torch.no_grad():
# do not store discriminator gradient in generator turn
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
mel_loss = self.mel_loss(speech_hat_, speech_)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
stats = dict(
generator_loss=loss.item(),
generator_mel_loss=mel_loss.item(),
generator_kl_loss=kl_loss.item(),
generator_dur_loss=dur_loss.item(),
generator_adv_loss=adv_loss.item(),
generator_feat_match_loss=feat_match_loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def _forward_discrminator(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
"""Perform discriminator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
# setup
feats = feats.transpose(1, 2)
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
)
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not reuse_cache:
self._cache = outs
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size * self.generator.upsample_factor,
)
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_.detach())
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss
stats = dict(
discriminator_loss=loss.item(),
discriminator_real_loss=real_loss.item(),
discriminator_fake_loss=fake_loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def inference(
self,
text: torch.Tensor,
feats: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (T_text,).
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids (Tensor): Speaker index tensor (1,).
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
lids (Tensor): Language index tensor (1,).
durations (Tensor): Ground-truth duration tensor (T_text,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
text_lengths = torch.tensor(
[text.size(1)],
dtype=torch.long,
device=text.device,
)
if sids is not None:
sids = sids.view(1)
if lids is not None:
lids = lids.view(1)
if durations is not None:
durations = durations.view(1, 1, -1)
# inference
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose(1, 2)
feats_lengths = torch.tensor(
[feats.size(2)],
dtype=torch.long,
device=feats.device,
)
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
max_len=max_len,
use_teacher_forcing=use_teacher_forcing,
)
else:
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0])

View File

@ -0,0 +1,349 @@
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""WaveNet modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import math
import logging
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
class WaveNet(torch.nn.Module):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_size: int = 3,
layers: int = 30,
stacks: int = 3,
base_dilation: int = 2,
residual_channels: int = 64,
aux_channels: int = -1,
gate_channels: int = 128,
skip_channels: int = 64,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
use_first_conv: bool = False,
use_last_conv: bool = False,
scale_residual: bool = False,
scale_skip_connect: bool = False,
):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(layers):
dilation = base_dilation ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual,
)
self.conv_layers += [conv]
# define output layers
if self.use_last_conv:
self.last_conv = torch.nn.Sequential(
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, skip_channels, bias=True),
torch.nn.ReLU(inplace=True),
Conv1d1x1(skip_channels, out_channels, bias=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(
layers: int,
stacks: int,
kernel_size: int,
base_dilation: int,
) -> int:
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self) -> int:
"""Return receptive field size."""
return self._get_receptive_field_size(
self.layers, self.stacks, self.kernel_size, self.base_dilation
)
class Conv1d(torch.nn.Conv1d):
"""Conv1d module with customized initialization."""
def __init__(self, *args, **kwargs):
"""Initialize Conv1d module."""
super().__init__(*args, **kwargs)
def reset_parameters(self):
"""Reset parameters."""
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
if self.bias is not None:
torch.nn.init.constant_(self.bias, 0.0)
class Conv1d1x1(Conv1d):
"""1x1 Conv1d with customized initialization."""
def __init__(self, in_channels: int, out_channels: int, bias: bool):
"""Initialize 1x1 Conv1d module."""
super().__init__(
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
)
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int = 3,
residual_channels: int = 64,
gate_channels: int = 128,
skip_channels: int = 64,
aux_channels: int = 80,
global_channels: int = -1,
dropout_rate: float = 0.0,
dilation: int = 1,
bias: bool = True,
scale_residual: bool = False,
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = Conv1d(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = Conv1d1x1(
gate_out_channels, residual_channels + skip_channels, bias=bias
)
def forward(
self,
x: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
c: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ga, xb + gb
x = torch.tanh(xa) * torch.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s