From 3df16b3f2b655a9fa700af2cafd17a8c79662307 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 22 Oct 2023 23:14:00 +0800 Subject: [PATCH] first commit --- .../tts/local/compute_spectrogram_ljspeech.py | 100 ++ .../tts/local/display_manifest_statistics.py | 73 ++ egs/ljspeech/tts/local/split_subsets.py | 79 ++ egs/ljspeech/tts/local/validate_manifest.py | 70 ++ egs/ljspeech/tts/prepare.sh | 77 ++ egs/ljspeech/tts/shared/parse_options.sh | 97 ++ egs/ljspeech/tts/vits/commons.py | 161 +++ egs/ljspeech/tts/vits/duration_predictor.py | 194 ++++ egs/ljspeech/tts/vits/features.py | 416 ++++++++ egs/ljspeech/tts/vits/flow.py | 311 ++++++ egs/ljspeech/tts/vits/generator.py | 524 ++++++++++ egs/ljspeech/tts/vits/hifigan.py | 933 ++++++++++++++++++ egs/ljspeech/tts/vits/loss.py | 332 +++++++ egs/ljspeech/tts/vits/models.py | 534 ++++++++++ .../tts/vits/monotonic_align/__init__.py | 81 ++ .../tts/vits/monotonic_align/core.pyx | 51 + .../tts/vits/monotonic_align/setup.py | 31 + egs/ljspeech/tts/vits/posterior_encoder.py | 117 +++ egs/ljspeech/tts/vits/residual_coupling.py | 229 +++++ egs/ljspeech/tts/vits/symbols.py | 17 + egs/ljspeech/tts/vits/text_encoder.py | 534 ++++++++++ egs/ljspeech/tts/vits/train.py | 896 +++++++++++++++++ egs/ljspeech/tts/vits/transform.py | 217 ++++ egs/ljspeech/tts/vits/tts_datamodule.py | 306 ++++++ egs/ljspeech/tts/vits/utils.py | 470 +++++++++ egs/ljspeech/tts/vits/vits.py | 567 +++++++++++ egs/ljspeech/tts/vits/wavenet.py | 349 +++++++ 27 files changed, 7766 insertions(+) create mode 100755 egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py create mode 100755 egs/ljspeech/tts/local/display_manifest_statistics.py create mode 100755 egs/ljspeech/tts/local/split_subsets.py create mode 100755 egs/ljspeech/tts/local/validate_manifest.py create mode 100755 egs/ljspeech/tts/prepare.sh create mode 100755 egs/ljspeech/tts/shared/parse_options.sh create mode 100644 egs/ljspeech/tts/vits/commons.py create mode 100644 egs/ljspeech/tts/vits/duration_predictor.py create mode 100644 egs/ljspeech/tts/vits/features.py create mode 100644 egs/ljspeech/tts/vits/flow.py create mode 100644 egs/ljspeech/tts/vits/generator.py create mode 100644 egs/ljspeech/tts/vits/hifigan.py create mode 100644 egs/ljspeech/tts/vits/loss.py create mode 100644 egs/ljspeech/tts/vits/models.py create mode 100644 egs/ljspeech/tts/vits/monotonic_align/__init__.py create mode 100644 egs/ljspeech/tts/vits/monotonic_align/core.pyx create mode 100644 egs/ljspeech/tts/vits/monotonic_align/setup.py create mode 100644 egs/ljspeech/tts/vits/posterior_encoder.py create mode 100644 egs/ljspeech/tts/vits/residual_coupling.py create mode 100644 egs/ljspeech/tts/vits/symbols.py create mode 100644 egs/ljspeech/tts/vits/text_encoder.py create mode 100755 egs/ljspeech/tts/vits/train.py create mode 100644 egs/ljspeech/tts/vits/transform.py create mode 100644 egs/ljspeech/tts/vits/tts_datamodule.py create mode 100644 egs/ljspeech/tts/vits/utils.py create mode 100644 egs/ljspeech/tts/vits/vits.py create mode 100644 egs/ljspeech/tts/vits/wavenet.py diff --git a/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py new file mode 100755 index 000000000..3603af07d --- /dev/null +++ b/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_ljspeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_ljspeech() diff --git a/egs/ljspeech/tts/local/display_manifest_statistics.py b/egs/ljspeech/tts/local/display_manifest_statistics.py new file mode 100755 index 000000000..93f0044f0 --- /dev/null +++ b/egs/ljspeech/tts/local/display_manifest_statistics.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: + ╒═══════════════════════════╤══════════╕ + │ Cuts count: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Total duration (hh:mm:ss) │ 23:55:18 │ + ├───────────────────────────┼──────────┤ + │ mean │ 6.6 │ + ├───────────────────────────┼──────────┤ + │ std │ 2.2 │ + ├───────────────────────────┼──────────┤ + │ min │ 1.1 │ + ├───────────────────────────┼──────────┤ + │ 25% │ 5.0 │ + ├───────────────────────────┼──────────┤ + │ 50% │ 6.8 │ + ├───────────────────────────┼──────────┤ + │ 75% │ 8.4 │ + ├───────────────────────────┼──────────┤ + │ 99% │ 10.0 │ + ├───────────────────────────┼──────────┤ + │ 99.5% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ 99.9% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ max │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ Recordings available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Features available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Supervisions available: │ 13100 │ + ╘═══════════════════════════╧══════════╛ +""" diff --git a/egs/ljspeech/tts/local/split_subsets.py b/egs/ljspeech/tts/local/split_subsets.py new file mode 100755 index 000000000..328cdd691 --- /dev/null +++ b/egs/ljspeech/tts/local/split_subsets.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script split the LJSpeech dataset cuts into three sets: + - training, 12500 + - validation, 100 + - test, 500 +The numbers are from https://arxiv.org/pdf/2106.06103.pdf + +Usage example: + python3 ./local/split_subsets.py ./data/spectrogram +""" + +import argparse +import logging +import random +from pathlib import Path + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest_dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest_dir = Path(args.manifest_dir) + prefix = "ljspeech" + suffix = "jsonl.gz" + all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") + + cut_ids = list(all_cuts.ids) + random.shuffle(cut_ids) + + train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500]) + valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100]) + test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:]) + assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)" + assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)" + assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)" + + train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}") + valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}") + test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}") + + logging.info("Splitted into three sets: training (12500), validation (100), and test (500)") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/tts/local/validate_manifest.py b/egs/ljspeech/tts/local/validate_manifest.py new file mode 100755 index 000000000..cd466303e --- /dev/null +++ b/egs/ljspeech/tts/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/tts/prepare.sh new file mode 100755 index 000000000..f78964c34 --- /dev/null +++ b/egs/ljspeech/tts/prepare.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=1 +stage=-1 +stop_stage=100 + +# dl_dir=$PWD/download +dl_dir=/star-data/zengwei/download/ljspeech/ + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LJSpeech, + # you can create a symlink + # + # ln -sfv /path/to/LJSpeech $dl_dir/LJSpeech + # + if [ ! -d $dl_dir/LJSpeech-1.1 ]; then + lhotse download ljspeech $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LJSpeech manifest" + # We assume that you have downloaded the LJSpeech corpus + # to $dl_dir/LJSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.ljspeech.done ]; then + lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests + touch data/manifests/.ljspeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for LJSpeech" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.ljspeech.done ]; then + ./local/compute_spectrogram_ljspeech.py + touch data/spectrogram/.ljspeech.done + fi + + if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then + log "Validating data/fbank for LJSpeech" + python3 ./local/validate_manifest.py \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Split the LJSpeech cuts into three sets" + if [ ! -e data/spectrogram/.ljspeech_split.done ]; then + ./local/split_subsets.py data/spectrogram + touch data/spectrogram/.ljspeech_split.done + fi +fi + + diff --git a/egs/ljspeech/tts/shared/parse_options.sh b/egs/ljspeech/tts/shared/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/egs/ljspeech/tts/shared/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/ljspeech/tts/vits/commons.py b/egs/ljspeech/tts/vits/commons.py new file mode 100644 index 000000000..9ad0444b6 --- /dev/null +++ b/egs/ljspeech/tts/vits/commons.py @@ -0,0 +1,161 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/egs/ljspeech/tts/vits/duration_predictor.py b/egs/ljspeech/tts/vits/duration_predictor.py new file mode 100644 index 000000000..5e8d670bd --- /dev/null +++ b/egs/ljspeech/tts/vits/duration_predictor.py @@ -0,0 +1,194 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Stochastic duration predictor modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from flow import ( + ConvFlow, + DilatedDepthSeparableConv, + ElementwiseAffineFlow, + FlipFlow, + LogFlow, +) + + +class StochasticDurationPredictor(torch.nn.Module): + """Stochastic duration predictor module. + + This is a module of stochastic duration predictor described in `Conditional + Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + flows: int = 4, + dds_conv_layers: int = 3, + global_channels: int = -1, + ): + """Initialize StochasticDurationPredictor module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + dropout_rate (float): Dropout rate. + flows (int): Number of flows. + dds_conv_layers (int): Number of conv layers in DDS conv. + global_channels (int): Number of global conditioning channels. + + """ + super().__init__() + + self.pre = torch.nn.Conv1d(channels, channels, 1) + self.dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.proj = torch.nn.Conv1d(channels, channels, 1) + + self.log_flow = LogFlow() + self.flows = torch.nn.ModuleList() + self.flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.flows += [FlipFlow()] + + self.post_pre = torch.nn.Conv1d(1, channels, 1) + self.post_dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.post_proj = torch.nn.Conv1d(channels, channels, 1) + self.post_flows = torch.nn.ModuleList() + self.post_flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.post_flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.post_flows += [FlipFlow()] + + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T_text). + x_mask (Tensor): Mask tensor (B, 1, T_text). + w (Optional[Tensor]): Duration tensor (B, 1, T_text). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) + inverse (bool): Whether to inverse the flow. + noise_scale (float): Noise scale value. + + Returns: + Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). + If inverse, log-duration tensor (B, 1, T_text). + + """ + x = x.detach() # stop gradient + x = self.pre(x) + if g is not None: + x = x + self.global_conv(g.detach()) # stop gradient + x = self.dds(x, x_mask) + x = self.proj(x) * x_mask + + if not inverse: + assert w is not None, "w must be provided." + h_w = self.post_pre(w) + h_w = self.post_dds(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn( + w.size(0), + 2, + w.size(2), + ).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + logdet_tot_q = 0.0 + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in self.flows: + z, logdet = flow(z, x_mask, g=x, inverse=inverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # (B,) + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn( + x.size(0), + 2, + x.size(2), + ).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, inverse=inverse) + z0, z1 = z.split(1, 1) + logw = z0 + return logw diff --git a/egs/ljspeech/tts/vits/features.py b/egs/ljspeech/tts/vits/features.py new file mode 100644 index 000000000..b43c7cf46 --- /dev/null +++ b/egs/ljspeech/tts/vits/features.py @@ -0,0 +1,416 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Tuple + +import librosa +import numpy as np +import torch +from torch import nn + +from icefall.utils import make_pad_mask + + +# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py +class Stft(nn.Module): + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.normalized = normalized + self.onesided = onesided + if window is not None and not hasattr(torch, f"{window}_window"): + raise ValueError(f"{window} window is not implemented") + self.window = window + + def extra_repr(self): + return ( + f"n_fft={self.n_fft}, " + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + f"normalized={self.normalized}, " + f"onesided={self.onesided}" + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + if input.dim() == 3: + multi_channel = True + # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) + input = input.transpose(1, 2).reshape(-1, input.size(1)) + else: + multi_channel = False + + # NOTE(kamo): + # The default behaviour of torch.stft is compatible with librosa.stft + # about padding and scaling. + # Note that it's different from scipy.signal.stft + + # output: (Batch, Freq, Frames, 2=real_imag) + # or (Batch, Channel, Freq, Frames, 2=real_imag) + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func( + self.win_length, dtype=input.dtype, device=input.device + ) + else: + window = None + + # For the compatibility of ARM devices, which do not support + # torch.stft() due to the lack of MKL (on older pytorch versions), + # there is an alternative replacement implementation with librosa. + # Note: pytorch >= 1.10.0 now has native support for FFT and STFT + # on all cpu targets including ARM. + if input.is_cuda or torch.backends.mkl.is_available(): + stft_kwargs = dict( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + window=window, + normalized=self.normalized, + onesided=self.onesided, + ) + stft_kwargs["return_complex"] = True + output = torch.stft(input, **stft_kwargs) + output = torch.view_as_real(output) + else: + if self.training: + raise NotImplementedError( + "stft is implemented with librosa on this device, which does not " + "support the training mode." + ) + + # use stft_kwargs to flexibly control different PyTorch versions' kwargs + # note: librosa does not support a win_length that is < n_ftt + # but the window can be manually padded (see below). + stft_kwargs = dict( + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.hop_length, + center=self.center, + window=window, + pad_mode="reflect", + ) + + if window is not None: + # pad the given window to n_fft + n_pad_left = (self.n_fft - window.shape[0]) // 2 + n_pad_right = self.n_fft - window.shape[0] - n_pad_left + stft_kwargs["window"] = torch.cat( + [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0 + ).numpy() + else: + win_length = ( + self.win_length if self.win_length is not None else self.n_fft + ) + stft_kwargs["window"] = torch.ones(win_length) + + output = [] + # iterate over istances in a batch + for i, instance in enumerate(input): + stft = librosa.stft(input[i].numpy(), **stft_kwargs) + output.append(torch.tensor(np.stack([stft.real, stft.imag], -1))) + output = torch.stack(output, 0) + if not self.onesided: + len_conj = self.n_fft - output.shape[1] + conj = output[:, 1 : 1 + len_conj].flip(1) + conj[:, :, :, -1].data *= -1 + output = torch.cat([output, conj], 1) + if self.normalized: + output = output * (stft_kwargs["window"].shape[0] ** (-0.5)) + + # output: (Batch, Freq, Frames, 2=real_imag) + # -> (Batch, Frames, Freq, 2=real_imag) + output = output.transpose(1, 2) + if multi_channel: + # output: (Batch * Channel, Frames, Freq, 2=real_imag) + # -> (Batch, Frame, Channel, Freq, 2=real_imag) + output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( + 1, 2 + ) + + if ilens is not None: + if self.center: + pad = self.n_fft // 2 + ilens = ilens + 2 * pad + + olens = ( + torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc") + + 1 + ) + output.masked_fill_(make_pad_mask(olens), 0.0) + else: + olens = None + + return output, olens + + +# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py +class LinearSpectrogram(nn.Module): + """Linear amplitude spectrogram. + + Stft -> amplitude-spec + """ + + def __init__( + self, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + self.n_fft = n_fft + + def output_size(self) -> int: + return self.n_fft // 2 + 1 + + def get_parameters(self) -> Dict[str, Any]: + """Return the parameters required by Vocoder.""" + return dict( + n_fft=self.n_fft, + n_shift=self.hop_length, + win_length=self.win_length, + window=self.window, + ) + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # STFT -> Power spectrum -> Amp spectrum + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) + return input_amp, feats_lens + + +# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py +class LogMel(nn.Module): + """Convert STFT to fbank feats + + The arguments is same as librosa.filters.mel + + Args: + fs: number > 0 [scalar] sampling rate of the incoming signal + n_fft: int > 0 [scalar] number of FFT components + n_mels: int > 0 [scalar] number of Mel bands to generate + fmin: float >= 0 [scalar] lowest frequency (in Hz) + fmax: float >= 0 [scalar] highest frequency (in Hz). + If `None`, use `fmax = fs / 2.0` + htk: use HTK formula instead of Slaney + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = None, + fmax: float = None, + htk: bool = False, + log_base: float = None, + ): + super().__init__() + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + _mel_options = dict( + sr=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + ) + self.mel_options = _mel_options + self.log_base = log_base + + # Note(kamo): The mel matrix of librosa is different from kaldi. + melmat = librosa.filters.mel(**_mel_options) + # melmat: (D2, D1) -> (D1, D2) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + + def extra_repr(self): + return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) + + def forward( + self, + feat: torch.Tensor, + ilens: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) + mel_feat = torch.matmul(feat, self.melmat) + mel_feat = torch.clamp(mel_feat, min=1e-10) + + if self.log_base is None: + logmel_feat = mel_feat.log() + elif self.log_base == 2.0: + logmel_feat = mel_feat.log2() + elif self.log_base == 10.0: + logmel_feat = mel_feat.log10() + else: + logmel_feat = mel_feat.log() / torch.log(self.log_base) + + # Zero padding + if ilens is not None: + logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0) + else: + ilens = feat.new_full( + [feat.size(0)], fill_value=feat.size(1), dtype=torch.long + ) + return logmel_feat, ilens + + +# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py +class LogMelFbank(nn.Module): + """Conventional frontend structure for TTS. + + Stft -> amplitude-spec -> Log-Mel-Fbank + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: Optional[int] = 80, + fmax: Optional[int] = 7600, + htk: bool = False, + log_base: Optional[float] = 10.0, + ): + super().__init__() + + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + self.logmel = LogMel( + fs=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + log_base=log_base, + ) + + def output_size(self) -> int: + return self.n_mels + + def get_parameters(self) -> Dict[str, Any]: + """Return the parameters required by Vocoder""" + return dict( + fs=self.fs, + n_fft=self.n_fft, + n_shift=self.hop_length, + window=self.window, + n_mels=self.n_mels, + win_length=self.win_length, + fmin=self.fmin, + fmax=self.fmax, + ) + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Domain-conversion: e.g. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # NOTE(kamo): We use different definition for log-spec between TTS and ASR + # TTS: log_10(abs(stft)) + # ASR: log_e(power(stft)) + + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) + input_feats, _ = self.logmel(input_amp, feats_lens) + return input_feats, feats_lens diff --git a/egs/ljspeech/tts/vits/flow.py b/egs/ljspeech/tts/vits/flow.py new file mode 100644 index 000000000..04fb99b42 --- /dev/null +++ b/egs/ljspeech/tts/vits/flow.py @@ -0,0 +1,311 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Basic Flow modules used in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional, Tuple, Union + +import torch + +from transform import piecewise_rational_quadratic_transform + + +class FlipFlow(torch.nn.Module): + """Flip flow module.""" + + def forward( + self, x: torch.Tensor, *args, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Flipped tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + x = torch.flip(x, [1]) + if not inverse: + logdet = x.new_zeros(x.size(0)) + return x, logdet + else: + return x + + +class LogFlow(torch.nn.Module): + """Log flow module.""" + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + inverse: bool = False, + eps: float = 1e-5, + **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + eps (float): Epsilon for log. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = torch.log(torch.clamp_min(x, eps)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class ElementwiseAffineFlow(torch.nn.Module): + """Elementwise affine flow module.""" + + def __init__(self, channels: int): + """Initialize ElementwiseAffineFlow module. + + Args: + channels (int): Number of channels. + + """ + super().__init__() + self.channels = channels + self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) + self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_lengths (Tensor): Length tensor (B,). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Transpose(torch.nn.Module): + """Transpose module for torch.nn.Sequential().""" + + def __init__(self, dim1: int, dim2: int): + """Initialize Transpose module.""" + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Transpose.""" + return x.transpose(self.dim1, self.dim2) + + +class DilatedDepthSeparableConv(torch.nn.Module): + """Dilated depth-separable conv module.""" + + def __init__( + self, + channels: int, + kernel_size: int, + layers: int, + dropout_rate: float = 0.0, + eps: float = 1e-5, + ): + """Initialize DilatedDepthSeparableConv module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + dropout_rate (float): Dropout rate. + eps (float): Epsilon for layer norm. + + """ + super().__init__() + + self.convs = torch.nn.ModuleList() + for i in range(layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Conv1d( + channels, + channels, + 1, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Dropout(dropout_rate), + ) + ] + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + if g is not None: + x = x + g + for f in self.convs: + y = f(x * x_mask) + x = x + y + return x * x_mask + + +class ConvFlow(torch.nn.Module): + """Convolutional flow module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + layers: int, + bins: int = 10, + tail_bound: float = 5.0, + ): + """Initialize ConvFlow module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + bins (int): Number of bins. + tail_bound (float): Tail bound value. + + """ + super().__init__() + self.half_channels = in_channels // 2 + self.hidden_channels = hidden_channels + self.bins = bins + self.tail_bound = tail_bound + + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.dds_conv = DilatedDepthSeparableConv( + hidden_channels, + kernel_size, + layers, + dropout_rate=0.0, + ) + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * (bins * 3 - 1), + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, 1) + h = self.input_conv(xa) + h = self.dds_conv(h, x_mask, g=g) + h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) + + b, c, t = xa.shape + # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) + + # TODO(kan-bayashi): Understand this calculation + denom = math.sqrt(self.hidden_channels) + unnorm_widths = h[..., : self.bins] / denom + unnorm_heights = h[..., self.bins : 2 * self.bins] / denom + unnorm_derivatives = h[..., 2 * self.bins :] + xb, logdet_abs = piecewise_rational_quadratic_transform( + xb, + unnorm_widths, + unnorm_heights, + unnorm_derivatives, + inverse=inverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = torch.cat([xa, xb], 1) * x_mask + logdet = torch.sum(logdet_abs * x_mask, [1, 2]) + if not inverse: + return x, logdet + else: + return x diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py new file mode 100644 index 000000000..dbf503944 --- /dev/null +++ b/egs/ljspeech/tts/vits/generator.py @@ -0,0 +1,524 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Generator module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + +from duration_predictor import StochasticDurationPredictor +from hifigan import HiFiGANGenerator +from posterior_encoder import PosteriorEncoder +from residual_coupling import ResidualAffineCouplingBlock +from text_encoder import TextEncoder +from utils import get_random_segments + + +class VITSGenerator(torch.nn.Module): + """Generator module in VITS, `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + aux_channels: int = 513, + hidden_channels: int = 192, + spks: Optional[int] = None, + langs: Optional[int] = None, + spk_embed_dim: Optional[int] = None, + global_channels: int = -1, + segment_size: int = 32, + text_encoder_attention_heads: int = 2, + text_encoder_ffn_expand: int = 4, + text_encoder_blocks: int = 6, + text_encoder_dropout_rate: float = 0.1, + decoder_kernel_size: int = 7, + decoder_channels: int = 512, + decoder_upsample_scales: List[int] = [8, 8, 2, 2], + decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], + decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_weight_norm_in_decoder: bool = True, + posterior_encoder_kernel_size: int = 5, + posterior_encoder_layers: int = 16, + posterior_encoder_stacks: int = 1, + posterior_encoder_base_dilation: int = 1, + posterior_encoder_dropout_rate: float = 0.0, + use_weight_norm_in_posterior_encoder: bool = True, + flow_flows: int = 4, + flow_kernel_size: int = 5, + flow_base_dilation: int = 1, + flow_layers: int = 4, + flow_dropout_rate: float = 0.0, + use_weight_norm_in_flow: bool = True, + use_only_mean_in_flow: bool = True, + stochastic_duration_predictor_kernel_size: int = 3, + stochastic_duration_predictor_dropout_rate: float = 0.5, + stochastic_duration_predictor_flows: int = 4, + stochastic_duration_predictor_dds_conv_layers: int = 3, + ): + """Initialize VITS generator module. + + Args: + vocabs (int): Input vocabulary size. + aux_channels (int): Number of acoustic feature channels. + hidden_channels (int): Number of hidden channels. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + global_channels (int): Number of global conditioning channels. + segment_size (int): Segment size for decoder. + text_encoder_attention_heads (int): Number of heads in conformer block + of text encoder. + text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block + of text encoder. + text_encoder_blocks (int): Number of conformer blocks in text encoder. + text_encoder_dropout_rate (float): Dropout rate in conformer block of + text encoder. + decoder_kernel_size (int): Decoder kernel size. + decoder_channels (int): Number of decoder initial channels. + decoder_upsample_scales (List[int]): List of upsampling scales in decoder. + decoder_upsample_kernel_sizes (List[int]): List of kernel size for + upsampling layers in decoder. + decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks + in decoder. + decoder_resblock_dilations (List[List[int]]): List of list of dilations for + resblocks in decoder. + use_weight_norm_in_decoder (bool): Whether to apply weight normalization in + decoder. + posterior_encoder_kernel_size (int): Posterior encoder kernel size. + posterior_encoder_layers (int): Number of layers of posterior encoder. + posterior_encoder_stacks (int): Number of stacks of posterior encoder. + posterior_encoder_base_dilation (int): Base dilation of posterior encoder. + posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. + use_weight_norm_in_posterior_encoder (bool): Whether to apply weight + normalization in posterior encoder. + flow_flows (int): Number of flows in flow. + flow_kernel_size (int): Kernel size in flow. + flow_base_dilation (int): Base dilation in flow. + flow_layers (int): Number of layers in flow. + flow_dropout_rate (float): Dropout rate in flow + use_weight_norm_in_flow (bool): Whether to apply weight normalization in + flow. + use_only_mean_in_flow (bool): Whether to use only mean in flow. + stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic + duration predictor. + stochastic_duration_predictor_dropout_rate (float): Dropout rate in + stochastic duration predictor. + stochastic_duration_predictor_flows (int): Number of flows in stochastic + duration predictor. + stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv + layers in stochastic duration predictor. + + """ + super().__init__() + self.segment_size = segment_size + self.text_encoder = TextEncoder( + vocabs=vocabs, + d_model=hidden_channels, + num_heads=text_encoder_attention_heads, + dim_feedforward=hidden_channels * text_encoder_ffn_expand, + num_layers=text_encoder_blocks, + dropout=text_encoder_dropout_rate, + ) + self.decoder = HiFiGANGenerator( + in_channels=hidden_channels, + out_channels=1, + channels=decoder_channels, + global_channels=global_channels, + kernel_size=decoder_kernel_size, + upsample_scales=decoder_upsample_scales, + upsample_kernel_sizes=decoder_upsample_kernel_sizes, + resblock_kernel_sizes=decoder_resblock_kernel_sizes, + resblock_dilations=decoder_resblock_dilations, + use_weight_norm=use_weight_norm_in_decoder, + ) + self.posterior_encoder = PosteriorEncoder( + in_channels=aux_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + kernel_size=posterior_encoder_kernel_size, + layers=posterior_encoder_layers, + stacks=posterior_encoder_stacks, + base_dilation=posterior_encoder_base_dilation, + global_channels=global_channels, + dropout_rate=posterior_encoder_dropout_rate, + use_weight_norm=use_weight_norm_in_posterior_encoder, + ) + self.flow = ResidualAffineCouplingBlock( + in_channels=hidden_channels, + hidden_channels=hidden_channels, + flows=flow_flows, + kernel_size=flow_kernel_size, + base_dilation=flow_base_dilation, + layers=flow_layers, + global_channels=global_channels, + dropout_rate=flow_dropout_rate, + use_weight_norm=use_weight_norm_in_flow, + use_only_mean=use_only_mean_in_flow, + ) + # TODO(kan-bayashi): Add deterministic version as an option + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, + ) + + self.upsample_factor = int(np.prod(decoder_upsample_scales)) + self.spks = None + if spks is not None and spks > 1: + assert global_channels > 0 + self.spks = spks + self.global_emb = torch.nn.Embedding(spks, global_channels) + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + assert global_channels > 0 + self.spk_embed_dim = spk_embed_dim + self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) + self.langs = None + if langs is not None and langs > 1: + assert global_channels > 0 + self.langs = langs + self.lang_emb = torch.nn.Embedding(langs, global_channels) + + # delayed import + from monotonic_align import maximum_path + + self.maximum_path = maximum_path + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: + """Calculate forward propagation. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: Duration negative log-likelihood (NLL) tensor (B,). + Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). + Tensor: Segments start index tensor (B,). + Tensor: Text mask tensor (B, 1, T_text). + Tensor: Feature mask tensor (B, 1, T_feats). + tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + - Tensor: Posterior encoder hidden representation (B, H, T_feats). + - Tensor: Flow hidden representation (B, H, T_feats). + - Tensor: Expanded text encoder projected mean (B, H, T_feats). + - Tensor: Expanded text encoder projected scale (B, H, T_feats). + - Tensor: Posterior encoder projected mean (B, H, T_feats). + - Tensor: Posterior encoder projected scale (B, H, T_feats). + + """ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + + # calculate global conditioning + g = None + if self.spks is not None: + # speaker one-hot vector embedding: (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # language one-hot vector embedding: (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = ( + self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ) + .unsqueeze(1) + .detach() + ) + + # forward duration predictor + w = attn.sum(2) # (B, 1, T_text) + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + z, + feats_lengths, + self.segment_size, + ) + + # forward decoder with random segments + wav = self.decoder(z_segments, g=g) + + return ( + wav, + dur_nll, + attn, + z_start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def inference( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: Optional[torch.Tensor] = None, + feats_lengths: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + dur: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats,). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, + skip the prediction of durations (i.e., teacher forcing). + noise_scale (float): Noise scale parameter for flow. + noise_scale_dur (float): Noise scale parameter for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length of acoustic feature sequence. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Monotonic attention weight tensor (B, T_feats, T_text). + Tensor: Duration tensor (B, T_text). + + """ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + g = None + if self.spks is not None: + # (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + if use_teacher_forcing: + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ).unsqueeze(1) + dur = attn.sum(2) # (B, 1, T_text) + + # forward decoder with random segments + wav = self.decoder(z * y_mask, g=g) + else: + # duration + if dur is None: + logw = self.duration_predictor( + x, + x_mask, + g=g, + inverse=True, + noise_scale=noise_scale_dur, + ) + w = torch.exp(logw) * x_mask * alpha + dur = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() + y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = self._generate_path(dur, attn_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul( + attn.squeeze(1), + m_p.transpose(1, 2), + ).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul( + attn.squeeze(1), + logs_p.transpose(1, 2), + ).transpose(1, 2) + + # decoder + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, inverse=True) + wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) + + return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) + + def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate path a.k.a. monotonic attention. + + Args: + dur (Tensor): Duration tensor (B, 1, T_text). + mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). + + Returns: + Tensor: Path tensor (B, 1, T_feats, T_text). + + """ + b, _, t_y, t_x = mask.shape + cum_dur = torch.cumsum(dur, -1) + cum_dur_flat = cum_dur.view(b * t_x) + path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + # path will be like (t_x = 3, t_y = 5): + # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/tts/vits/hifigan.py b/egs/ljspeech/tts/vits/hifigan.py new file mode 100644 index 000000000..a87cb2fce --- /dev/null +++ b/egs/ljspeech/tts/vits/hifigan.py @@ -0,0 +1,933 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFi-GAN Modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import copy +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + + +class HiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator module.""" + + def __init__( + self, + in_channels: int = 80, + out_channels: int = 1, + channels: int = 512, + global_channels: int = -1, + kernel_size: int = 7, + upsample_scales: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_additional_convs: bool = True, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + global_channels (int): Number of global conditioning channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (List[int]): List of upsampling scales. + upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. + resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. + resblock_dilations (List[List[int]]): List of list of dilations for residual + blocks. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + + """ + super().__init__() + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernel size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.upsample_factor = int(np.prod(upsample_scales) * out_channels) + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2**i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + c = self.input_conv(c) + if g is not None: + c = c + self.global_conv(g) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m: torch.nn.Module): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def inference( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Perform inference. + + Args: + c (torch.Tensor): Input tensor (T, in_channels). + g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). + + Returns: + Tensor: Output tensor (T ** upsample_factor, out_channels). + + """ + if g is not None: + g = g.unsqueeze(0) + c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g) + return c.squeeze(0).transpose(1, 0) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in HiFiGAN.""" + + def __init__( + self, + kernel_size: int = 3, + channels: int = 512, + dilations: List[int] = [1, 3, 5], + bias: bool = True, + use_additional_convs: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels for convolution layer. + dilations (List[int]): List of dilation factors. + use_additional_convs (bool): Whether to use additional convolution layers. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + + """ + super().__init__() + self.use_additional_convs = use_additional_convs + self.convs1 = torch.nn.ModuleList() + if use_additional_convs: + self.convs2 = torch.nn.ModuleList() + assert kernel_size % 2 == 1, "Kernel size must be odd number." + for dilation in dilations: + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + bias=bias, + padding=(kernel_size - 1) // 2 * dilation, + ), + ) + ] + if use_additional_convs: + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + ) + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + for idx in range(len(self.convs1)): + xt = self.convs1[idx](x) + if self.use_additional_convs: + xt = self.convs2[idx](xt) + x = xt + x + return x + + +class HiFiGANPeriodDiscriminator(torch.nn.Module): + """HiFiGAN period discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + period: int = 3, + kernel_sizes: List[int] = [5, 3], + channels: int = 32, + downsample_scales: List[int] = [3, 3, 3, 3, 1], + max_downsample_channels: int = 1024, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initialize HiFiGANPeriodDiscriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + period (int): Period. + kernel_sizes (list): Kernel sizes of initial conv layers and the final conv + layer. + channels (int): Number of initial channels. + downsample_scales (List[int]): List of downsampling scales. + max_downsample_channels (int): Number of maximum downsampling channels. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." + assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." + + self.period = period + self.convs = torch.nn.ModuleList() + in_chs = in_channels + out_chs = channels + for downsample_scale in downsample_scales: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_chs, + out_chs, + (kernel_sizes[0], 1), + (downsample_scale, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Use downsample_scale + 1? + out_chs = min(out_chs * 4, max_downsample_channels) + self.output_conv = torch.nn.Conv2d( + out_chs, + out_channels, + (kernel_sizes[1] - 1, 1), + 1, + padding=((kernel_sizes[1] - 1) // 2, 0), + ) + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + list: List of each layer's tensors. + + """ + # transform 1d to 2d -> (B, C, T/P, P) + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t += n_pad + x = x.view(b, c, t // self.period, self.period) + + # forward conv + outs = [] + for layer in self.convs: + x = layer(x) + outs += [x] + x = self.output_conv(x) + x = torch.flatten(x, 1, -1) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN multi-period discriminator module.""" + + def __init__( + self, + periods: List[int] = [2, 3, 5, 7, 11], + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initialize HiFiGANMultiPeriodDiscriminator module. + + Args: + periods (List[int]): List of periods. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + for period in periods: + params = copy.deepcopy(discriminator_params) + params["period"] = period + self.discriminators += [HiFiGANPeriodDiscriminator(**params)] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + + return outs + + +class HiFiGANScaleDiscriminator(torch.nn.Module): + """HiFi-GAN scale discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_sizes: List[int] = [15, 41, 5, 3], + channels: int = 128, + max_downsample_channels: int = 1024, + max_groups: int = 16, + bias: int = True, + downsample_scales: List[int] = [2, 2, 4, 4, 1], + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initilize HiFiGAN scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (List[int]): List of four kernel sizes. The first will be used + for the first conv layer, and the second is for downsampling part, and + the remaining two are for the last two output layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling + layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (List[int]): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. If set to true, it + will be applied to all of the conv layers. + + """ + super().__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 4 + for ks in kernel_sizes: + assert ks % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + channels, + # NOTE(kan-bayashi): Use always the same kernel size + kernel_sizes[0], + bias=bias, + padding=(kernel_sizes[0] - 1) // 2, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + out_chs = channels + # NOTE(kan-bayashi): Remove hard coding? + groups = 4 + for downsample_scale in downsample_scales: + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[1], + stride=downsample_scale, + padding=(kernel_sizes[1] - 1) // 2, + groups=groups, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Remove hard coding? + out_chs = min(in_chs * 2, max_downsample_channels) + # NOTE(kan-bayashi): Remove hard coding? + groups = min(groups * 4, max_groups) + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_size=kernel_sizes[3], + stride=1, + padding=(kernel_sizes[3] - 1) // 2, + bias=bias, + ), + ] + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + self.use_weight_norm = use_weight_norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + self.use_spectral_norm = use_spectral_norm + if use_spectral_norm: + self.apply_spectral_norm() + + # backward compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[Tensor]: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def remove_spectral_norm(self): + """Remove spectral normalization module from all of the layers.""" + + def _remove_spectral_norm(m): + try: + logging.debug(f"Spectral norm is removed from {m}.") + torch.nn.utils.remove_spectral_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_spectral_norm) + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Fix the compatibility of weight / spectral normalization issue. + + Some pretrained models are trained with configs that use weight / spectral + normalization, but actually, the norm is not applied. This causes the mismatch + of the parameters with configs. To solve this issue, when parameter mismatch + happens in loading pretrained model, we remove the norm from the current model. + + See also: + - https://github.com/espnet/espnet/pull/5240 + - https://github.com/espnet/espnet/pull/5249 + - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409 + + """ + current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)] + if self.use_weight_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems weight norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_weight_norm() + self.use_weight_norm = False + for k in current_module_keys: + if k.endswith("weight_g") or k.endswith("weight_v"): + del state_dict[k] + + if self.use_spectral_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems spectral norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_spectral_norm() + self.use_spectral_norm = False + for k in current_module_keys: + if ( + k.endswith("weight_u") + or k.endswith("weight_v") + or k.endswith("weight_orig") + ): + del state_dict[k] + + +class HiFiGANMultiScaleDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale discriminator module.""" + + def __init__( + self, + scales: int = 3, + downsample_pooling: str = "AvgPool1d", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = False, + ): + """Initilize HiFiGAN multi-scale discriminator module. + + Args: + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the + inputs. + downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling + module. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm + and the other discriminators use weight norm. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for i in range(scales): + params = copy.deepcopy(discriminator_params) + if follow_official_norm: + if i == 0: + params["use_weight_norm"] = False + params["use_spectral_norm"] = True + else: + params["use_weight_norm"] = True + params["use_spectral_norm"] = False + self.discriminators += [HiFiGANScaleDiscriminator(**params)] + self.pooling = None + if scales > 1: + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[torch.Tensor]]: List of list of each discriminator outputs, + which consists of eachlayer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + if self.pooling is not None: + x = self.pooling(x) + + return outs + + +class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale + multi-period discriminator module.""" + + def __init__( + self, + # Multi-scale discriminator related + scales: int = 3, + scale_downsample_pooling: str = "AvgPool1d", + scale_downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + scale_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = True, + # Multi-period discriminator related + periods: List[int] = [2, 3, 5, 7, 11], + period_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initilize HiFiGAN multi-scale + multi-period discriminator module. + + Args: + scales (int): Number of multi-scales. + scale_downsample_pooling (str): Pooling module name for downsampling of the + inputs. + scale_downsample_pooling_params (dict): Parameters for the above pooling + module. + scale_discriminator_params (dict): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm and + the other discriminators use weight norm. + periods (list): List of periods. + period_discriminator_params (dict): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.msd = HiFiGANMultiScaleDiscriminator( + scales=scales, + downsample_pooling=scale_downsample_pooling, + downsample_pooling_params=scale_downsample_pooling_params, + discriminator_params=scale_discriminator_params, + follow_official_norm=follow_official_norm, + ) + self.mpd = HiFiGANMultiPeriodDiscriminator( + periods=periods, + discriminator_params=period_discriminator_params, + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[Tensor]]: List of list of each discriminator outputs, + which consists of each layer output tensors. Multi scale and + multi period ones are concatenated. + + """ + msd_outs = self.msd(x) + mpd_outs = self.mpd(x) + return msd_outs + mpd_outs diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/tts/vits/loss.py new file mode 100644 index 000000000..d322f5e05 --- /dev/null +++ b/egs/ljspeech/tts/vits/loss.py @@ -0,0 +1,332 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFiGAN-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributions as D +import torch.nn.functional as F + +from lhotse.features.kaldi import Wav2LogFilterBank + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramLoss(torch.nn.Module): + """Mel-spectrogram loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + frame_length: int = 1024, # in samples + frame_shift: int = 256, # in samples + n_mels: int = 80, + use_fft_mag: bool = True, + ): + super().__init__() + self.wav_to_mel = Wav2LogFilterBank( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # in second + frame_shift=frame_shift / sampling_rate, # in second + use_fft_mag=use_fft_mag, + num_filters=n_mels, + ) + + def forward( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """Calculate Mel-spectrogram loss. + + Args: + y_hat (Tensor): Generated waveform tensor (B, 1, T). + y (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_hat = self.wav_to_mel(y_hat.squeeze(1)) + mel = self.wav_to_mel(y.squeeze(1)) + mel_loss = F.l1_loss(mel_hat, mel) + + return mel_loss + + +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py + +"""VITS-related loss modules. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +class KLDivergenceLoss(torch.nn.Module): + """KL divergence loss.""" + + def forward( + self, + z_p: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + z_mask: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss. + + Args: + z_p (Tensor): Flow hidden representation (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + z_mask (Tensor): Mask tensor (B, 1, T_feats). + + Returns: + Tensor: KL divergence loss. + + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + loss = kl / torch.sum(z_mask) + + return loss + + +class KLDivergenceLossWithoutFlow(torch.nn.Module): + """KL divergence loss without flow.""" + + def forward( + self, + m_q: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss without flow. + + Args: + m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + """ + posterior_norm = D.Normal(m_q, torch.exp(logs_q)) + prior_norm = D.Normal(m_p, torch.exp(logs_p)) + loss = D.kl_divergence(posterior_norm, prior_norm).mean() + return loss diff --git a/egs/ljspeech/tts/vits/models.py b/egs/ljspeech/tts/vits/models.py new file mode 100644 index 000000000..f5acdeb2b --- /dev/null +++ b/egs/ljspeech/tts/vits/models.py @@ -0,0 +1,534 @@ +import copy +import math +import torch +from torch import nn +from torch.nn import functional as F + +import commons +import modules +import attentions +import monotonic_align + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from commons import init_weights, get_padding + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + **kwargs): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + assert self.n_speakers > 0, "n_speakers have to be larger than 0." + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + diff --git a/egs/ljspeech/tts/vits/monotonic_align/__init__.py b/egs/ljspeech/tts/vits/monotonic_align/__init__.py new file mode 100644 index 000000000..2b35654f5 --- /dev/null +++ b/egs/ljspeech/tts/vits/monotonic_align/__init__.py @@ -0,0 +1,81 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py + +"""Maximum path calculation module. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import warnings + +import numpy as np +import torch +from numba import njit, prange + +try: + from .core import maximum_path_c + + is_cython_avalable = True +except ImportError: + is_cython_avalable = False + warnings.warn( + "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " + "If you want to use the cython version, please build it as follows: " + "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" + ) + + +def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Calculate maximum path. + + Args: + neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). + attn_mask (Tensor): Attention mask (B, T_feats, T_text). + + Returns: + Tensor: Maximum path tensor (B, T_feats, T_text). + + """ + device, dtype = neg_x_ent.device, neg_x_ent.dtype + neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) + path = np.zeros(neg_x_ent.shape, dtype=np.int32) + t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) + t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) + if is_cython_avalable: + maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) + else: + maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) + + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@njit +def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): + """Calculate a single maximum path with numba.""" + index = t_x - 1 + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@njit(parallel=True) +def maximum_path_numba(paths, values, t_ys, t_xs): + """Calculate batch maximum path with numba.""" + for i in prange(paths.shape[0]): + maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/tts/vits/monotonic_align/core.pyx b/egs/ljspeech/tts/vits/monotonic_align/core.pyx new file mode 100644 index 000000000..c02c2d02e --- /dev/null +++ b/egs/ljspeech/tts/vits/monotonic_align/core.pyx @@ -0,0 +1,51 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx + +"""Maximum path calculation module with cython optimization. + +This code is copied from https://github.com/jaywalnut310/vits and modifed code format. + +""" + +cimport cython + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/tts/vits/monotonic_align/setup.py b/egs/ljspeech/tts/vits/monotonic_align/setup.py new file mode 100644 index 000000000..33d75e176 --- /dev/null +++ b/egs/ljspeech/tts/vits/monotonic_align/setup.py @@ -0,0 +1,31 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py +"""Setup cython code.""" + +from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] +setup( + name="monotonic_align", + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, +) diff --git a/egs/ljspeech/tts/vits/posterior_encoder.py b/egs/ljspeech/tts/vits/posterior_encoder.py new file mode 100644 index 000000000..c78fd647f --- /dev/null +++ b/egs/ljspeech/tts/vits/posterior_encoder.py @@ -0,0 +1,117 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Posterior encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple + +import torch + +from icefall.utils import make_pad_mask +from wavenet import WaveNet, Conv1d + + +class PosteriorEncoder(torch.nn.Module): + """Posterior encoder module in VITS. + + This is a module of posterior encoder described in `Conditional Variational + Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + in_channels: int = 513, + out_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + layers: int = 16, + stacks: int = 1, + base_dilation: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + ): + """Initilialize PosteriorEncoder module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size in WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of repeat stacking of WaveNet. + base_dilation (int): Base dilation factor. + global_channels (int): Number of global conditioning channels. + dropout_rate (float): Dropout rate. + bias (bool): Whether to use bias parameters in conv. + use_weight_norm (bool): Whether to apply weight norm. + + """ + super().__init__() + + # define modules + self.input_conv = Conv1d(in_channels, hidden_channels, 1) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + self.proj = Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_feats). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). + Tensor: Projected mean tensor (B, out_channels, T_feats). + Tensor: Projected scale tensor (B, out_channels, T_feats). + Tensor: Mask tensor for input tensor (B, 1, T_feats). + + """ + x_mask = ( + (~make_pad_mask(x_lengths)) + .unsqueeze(1) + .to( + dtype=x.dtype, + device=x.device, + ) + ) + x = self.input_conv(x) * x_mask + x = self.encoder(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + + return z, m, logs, x_mask diff --git a/egs/ljspeech/tts/vits/residual_coupling.py b/egs/ljspeech/tts/vits/residual_coupling.py new file mode 100644 index 000000000..48e748316 --- /dev/null +++ b/egs/ljspeech/tts/vits/residual_coupling.py @@ -0,0 +1,229 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Residual affine coupling modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple, Union + +import torch + +from flow import FlipFlow +from wavenet import WaveNet + + +class ResidualAffineCouplingBlock(torch.nn.Module): + """Residual affine coupling block module. + + This is a module of residual affine coupling block, which used as "Flow" in + `Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + flows: int = 4, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 4, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initilize ResidualAffineCouplingBlock module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + flows (int): Number of flows. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + super().__init__() + + self.flows = torch.nn.ModuleList() + for i in range(flows): + self.flows += [ + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] + self.flows += [FlipFlow()] + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + + """ + if not inverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, inverse=inverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, inverse=inverse) + return x + + +class ResidualAffineCouplingLayer(torch.nn.Module): + """Residual affine coupling layer.""" + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 5, + stacks: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initialzie ResidualAffineCouplingLayer module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + # define modules + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + if use_only_mean: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels, + 1, + ) + else: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * 2, + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, dim=1) + h = self.input_conv(xa) * x_mask + h = self.encoder(h, x_mask, g=g) + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = stats.split(stats.size(1) // 2, dim=1) + else: + m = stats + logs = torch.zeros_like(m) + + if not inverse: + xb = m + xb * torch.exp(logs) * x_mask + x = torch.cat([xa, xb], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * torch.exp(-logs) * x_mask + x = torch.cat([xa, xb], 1) + return x diff --git a/egs/ljspeech/tts/vits/symbols.py b/egs/ljspeech/tts/vits/symbols.py new file mode 100644 index 000000000..70c2868f4 --- /dev/null +++ b/egs/ljspeech/tts/vits/symbols.py @@ -0,0 +1,17 @@ +# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py +""" from https://github.com/keithito/tacotron """ + +''' +Defines the set of symbols used in text input to the model. +''' +_pad = '_' +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + + +# Export all symbols: +symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbol_table.index(" ") diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/tts/vits/text_encoder.py new file mode 100644 index 000000000..fbf9b16a3 --- /dev/null +++ b/egs/ljspeech/tts/vits/text_encoder.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text encoder module in VITS. + +This code is based on + - https://github.com/jaywalnut310/vits + - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py +""" + +import copy +import math +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.utils import make_pad_mask + + +class TextEncoder(torch.nn.Module): + """Text encoder module in VITS. + + This is a module of text encoder described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + num_layers: int = 6, + dropout: float = 0.1, + ): + """Initialize TextEncoder module. + + Args: + vocabs (int): Vocabulary size. + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + super().__init__() + self.d_model = d_model + + # define modules + self.emb = torch.nn.Embedding(vocabs, d_model) + torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + + self.encoder = Transformer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + num_layers=num_layers, + dropout=dropout, + ) + + self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input index tensor (B, T_text). + x_lengths (Tensor): Length tensor (B,). + + Returns: + Tensor: Encoded hidden representation (B, attention_dim, T_text). + Tensor: Projected mean tensor (B, attention_dim, T_text). + Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Mask tensor for input tensor (B, 1, T_text). + + """ + # (B, T_text, embed_dim) + x = self.emb(x) * math.sqrt(self.d_model) + + assert x.size(1) == x_lengths.max().item() + + # (B, T_text) + pad_mask = make_pad_mask(x_lengths) + + # encoder assume the channel last (B, T_text, embed_dim) + x = self.encoder(x, key_padding_mask=pad_mask) + + # convert the channel first (B, embed_dim, T_text) + x = x.transpose(1, 2) + non_pad_mask = (~pad_mask).unsqueeze(1) + stats = self.proj(x) * non_pad_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + + return x, m, logs, non_pad_mask + + +class Transformer(nn.Module): + """ + Args: + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + num_layers: int = 6, + dropout: float = 0.1, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.d_model = d_model + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers) + self.after_norm = nn.LayerNorm(d_model) + + def forward( + self, x: Tensor, key_padding_mask: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + lengths: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + """ + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x = self.encoder( + x, pos_emb, key_padding_mask=key_padding_mask + ) # (T, N, C) + + x = self.after_norm(x) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x + + +class TransformerEncoderLayer(nn.Module): + """ + TransformerEncoderLayer is made up of self-attn and feedforward. + + Args: + d_model: the number of expected features in the input. + num_heads: the number of heads in the multi-head attention models. + dim_feedforward: the dimension of the feed-forward network model. + dropout: the dropout value (default=0.1). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + dropout: float = 0.1, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the transformer encoder layer. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + # multi-head self-attention module + src_attn = self.self_attn( + self.norm_mha(src), + pos_emb=pos_emb, + key_padding_mask=key_padding_mask, + ) + src = src + self.dropout(src_attn) + + # feed-forward module + src = src + self.dropout(self.feed_forward(self.norm_ff(src))) + + src = self.norm_final(src) + + return src + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer class. + num_layers: the number of sub-encoder-layers in the encoder. + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + key_padding_mask=key_padding_mask, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + x_size = x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, seq_len, 2*seq_len-1). + + Returns: + Tensor: tensor of shape (batch, head, seq_len, seq_len) + """ + (batch_size, num_heads, seq_len, n) = x.shape + + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: Input tensor of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim) + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + Its shape is (batch_size, seq_len). + + Outputs: + A tensor of shape (seq_len, batch_size, embed_dim). + """ + seq_len, batch_size, _ = x.shape + scaling = float(self.head_dim) ** -0.5 + + q, k, v = self.in_proj(x).chunk(3, dim=-1) + + q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + + q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) + + p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) + p = p.permute(0, 2, 3, 1) + + # (batch_size, num_head, seq_len, head_dim) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + + # (batch_size, num_head, seq_len, seq_len) + attn_output_weights = (matrix_ac + matrix_bd) * scaling + attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size, self.num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + # (batch_size * num_head, seq_len, head_dim) + attn_output = torch.bmm(attn_output_weights, v) + assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + ) + # (seq_len, batch_size, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Swish(nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def _test_text_encoder(): + vocabs = 500 + d_model = 192 + batch_size = 5 + seq_len = 100 + + m = TextEncoder(vocabs=vocabs, d_model=d_model) + x, m, logs, mask = m( + x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)), + x_lengths=torch.full((batch_size,), seq_len), + ) + print(x.shape, m.shape, logs.shape, mask.shape) + + +if __name__ == "__main__": + _test_text_encoder() diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py new file mode 100755 index 000000000..8fd2a596a --- /dev/null +++ b/egs/ljspeech/tts/vits/train.py @@ -0,0 +1,896 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + +from symbols import symbol_table +from utils import ( + MetricsTracker, + prepare_token_batch, + save_checkpoint, + save_checkpoint_with_global_batch_idx, +) +from vits import VITS + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + # "reset_interval": 200, + "valid_interval": 500, + "env_info": get_env_info(), + "sampling_rate": 22050, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "vocab_size": len(symbol_table), + "mel_loss_params": { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + mel_loss_params=params.mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["text"]) + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["text"]) + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + # tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + # if batch_idx % 100 == 0 and params.use_fp16: + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + # if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + # if batch_idx % params.log_interval == 0: + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), + lr=params.lr, + betas=(0.8, 0.99), + eps=1e-9, + weight_decay=0, + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), + lr=params.lr, + betas=(0.8, 0.99), + eps=1e-9, + weight_decay=0, + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/tts/vits/transform.py b/egs/ljspeech/tts/vits/transform.py new file mode 100644 index 000000000..6858de2ab --- /dev/null +++ b/egs/ljspeech/tts/vits/transform.py @@ -0,0 +1,217 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py +"""Flow-related transformation. + +This code is derived from https://github.com/bayesiains/nflows. + +""" + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +# TODO(kan-bayashi): Documentation and type hint +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = _searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = _searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet + + +def _searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/tts/vits/tts_datamodule.py b/egs/ljspeech/tts/vits/tts_datamodule.py new file mode 100644 index 000000000..bd67aa6b1 --- /dev/null +++ b/egs/ljspeech/tts/vits/tts_datamodule.py @@ -0,0 +1,306 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + SpeechSynthesisDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + ) + else: + validate = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + ) + else: + test = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/tts/vits/utils.py new file mode 100644 index 000000000..002097581 --- /dev/null +++ b/egs/ljspeech/tts/vits/utils.py @@ -0,0 +1,470 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Function to get random segments.""" + +from typing import Any, Dict, List, Optional, Tuple, Union +import collections +import logging +import re +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.distributed as dist +from lhotse.dataset.sampling.base import CutSampler +from pathlib import Path +from phonemizer import phonemize +from symbols import symbol_table +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_sequence +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from unidecode import unidecode + + +def get_random_segments( + x: torch.Tensor, + x_lengths: torch.Tensor, + segment_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get random segments. + + Args: + x (Tensor): Input tensor (B, C, T). + x_lengths (Tensor): Length tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + Tensor: Start index tensor (B,). + + """ + b, c, t = x.size() + max_start_idx = x_lengths - segment_size + max_start_idx[max_start_idx < 0] = 0 + start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( + dtype=torch.long, + ) + segments = get_segments(x, start_idxs, segment_size) + + return segments, start_idxs + + +def get_segments( + x: torch.Tensor, + start_idxs: torch.Tensor, + segment_size: int, +) -> torch.Tensor: + """Get segments. + + Args: + x (Tensor): Input tensor (B, C, T). + start_idxs (Tensor): Start index tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + + """ + b, c, t = x.size() + segments = x.new_zeros(b, c, segment_size) + for i, start_idx in enumerate(start_idxs): + segments[i] = x[i, :, start_idx : start_idx + segment_size] + return segments + + +# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py +def force_gatherable(data, device): + """Change object to gatherable in torch.nn.DataParallel recursively + + The difference from to_device() is changing to torch.Tensor if float or int + value is found. + + The restriction to the returned value in DataParallel: + The object must be + - torch.cuda.Tensor + - 1 or more dimension. 0-dimension-tensor sends warning. + or a list, tuple, dict. + + """ + if isinstance(data, dict): + return {k: force_gatherable(v, device) for k, v in data.items()} + # DataParallel can't handle NamedTuple well + elif isinstance(data, tuple) and type(data) is not tuple: + return type(data)(*[force_gatherable(o, device) for o in data]) + elif isinstance(data, (list, tuple, set)): + return type(data)(force_gatherable(v, device) for v in data) + elif isinstance(data, np.ndarray): + return force_gatherable(torch.from_numpy(data), device) + elif isinstance(data, torch.Tensor): + if data.dim() == 0: + # To 1-dim array + data = data[None] + return data.to(device) + elif isinstance(data, float): + return torch.tensor([data], dtype=torch.float, device=device) + elif isinstance(data, int): + return torch.tensor([data], dtype=torch.long, device=device) + elif data is None: + return None + else: + warnings.warn(f"{type(data)} may not be gatherable by DataParallel") + return data + + +# The following codes are based on https://github.com/jaywalnut310/vits + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def text_clean(text): + '''Pipeline for English text, including abbreviation expansion. + punctuation + stress. + + Returns: + A string of phonemes. + ''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize( + text, + language='en-us', + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +# Mappings from symbol to numeric ID and vice versa: +symbol_to_id = {s: i for i, s in enumerate(symbol_table)} +id_to_symbol = {i: s for i, s in enumerate(symbol_table)} + + +# def text_to_sequence(text: str) -> List[int]: +# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. +# ''' +# cleaned_text = text_clean(text) +# sequence = [symbol_to_id[symbol] for symbol in cleaned_text] +# return sequence +# +# +# def sequence_to_text(sequence: List[int]) -> str: +# '''Converts a sequence of IDs back to a string''' +# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence) +# return result + + +def intersperse(sequence, item=0): + result = [item] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result + + +def prepare_token_batch( + texts: List[str], + intersperse_blank: bool = True, + blank_id: int = 0, + pad_id: int = 0, +) -> torch.Tensor: + """Convert a list of text strings into a batch of symbol tokens with padding. + Args: + texts: list of text strings + intersperse_blank: whether to intersperse blank tokens in the converted token sequence. + blank_id: index of blank token + pad_id: padding index + """ + # normalize text + normalized_texts = [] + for text in texts: + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + normalized_texts.append(text) + + # convert to phonemes + phonemes = phonemize( + normalized_texts, + language='en-us', + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + + # convert to symbol ids + lengths = [] + sequences = [] + for idx, sequence in enumerate(phonemes): + try: + sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)] + except RuntimeError: + print(text[idx]) + print(normalized_texts[idx]) + if intersperse_blank: + sequence = intersperse(sequence, blank_id) + sequences.append(torch.tensor(sequence, dtype=torch.int64)) + lengths.append(len(sequence)) + + sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id) + lengths = torch.tensor(lengths, dtype=torch.int64) + return sequences, lengths + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + samples = "%.2f" % self["samples"] + ans += "over" + str(samples) + " samples." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('loss_1', 0.1), ('loss_2', 0.07)] + """ + samples = self["samples"] if "samples" in self else 1 + ans = [] + for k, v in self.items(): + if k == "samples": + continue + norm_value = float(v) / samples + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + +# checkpoint saving and loading +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) + + +def save_checkpoint_with_global_batch_idx( + out_dir: Path, + global_batch_idx: int, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +): + """Save training info after processing given number of batches. + + Args: + out_dir: + The directory to save the checkpoint. + global_batch_idx: + The number of batches processed so far from the very start of the + training. The saved checkpoint will have the following filename: + f'out_dir / checkpoint-{global_batch_idx}.pt' + model: + The neural network model whose `state_dict` will be saved in the + checkpoint. + params: + A dict of training configurations to be saved. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scaler: + The scaler used for mix precision training. Its `state_dict` will + be saved. + sampler: + The sampler used in the training dataset. + rank: + The rank ID used in DDP training of the current node. Set it to 0 + if DDP is not used. + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + filename = out_dir / f"checkpoint-{global_batch_idx}.pt" + save_checkpoint( + filename=filename, + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + scaler=scaler, + sampler=sampler, + rank=rank, + ) diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py new file mode 100644 index 000000000..da9d144f2 --- /dev/null +++ b/egs/ljspeech/tts/vits/vits.py @@ -0,0 +1,567 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""VITS module for GAN-TTS task.""" + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from hifigan import ( + HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator, +) +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + KLDivergenceLoss, + MelSpectrogramLoss, +) +from utils import get_segments +from generator import VITSGenerator + + +AVAILABLE_GENERATERS = { + "vits_generator": VITSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA +} + + +class VITS(nn.Module): + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` + """ + + def __init__( + self, + # generator related + vocab_size: int, + feature_dim: int = 513, + sampling_rate: int = 22050, + generator_type: str = "vits_generator", + generator_params: Dict[str, Any] = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + }, + # discriminator related + discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any] = { + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + # loss related + generator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params: Dict[str, Any] = { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + lambda_adv: float = 1.0, + lambda_mel: float = 45.0, + lambda_feat_match: float = 2.0, + lambda_dur: float = 1.0, + lambda_kl: float = 1.0, + cache_generator_outputs: bool = True, + ): + """Initialize VITS module. + + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since VITS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator + adversarial loss. + discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for + discriminator adversarial loss. + feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. + mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. + lambda_adv (float): Loss scaling coefficient for adversarial loss. + lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. + lambda_feat_match (float): Loss scaling coefficient for feat match loss. + lambda_dur (float): Loss scaling coefficient for duration loss. + lambda_kl (float): Loss scaling coefficient for KL divergence loss. + cache_generator_outputs (bool): Whether to cache generator outputs. + + """ + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "vits_generator": + # NOTE(kan-bayashi): Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) + self.generator = generator_class( + **generator_params, + ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, + ) + self.generator_adv_loss = GeneratorAdversarialLoss( + **generator_adv_loss_params, + ) + self.discriminator_adv_loss = DiscriminatorAdversarialLoss( + **discriminator_adv_loss_params, + ) + self.feat_match_loss = FeatureMatchLoss( + **feat_match_loss_params, + ) + mel_loss_params.update(sampling_rate=sampling_rate) + self.mel_loss = MelSpectrogramLoss( + **mel_loss_params, + ) + self.kl_loss = KLDivergenceLoss() + + # coefficients + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_kl = lambda_kl + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.sampling_rate = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + forward_generator: bool = True, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + + Returns: + Dict[str, Any]: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + - weight (Tensor): Weight tensor to summarize losses. + - optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + + def _forward_generator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + mel_loss = self.mel_loss(speech_hat_, speech_) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = torch.sum(dur_nll.float()) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + stats = dict( + generator_loss=loss.item(), + generator_mel_loss=mel_loss.item(), + generator_kl_loss=kl_loss.item(), + generator_dur_loss=dur_loss.item(), + generator_adv_loss=adv_loss.item(), + generator_feat_match_loss=feat_match_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def _forward_discrminator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_.detach()) + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss + + stats = dict( + discriminator_loss=loss.item(), + discriminator_real_loss=real_loss.item(), + discriminator_fake_loss=fake_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def inference( + self, + text: torch.Tensor, + feats: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Dict[str, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + sids (Tensor): Speaker index tensor (1,). + spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). + lids (Tensor): Language index tensor (1,). + durations (Tensor): Ground-truth duration tensor (T_text,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Dict[str, Tensor]: + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). + + """ + # setup + text = text[None] + text_lengths = torch.tensor( + [text.size(1)], + dtype=torch.long, + device=text.device, + ) + if sids is not None: + sids = sids.view(1) + if lids is not None: + lids = lids.view(1) + if durations is not None: + durations = durations.view(1, 1, -1) + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None].transpose(1, 2) + feats_lengths = torch.tensor( + [feats.size(2)], + dtype=torch.long, + device=feats.device, + ) + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + max_len=max_len, + use_teacher_forcing=use_teacher_forcing, + ) + else: + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + spembs=spembs, + lids=lids, + dur=durations, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0]) diff --git a/egs/ljspeech/tts/vits/wavenet.py b/egs/ljspeech/tts/vits/wavenet.py new file mode 100644 index 000000000..cbb44a8f4 --- /dev/null +++ b/egs/ljspeech/tts/vits/wavenet.py @@ -0,0 +1,349 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""WaveNet modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import math +import logging + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +class WaveNet(torch.nn.Module): + """WaveNet with global conditioning.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + layers: int = 30, + stacks: int = 3, + base_dilation: int = 2, + residual_channels: int = 64, + aux_channels: int = -1, + gate_channels: int = 128, + skip_channels: int = 64, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + use_first_conv: bool = False, + use_last_conv: bool = False, + scale_residual: bool = False, + scale_skip_connect: bool = False, + ): + """Initialize WaveNet module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + base_dilation (int): Base dilation factor. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for local conditioning feature. + global_channels (int): Number of channels for global conditioning feature. + dropout_rate (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_first_conv (bool): Whether to use the first conv layers. + use_last_conv (bool): Whether to use the last conv layers. + scale_residual (bool): Whether to scale the residual outputs. + scale_skip_connect (bool): Whether to scale the skip connection outputs. + + """ + super().__init__() + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + self.base_dilation = base_dilation + self.use_first_conv = use_first_conv + self.use_last_conv = use_last_conv + self.scale_skip_connect = scale_skip_connect + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + if self.use_first_conv: + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = base_dilation ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + global_channels=global_channels, + dilation=dilation, + dropout_rate=dropout_rate, + bias=bias, + scale_residual=scale_residual, + ) + self.conv_layers += [conv] + + # define output layers + if self.use_last_conv: + self.last_conv = torch.nn.Sequential( + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T) if use_first_conv else + (B, residual_channels, T). + x_mask (Optional[Tensor]): Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T) if use_last_conv else + (B, residual_channels, T). + + """ + # encode to hidden representation + if self.use_first_conv: + x = self.first_conv(x) + + # residual block + skips = 0.0 + for f in self.conv_layers: + x, h = f(x, x_mask=x_mask, c=c, g=g) + skips = skips + h + x = skips + if self.scale_skip_connect: + x = x * math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + if self.use_last_conv: + x = self.last_conv(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size( + layers: int, + stacks: int, + kernel_size: int, + base_dilation: int, + ) -> int: + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self) -> int: + """Return receptive field size.""" + return self._get_receptive_field_size( + self.layers, self.stacks, self.kernel_size, self.base_dilation + ) + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super().__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels: int, out_channels: int, bias: bool): + """Initialize 1x1 Conv1d module.""" + super().__init__( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size: int = 3, + residual_channels: int = 64, + gate_channels: int = 128, + skip_channels: int = 64, + aux_channels: int = 80, + global_channels: int = -1, + dropout_rate: float = 0.0, + dilation: int = 1, + bias: bool = True, + scale_residual: bool = False, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Number of local conditioning channels. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + scale_residual (bool): Whether to scale the residual outputs. + + """ + super().__init__() + self.dropout_rate = dropout_rate + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.scale_residual = scale_residual + + # check + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert gate_channels % 2 == 0 + + # dilation conv + padding = (kernel_size - 1) // 2 * dilation + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # global conditioning + if global_channels > 0: + self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) + else: + self.conv1x1_glo = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + + # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency + # (integrate res 1x1 + skip 1x1 convs) + self.conv1x1_out = Conv1d1x1( + gate_out_channels, residual_channels + skip_channels, bias=bias + ) + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout_rate, training=self.training) + x = self.conv(x) + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + # global conditioning + if g is not None: + g = self.conv1x1_glo(g) + ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ga, xb + gb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # residual + skip 1x1 conv + x = self.conv1x1_out(x) + if x_mask is not None: + x = x * x_mask + + # split integrated conv results + x, s = x.split([self.residual_channels, self.skip_channels], dim=1) + + # for residual connection + x = x + residual + if self.scale_residual: + x = x * math.sqrt(0.5) + + return x, s