mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
replace phonimizer with g2p
This commit is contained in:
parent
3df16b3f2b
commit
b719581e2f
116
egs/ljspeech/tts/local/prepare_token_file.py
Executable file
116
egs/ljspeech/tts/local/prepare_token_file.py
Executable file
@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import g2p_en
|
||||||
|
import tacotron_cleaner.cleaners
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-file",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
|
||||||
|
help="Path to the manifest file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/tokens.txt"),
|
||||||
|
help="Path to the tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||||
|
"""Write a symbol to ID mapping to a file.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
No need to implement `read_mapping` as it can be done
|
||||||
|
through :func:`k2.SymbolTable.from_file`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename to save the mapping.
|
||||||
|
sym2id:
|
||||||
|
A dict mapping symbols to IDs.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
for sym, i in sym2id.items():
|
||||||
|
f.write(f"{sym} {i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
||||||
|
"""Return a dict that maps token to IDs."""
|
||||||
|
extra_tokens = {
|
||||||
|
"<blk>": 0, # blank
|
||||||
|
"<sos/eos>": 1, # sos and eos symbols.
|
||||||
|
"<unk>": 2, # OOV
|
||||||
|
}
|
||||||
|
cut_set = load_manifest(manifest_file)
|
||||||
|
g2p = g2p_en.G2p()
|
||||||
|
counter = Counter()
|
||||||
|
|
||||||
|
for cut in cut_set:
|
||||||
|
# Each cut only contain one supervision
|
||||||
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
|
text = cut.supervisions[0].normalized_text
|
||||||
|
# Text normalization
|
||||||
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
|
# Convert to phonemes
|
||||||
|
tokens = g2p(text)
|
||||||
|
for t in tokens:
|
||||||
|
counter[t] += 1
|
||||||
|
|
||||||
|
# Sort by the number of occurrences in descending order
|
||||||
|
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
|
||||||
|
|
||||||
|
for token, idx in extra_tokens.items():
|
||||||
|
tokens_and_counts.insert(idx, (token, None))
|
||||||
|
|
||||||
|
token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)}
|
||||||
|
return token2id
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
manifest_file = Path(args.manifest_file)
|
||||||
|
out_file = Path(args.tokens)
|
||||||
|
|
||||||
|
token2id = get_token2id(manifest_file)
|
||||||
|
write_mapping(out_file, token2id)
|
@ -52,7 +52,8 @@ def main():
|
|||||||
manifest_dir = Path(args.manifest_dir)
|
manifest_dir = Path(args.manifest_dir)
|
||||||
prefix = "ljspeech"
|
prefix = "ljspeech"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
|
# all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
|
||||||
|
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}")
|
||||||
|
|
||||||
cut_ids = list(all_cuts.ids)
|
cut_ids = list(all_cuts.ids)
|
||||||
random.shuffle(cut_ids)
|
random.shuffle(cut_ids)
|
||||||
|
@ -66,11 +66,50 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
# log "Stage 3: Phonemize the transcripts for LJSpeech"
|
||||||
|
# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then
|
||||||
|
# ./local/phonemize_text.py data/spectrogram
|
||||||
|
# touch data/spectrogram/.ljspeech_phonemized.done
|
||||||
|
# fi
|
||||||
|
# fi
|
||||||
|
|
||||||
|
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
# log "Stage 4: Split the LJSpeech cuts into three sets"
|
||||||
|
# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||||
|
# ./local/split_subsets.py data/spectrogram
|
||||||
|
# touch data/spectrogram/.ljspeech_split.done
|
||||||
|
# fi
|
||||||
|
# fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Split the LJSpeech cuts into three sets"
|
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
|
||||||
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||||
./local/split_subsets.py data/spectrogram
|
lhotse subset --last 600 \
|
||||||
touch data/spectrogram/.ljspeech_split.done
|
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
||||||
|
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
||||||
|
lhotse subset --first 100 \
|
||||||
|
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
||||||
|
data/spectrogram/ljspeech_cuts_valid.jsonl.gz
|
||||||
|
lhotse subset --last 500 \
|
||||||
|
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
||||||
|
data/spectrogram/ljspeech_cuts_test.jsonl.gz
|
||||||
|
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
||||||
|
|
||||||
|
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||||
|
lhotse subset --first $n \
|
||||||
|
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
||||||
|
data/spectrogram/ljspeech_cuts_train.jsonl.gz
|
||||||
|
touch data/spectrogram/.ljspeech_split.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Generate token file"
|
||||||
|
if [ ! -e data/tokens.txt ]; then
|
||||||
|
./local/prepare_token_file.py \
|
||||||
|
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
|
||||||
|
--tokens data/tokens.txt
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -515,10 +515,12 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
cum_dur_flat = cum_dur.view(b * t_x)
|
cum_dur_flat = cum_dur.view(b * t_x)
|
||||||
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
|
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
|
||||||
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
||||||
path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
|
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
|
||||||
|
path = path.view(b, t_x, t_y).to(dtype=torch.float)
|
||||||
# path will be like (t_x = 3, t_y = 5):
|
# path will be like (t_x = 3, t_y = 5):
|
||||||
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
||||||
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
||||||
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
||||||
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
||||||
|
# path = path.to(dtype=mask.dtype)
|
||||||
return path.unsqueeze(1).transpose(2, 3) * mask
|
return path.unsqueeze(1).transpose(2, 3) * mask
|
||||||
|
366
egs/ljspeech/tts/vits/infer.py
Executable file
366
egs/ljspeech/tts/vits/infer.py
Executable file
@ -0,0 +1,366 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) beam search (not recommended)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search (one best)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(7) fast beam search (with LG)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from train2 import get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
make_pad_mask,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
from utils import prepare_token_batch
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def infer_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
# Background worker save audios to disk.
|
||||||
|
def _save_worker(
|
||||||
|
batch_size: int,
|
||||||
|
cut_ids: List[str],
|
||||||
|
audio: torch.Tensor,
|
||||||
|
audio_pred: torch.Tensor,
|
||||||
|
audio_lens: List[int],
|
||||||
|
audio_lens_pred: List[int],
|
||||||
|
):
|
||||||
|
for i in range(batch_size):
|
||||||
|
torchaudio.save(
|
||||||
|
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
||||||
|
audio[i:i + 1, :audio_lens[i]],
|
||||||
|
sample_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
torchaudio.save(
|
||||||
|
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
||||||
|
audio_pred[i:i + 1, :audio_lens_pred[i]],
|
||||||
|
sample_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
num_cuts = 0
|
||||||
|
log_interval = 10
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
# We only want one background worker so that serialization is deterministic.
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
batch_size = len(batch["text"])
|
||||||
|
text = batch["text"]
|
||||||
|
tokens, tokens_lens = prepare_token_batch(text)
|
||||||
|
tokens = tokens.to(device)
|
||||||
|
tokens_lens = tokens_lens.to(device)
|
||||||
|
|
||||||
|
audio = batch["audio"]
|
||||||
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
|
||||||
|
audio_pred = audio_pred.detach().cpu()
|
||||||
|
# convert to samples
|
||||||
|
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||||
|
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cuts += batch_size
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
# return results
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LJSpeechTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||||
|
params.save_wav_dir = params.res_dir / "wav"
|
||||||
|
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||||
|
logging.info("Infer started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = ljspeech.test_cuts()
|
||||||
|
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
infer_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# save_results(
|
||||||
|
# params=params,
|
||||||
|
# test_set_name=test_set,
|
||||||
|
# results_dict=results_dict,
|
||||||
|
# )
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
# torch.set_num_threads(1)
|
||||||
|
# torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -241,7 +241,8 @@ class MelSpectrogramLoss(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
y_hat: torch.Tensor,
|
y_hat: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
return_mel: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
"""Calculate Mel-spectrogram loss.
|
"""Calculate Mel-spectrogram loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -259,6 +260,9 @@ class MelSpectrogramLoss(torch.nn.Module):
|
|||||||
mel = self.wav_to_mel(y.squeeze(1))
|
mel = self.wav_to_mel(y.squeeze(1))
|
||||||
mel_loss = F.l1_loss(mel_hat, mel)
|
mel_loss = F.l1_loss(mel_hat, mel)
|
||||||
|
|
||||||
|
if return_mel:
|
||||||
|
return mel_loss, (mel_hat, mel)
|
||||||
|
|
||||||
return mel_loss
|
return mel_loss
|
||||||
|
|
||||||
|
|
||||||
|
80
egs/ljspeech/tts/vits/tokenizer.py
Normal file
80
egs/ljspeech/tts/vits/tokenizer.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import g2p_en
|
||||||
|
import tacotron_cleaner.cleaners
|
||||||
|
|
||||||
|
from utils import intersperse
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer(object):
|
||||||
|
def __init__(self, tokens: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens: the file that maps tokens to ids
|
||||||
|
"""
|
||||||
|
# Parse token file
|
||||||
|
self.token2id: Dict[str, int] = {}
|
||||||
|
with open(tokens, "r", encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split()
|
||||||
|
if len(info) == 1:
|
||||||
|
# case of space
|
||||||
|
token = " "
|
||||||
|
id = int(info[0])
|
||||||
|
else:
|
||||||
|
token, id = info[0], int(info[1])
|
||||||
|
self.token2id[token] = id
|
||||||
|
|
||||||
|
self.blank_id = self.token2id["<blk>"]
|
||||||
|
self.oov_id = self.token2id["<unk>"]
|
||||||
|
self.vocab_size = len(self.token2id)
|
||||||
|
|
||||||
|
self.g2p = g2p_en.G2p()
|
||||||
|
|
||||||
|
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of transcripts.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Text normalization
|
||||||
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
|
# Convert to phonemes
|
||||||
|
tokens = self.g2p(text)
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t in self.token2id:
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
else:
|
||||||
|
token_ids.append(self.oov_id)
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.blank_id)
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
@ -1,10 +1,32 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Mingshuang Luo,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Daniel Povey)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -27,10 +49,10 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
from symbols import symbol_table
|
from tokenizer import Tokenizer
|
||||||
from utils import (
|
from utils import (
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
prepare_token_batch,
|
plot_feature,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
)
|
)
|
||||||
@ -101,6 +123,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -213,16 +242,16 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": -1, # 0
|
"batch_idx_train": -1, # 0
|
||||||
"log_interval": 50,
|
"log_interval": 10,
|
||||||
|
"draw_interval": 500,
|
||||||
# "reset_interval": 200,
|
# "reset_interval": 200,
|
||||||
"valid_interval": 500,
|
"valid_interval": 200,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"sampling_rate": 22050,
|
"sampling_rate": 22050,
|
||||||
|
"frame_shift": 256,
|
||||||
|
"frame_length": 1024,
|
||||||
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||||
"vocab_size": len(symbol_table),
|
|
||||||
"mel_loss_params": {
|
"mel_loss_params": {
|
||||||
"frame_shift": 256,
|
|
||||||
"frame_length": 1024,
|
|
||||||
"n_mels": 80,
|
"n_mels": 80,
|
||||||
},
|
},
|
||||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||||
@ -287,11 +316,16 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
|
mel_loss_params = params.mel_loss_params
|
||||||
|
mel_loss_params.update(
|
||||||
|
frame_length=params.frame_length,
|
||||||
|
frame_shift=params.frame_shift,
|
||||||
|
)
|
||||||
model = VITS(
|
model = VITS(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
feature_dim=params.feature_dim,
|
feature_dim=params.feature_dim,
|
||||||
sampling_rate=params.sampling_rate,
|
sampling_rate=params.sampling_rate,
|
||||||
mel_loss_params=params.mel_loss_params,
|
mel_loss_params=mel_loss_params,
|
||||||
lambda_adv=params.lambda_adv,
|
lambda_adv=params.lambda_adv,
|
||||||
lambda_mel=params.lambda_mel,
|
lambda_mel=params.lambda_mel,
|
||||||
lambda_feat_match=params.lambda_feat_match,
|
lambda_feat_match=params.lambda_feat_match,
|
||||||
@ -301,79 +335,30 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||||
params: AttributeDict,
|
"""Parse batch data"""
|
||||||
model: Union[nn.Module, DDP],
|
audio = batch["audio"].to(device)
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
features = batch["features"].to(device)
|
||||||
world_size: int = 1,
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
) -> MetricsTracker:
|
features_lens = batch["features_lens"].to(device)
|
||||||
"""Run the validation process."""
|
text = batch["text"]
|
||||||
model.eval()
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
|
||||||
|
|
||||||
# used to summary the stats over iterations
|
tokens = tokenizer.texts_to_token_ids(text)
|
||||||
tot_loss = MetricsTracker()
|
tokens = k2.RaggedTensor(tokens)
|
||||||
|
row_splits = tokens.shape.row_splits(1)
|
||||||
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
tokens = tokens.to(device)
|
||||||
|
tokens_lens = tokens_lens.to(device)
|
||||||
|
# a tensor of shape (B, T)
|
||||||
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
with torch.no_grad():
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
|
||||||
batch_size = len(batch["text"])
|
|
||||||
audio = batch["audio"].to(device)
|
|
||||||
features = batch["features"].to(device)
|
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
|
||||||
features_lens = batch["features_lens"].to(device)
|
|
||||||
text = batch["text"]
|
|
||||||
tokens, tokens_lens = prepare_token_batch(text)
|
|
||||||
tokens = tokens.to(device)
|
|
||||||
tokens_lens = tokens_lens.to(device)
|
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
|
||||||
loss_info['samples'] = batch_size
|
|
||||||
|
|
||||||
# forward discriminator
|
|
||||||
loss_d, stats_d = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=False,
|
|
||||||
)
|
|
||||||
assert loss_d.requires_grad is False
|
|
||||||
for k, v in stats_d.items():
|
|
||||||
loss_info[k] = v * batch_size
|
|
||||||
|
|
||||||
# forward generator
|
|
||||||
loss_g, stats_g = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=True,
|
|
||||||
)
|
|
||||||
assert loss_g.requires_grad is False
|
|
||||||
for k, v in stats_g.items():
|
|
||||||
loss_info[k] = v * batch_size
|
|
||||||
|
|
||||||
# summary stats
|
|
||||||
tot_loss = tot_loss + loss_info
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
tot_loss.reduce(device)
|
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
|
||||||
if loss_value < params.best_valid_loss:
|
|
||||||
params.best_valid_epoch = params.cur_epoch
|
|
||||||
params.best_valid_loss = loss_value
|
|
||||||
|
|
||||||
return tot_loss
|
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
optimizer_g: Optimizer,
|
optimizer_g: Optimizer,
|
||||||
optimizer_d: Optimizer,
|
optimizer_d: Optimizer,
|
||||||
scheduler_g: LRSchedulerType,
|
scheduler_g: LRSchedulerType,
|
||||||
@ -442,18 +427,13 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
audio = batch["audio"].to(device)
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||||
features = batch["features"].to(device)
|
prepare_input(batch, tokenizer, device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
|
||||||
features_lens = batch["features_lens"].to(device)
|
|
||||||
text = batch["text"]
|
|
||||||
tokens, tokens_lens = prepare_token_batch(text)
|
|
||||||
tokens = tokens.to(device)
|
|
||||||
tokens_lens = tokens_lens.to(device)
|
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info['samples'] = batch_size
|
loss_info['samples'] = batch_size
|
||||||
|
|
||||||
|
return_sample = params.batch_idx_train % params.log_interval == 0
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
@ -483,9 +463,13 @@ def train_one_epoch(
|
|||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
|
return_sample=return_sample,
|
||||||
)
|
)
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
loss_info[k] = v * batch_size
|
if "return_sample" not in k:
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
if return_sample:
|
||||||
|
speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"]
|
||||||
# update generator
|
# update generator
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
scaler.scale(loss_g).backward()
|
scaler.scale(loss_g).backward()
|
||||||
@ -577,13 +561,27 @@ def train_one_epoch(
|
|||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
if return_sample:
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
|
||||||
|
)
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
|
||||||
|
)
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
||||||
|
)
|
||||||
|
|
||||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
@ -596,6 +594,12 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
@ -604,9 +608,87 @@ def train_one_epoch(
|
|||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> MetricsTracker:
|
||||||
|
"""Run the validation process."""
|
||||||
|
model.eval()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to summary the stats over iterations
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
return_sample = None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
batch_size = len(batch["text"])
|
||||||
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||||
|
prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info['samples'] = batch_size
|
||||||
|
|
||||||
|
# forward discriminator
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
assert loss_d.requires_grad is False
|
||||||
|
for k, v in stats_d.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# forward generator
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
)
|
||||||
|
assert loss_g.requires_grad is False
|
||||||
|
for k, v in stats_g.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
|
# infer for first batch:
|
||||||
|
if batch_idx == 0 and rank == 0:
|
||||||
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
|
audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()])
|
||||||
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
|
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||||
|
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
||||||
|
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
||||||
|
return_sample = (audio_pred, audio_gt)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
tot_loss.reduce(device)
|
||||||
|
|
||||||
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
|
if loss_value < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss, return_sample
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
optimizer_g: torch.optim.Optimizer,
|
optimizer_g: torch.optim.Optimizer,
|
||||||
optimizer_d: torch.optim.Optimizer,
|
optimizer_d: torch.optim.Optimizer,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -620,14 +702,8 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
audio = batch["audio"].to(device)
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||||
features = batch["features"].to(device)
|
prepare_input(batch, tokenizer, device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
|
||||||
features_lens = batch["features_lens"].to(device)
|
|
||||||
text = batch["text"]
|
|
||||||
tokens, tokens_lens = prepare_token_batch(text)
|
|
||||||
tokens = tokens.to(device)
|
|
||||||
tokens_lens = tokens_lens.to(device)
|
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
@ -702,6 +778,11 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.blank_id
|
||||||
|
params.oov_id = tokenizer.oov_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -728,14 +809,14 @@ def run(rank, world_size, args):
|
|||||||
lr=params.lr,
|
lr=params.lr,
|
||||||
betas=(0.8, 0.99),
|
betas=(0.8, 0.99),
|
||||||
eps=1e-9,
|
eps=1e-9,
|
||||||
weight_decay=0,
|
# weight_decay=0,
|
||||||
)
|
)
|
||||||
optimizer_d = torch.optim.AdamW(
|
optimizer_d = torch.optim.AdamW(
|
||||||
discriminator.parameters(),
|
discriminator.parameters(),
|
||||||
lr=params.lr,
|
lr=params.lr,
|
||||||
betas=(0.8, 0.99),
|
betas=(0.8, 0.99),
|
||||||
eps=1e-9,
|
eps=1e-9,
|
||||||
weight_decay=0,
|
# weight_decay=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||||
@ -804,6 +885,7 @@ def run(rank, world_size, args):
|
|||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
params=params,
|
params=params,
|
||||||
@ -815,6 +897,8 @@ def run(rank, world_size, args):
|
|||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
|
logging.info(f"Start epoch {epoch}")
|
||||||
|
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
@ -826,6 +910,7 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
|
@ -131,7 +131,14 @@ class LJSpeechTtsDataModule:
|
|||||||
default=True,
|
default=True,
|
||||||
help="Whether to drop last batch. Used by sampler.",
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -163,6 +170,7 @@ class LJSpeechTtsDataModule:
|
|||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_tokens=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
@ -176,6 +184,7 @@ class LJSpeechTtsDataModule:
|
|||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_tokens=False,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
@ -229,11 +238,13 @@ class LJSpeechTtsDataModule:
|
|||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_tokens=False,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_tokens=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
@ -264,11 +275,13 @@ class LJSpeechTtsDataModule:
|
|||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_tokens=False,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_tokens=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
test_sampler = DynamicBucketingSampler(
|
test_sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts,
|
||||||
|
@ -211,6 +211,7 @@ def intersperse(sequence, item=0):
|
|||||||
|
|
||||||
def prepare_token_batch(
|
def prepare_token_batch(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
|
phonemes: Optional[List[str]] = None,
|
||||||
intersperse_blank: bool = True,
|
intersperse_blank: bool = True,
|
||||||
blank_id: int = 0,
|
blank_id: int = 0,
|
||||||
pad_id: int = 0,
|
pad_id: int = 0,
|
||||||
@ -222,41 +223,50 @@ def prepare_token_batch(
|
|||||||
blank_id: index of blank token
|
blank_id: index of blank token
|
||||||
pad_id: padding index
|
pad_id: padding index
|
||||||
"""
|
"""
|
||||||
# normalize text
|
if phonemes is None:
|
||||||
normalized_texts = []
|
# normalize text
|
||||||
for text in texts:
|
normalized_texts = []
|
||||||
text = convert_to_ascii(text)
|
for text in texts:
|
||||||
text = lowercase(text)
|
text = convert_to_ascii(text)
|
||||||
text = expand_abbreviations(text)
|
text = lowercase(text)
|
||||||
normalized_texts.append(text)
|
text = expand_abbreviations(text)
|
||||||
|
normalized_texts.append(text)
|
||||||
|
|
||||||
# convert to phonemes
|
# convert to phonemes
|
||||||
phonemes = phonemize(
|
phonemes = phonemize(
|
||||||
normalized_texts,
|
normalized_texts,
|
||||||
language='en-us',
|
language='en-us',
|
||||||
backend='espeak',
|
backend='espeak',
|
||||||
strip=True,
|
strip=True,
|
||||||
preserve_punctuation=True,
|
preserve_punctuation=True,
|
||||||
with_stress=True,
|
with_stress=True,
|
||||||
)
|
)
|
||||||
|
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
|
||||||
|
|
||||||
# convert to symbol ids
|
# convert to symbol ids
|
||||||
lengths = []
|
lengths = []
|
||||||
sequences = []
|
sequences = []
|
||||||
|
skip = False
|
||||||
for idx, sequence in enumerate(phonemes):
|
for idx, sequence in enumerate(phonemes):
|
||||||
try:
|
try:
|
||||||
sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)]
|
sequence = [symbol_to_id[symbol] for symbol in sequence]
|
||||||
except RuntimeError:
|
except Exception:
|
||||||
print(text[idx])
|
# print(texts[idx])
|
||||||
print(normalized_texts[idx])
|
# print(normalized_texts[idx])
|
||||||
|
print(phonemes[idx])
|
||||||
|
skip = True
|
||||||
if intersperse_blank:
|
if intersperse_blank:
|
||||||
sequence = intersperse(sequence, blank_id)
|
sequence = intersperse(sequence, blank_id)
|
||||||
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
try:
|
||||||
|
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
||||||
|
except Exception:
|
||||||
|
print(sequence)
|
||||||
|
skip = True
|
||||||
lengths.append(len(sequence))
|
lengths.append(len(sequence))
|
||||||
|
|
||||||
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
|
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
|
||||||
lengths = torch.tensor(lengths, dtype=torch.int64)
|
lengths = torch.tensor(lengths, dtype=torch.int64)
|
||||||
return sequences, lengths
|
return sequences, lengths, skip
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
@ -287,7 +297,7 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
norm_value = "%.4g" % v
|
norm_value = "%.4g" % v
|
||||||
ans += str(k) + "=" + str(norm_value) + ", "
|
ans += str(k) + "=" + str(norm_value) + ", "
|
||||||
samples = "%.2f" % self["samples"]
|
samples = "%.2f" % self["samples"]
|
||||||
ans += "over" + str(samples) + " samples."
|
ans += "over " + str(samples) + " samples."
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def norm_items(self) -> List[Tuple[str, float]]:
|
def norm_items(self) -> List[Tuple[str, float]]:
|
||||||
@ -468,3 +478,41 @@ def save_checkpoint_with_global_batch_idx(
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# def plot_feature(feature):
|
||||||
|
# """
|
||||||
|
# Display the feature matrix as an image. Requires matplotlib to be installed.
|
||||||
|
# """
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
#
|
||||||
|
# feature = np.flip(feature.transpose(1, 0), 0)
|
||||||
|
# return plt.matshow(feature)
|
||||||
|
|
||||||
|
MATPLOTLIB_FLAG = False
|
||||||
|
|
||||||
|
|
||||||
|
def plot_feature(spectrogram):
|
||||||
|
global MATPLOTLIB_FLAG
|
||||||
|
if not MATPLOTLIB_FLAG:
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
MATPLOTLIB_FLAG = True
|
||||||
|
mpl_logger = logging.getLogger('matplotlib')
|
||||||
|
mpl_logger.setLevel(logging.WARNING)
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||||
|
interpolation='none')
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
plt.xlabel("Frames")
|
||||||
|
plt.ylabel("Channels")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
@ -241,6 +241,7 @@ class VITS(nn.Module):
|
|||||||
feats_lengths: torch.Tensor,
|
feats_lengths: torch.Tensor,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
||||||
|
return_sample: bool = False,
|
||||||
sids: Optional[torch.Tensor] = None,
|
sids: Optional[torch.Tensor] = None,
|
||||||
spembs: Optional[torch.Tensor] = None,
|
spembs: Optional[torch.Tensor] = None,
|
||||||
lids: Optional[torch.Tensor] = None,
|
lids: Optional[torch.Tensor] = None,
|
||||||
@ -276,6 +277,7 @@ class VITS(nn.Module):
|
|||||||
feats_lengths=feats_lengths,
|
feats_lengths=feats_lengths,
|
||||||
speech=speech,
|
speech=speech,
|
||||||
speech_lengths=speech_lengths,
|
speech_lengths=speech_lengths,
|
||||||
|
return_sample=return_sample,
|
||||||
sids=sids,
|
sids=sids,
|
||||||
spembs=spembs,
|
spembs=spembs,
|
||||||
lids=lids,
|
lids=lids,
|
||||||
@ -301,6 +303,7 @@ class VITS(nn.Module):
|
|||||||
feats_lengths: torch.Tensor,
|
feats_lengths: torch.Tensor,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
||||||
|
return_sample: bool = False,
|
||||||
sids: Optional[torch.Tensor] = None,
|
sids: Optional[torch.Tensor] = None,
|
||||||
spembs: Optional[torch.Tensor] = None,
|
spembs: Optional[torch.Tensor] = None,
|
||||||
lids: Optional[torch.Tensor] = None,
|
lids: Optional[torch.Tensor] = None,
|
||||||
@ -367,7 +370,12 @@ class VITS(nn.Module):
|
|||||||
|
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
if not return_sample:
|
||||||
|
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||||
|
else:
|
||||||
|
mel_loss, (mel_hat_, mel_) = self.mel_loss(
|
||||||
|
speech_hat_, speech_, return_mel=True
|
||||||
|
)
|
||||||
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
||||||
dur_loss = torch.sum(dur_nll.float())
|
dur_loss = torch.sum(dur_nll.float())
|
||||||
adv_loss = self.generator_adv_loss(p_hat)
|
adv_loss = self.generator_adv_loss(p_hat)
|
||||||
@ -389,6 +397,14 @@ class VITS(nn.Module):
|
|||||||
generator_feat_match_loss=feat_match_loss.item(),
|
generator_feat_match_loss=feat_match_loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_sample:
|
||||||
|
stats["return_sample"] = (
|
||||||
|
speech_hat_[0].data.cpu().numpy(),
|
||||||
|
speech_[0].data.cpu().numpy(),
|
||||||
|
mel_hat_[0].data.cpu().numpy(),
|
||||||
|
mel_[0].data.cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
# reset cache
|
# reset cache
|
||||||
if reuse_cache or not self.training:
|
if reuse_cache or not self.training:
|
||||||
self._cache = None
|
self._cache = None
|
||||||
@ -564,4 +580,43 @@ class VITS(nn.Module):
|
|||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
max_len=max_len,
|
max_len=max_len,
|
||||||
)
|
)
|
||||||
return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0])
|
return wav.view(-1), att_w[0], dur[0]
|
||||||
|
|
||||||
|
def inference_batch(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
durations: Optional[torch.Tensor] = None,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
max_len: Optional[int] = None,
|
||||||
|
use_teacher_forcing: bool = False,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Run inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Tensor): Input text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Input text index tensor (B,).
|
||||||
|
noise_scale (float): Noise scale value for flow.
|
||||||
|
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||||
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
|
max_len (Optional[int]): Maximum length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor]:
|
||||||
|
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||||
|
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||||
|
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# inference
|
||||||
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
max_len=max_len,
|
||||||
|
)
|
||||||
|
return wav, att_w, dur
|
||||||
|
Loading…
x
Reference in New Issue
Block a user