mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
replace phonimizer with g2p
This commit is contained in:
parent
3df16b3f2b
commit
b719581e2f
116
egs/ljspeech/tts/local/prepare_token_file.py
Executable file
116
egs/ljspeech/tts/local/prepare_token_file.py
Executable file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import g2p_en
|
||||
import tacotron_cleaner.cleaners
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=Path,
|
||||
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default=Path("data/tokens.txt"),
|
||||
help="Path to the tokens",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
||||
"""Return a dict that maps token to IDs."""
|
||||
extra_tokens = {
|
||||
"<blk>": 0, # blank
|
||||
"<sos/eos>": 1, # sos and eos symbols.
|
||||
"<unk>": 2, # OOV
|
||||
}
|
||||
cut_set = load_manifest(manifest_file)
|
||||
g2p = g2p_en.G2p()
|
||||
counter = Counter()
|
||||
|
||||
for cut in cut_set:
|
||||
# Each cut only contain one supervision
|
||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||
text = cut.supervisions[0].normalized_text
|
||||
# Text normalization
|
||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||
# Convert to phonemes
|
||||
tokens = g2p(text)
|
||||
for t in tokens:
|
||||
counter[t] += 1
|
||||
|
||||
# Sort by the number of occurrences in descending order
|
||||
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
|
||||
|
||||
for token, idx in extra_tokens.items():
|
||||
tokens_and_counts.insert(idx, (token, None))
|
||||
|
||||
token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)}
|
||||
return token2id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
manifest_file = Path(args.manifest_file)
|
||||
out_file = Path(args.tokens)
|
||||
|
||||
token2id = get_token2id(manifest_file)
|
||||
write_mapping(out_file, token2id)
|
@ -52,7 +52,8 @@ def main():
|
||||
manifest_dir = Path(args.manifest_dir)
|
||||
prefix = "ljspeech"
|
||||
suffix = "jsonl.gz"
|
||||
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
|
||||
# all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
|
||||
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}")
|
||||
|
||||
cut_ids = list(all_cuts.ids)
|
||||
random.shuffle(cut_ids)
|
||||
|
@ -66,11 +66,50 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
# log "Stage 3: Phonemize the transcripts for LJSpeech"
|
||||
# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then
|
||||
# ./local/phonemize_text.py data/spectrogram
|
||||
# touch data/spectrogram/.ljspeech_phonemized.done
|
||||
# fi
|
||||
# fi
|
||||
|
||||
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
# log "Stage 4: Split the LJSpeech cuts into three sets"
|
||||
# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||
# ./local/split_subsets.py data/spectrogram
|
||||
# touch data/spectrogram/.ljspeech_split.done
|
||||
# fi
|
||||
# fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Split the LJSpeech cuts into three sets"
|
||||
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
|
||||
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||
./local/split_subsets.py data/spectrogram
|
||||
touch data/spectrogram/.ljspeech_split.done
|
||||
lhotse subset --last 600 \
|
||||
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
||||
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
||||
lhotse subset --first 100 \
|
||||
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/ljspeech_cuts_valid.jsonl.gz
|
||||
lhotse subset --last 500 \
|
||||
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/ljspeech_cuts_test.jsonl.gz
|
||||
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||
lhotse subset --first $n \
|
||||
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
||||
data/spectrogram/ljspeech_cuts_train.jsonl.gz
|
||||
touch data/spectrogram/.ljspeech_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Generate token file"
|
||||
if [ ! -e data/tokens.txt ]; then
|
||||
./local/prepare_token_file.py \
|
||||
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
|
||||
--tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -515,10 +515,12 @@ class VITSGenerator(torch.nn.Module):
|
||||
cum_dur_flat = cum_dur.view(b * t_x)
|
||||
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
|
||||
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
||||
path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
|
||||
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
|
||||
path = path.view(b, t_x, t_y).to(dtype=torch.float)
|
||||
# path will be like (t_x = 3, t_y = 5):
|
||||
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
||||
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
||||
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
||||
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
||||
# path = path.to(dtype=mask.dtype)
|
||||
return path.unsqueeze(1).transpose(2, 3) * mask
|
||||
|
366
egs/ljspeech/tts/vits/infer.py
Executable file
366
egs/ljspeech/tts/vits/infer.py
Executable file
@ -0,0 +1,366 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./zipformer/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
|
||||
from train2 import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import prepare_token_batch
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def infer_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
batch_size: int,
|
||||
cut_ids: List[str],
|
||||
audio: torch.Tensor,
|
||||
audio_pred: torch.Tensor,
|
||||
audio_lens: List[int],
|
||||
audio_lens_pred: List[int],
|
||||
):
|
||||
for i in range(batch_size):
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
|
||||
audio[i:i + 1, :audio_lens[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
|
||||
audio_pred[i:i + 1, :audio_lens_pred[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
num_cuts = 0
|
||||
log_interval = 10
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# We only want one background worker so that serialization is deterministic.
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
batch_size = len(batch["text"])
|
||||
text = batch["text"]
|
||||
tokens, tokens_lens = prepare_token_batch(text)
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
|
||||
audio = batch["audio"]
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
|
||||
audio_pred = audio_pred.detach().cpu()
|
||||
# convert to samples
|
||||
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
||||
)
|
||||
)
|
||||
|
||||
num_cuts += batch_size
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
# return results
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||
params.save_wav_dir = params.res_dir / "wav"
|
||||
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||
logging.info("Infer started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
|
||||
infer_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# save_results(
|
||||
# params=params,
|
||||
# test_set_name=test_set,
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -241,7 +241,8 @@ class MelSpectrogramLoss(torch.nn.Module):
|
||||
self,
|
||||
y_hat: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return_mel: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Calculate Mel-spectrogram loss.
|
||||
|
||||
Args:
|
||||
@ -259,6 +260,9 @@ class MelSpectrogramLoss(torch.nn.Module):
|
||||
mel = self.wav_to_mel(y.squeeze(1))
|
||||
mel_loss = F.l1_loss(mel_hat, mel)
|
||||
|
||||
if return_mel:
|
||||
return mel_loss, (mel_hat, mel)
|
||||
|
||||
return mel_loss
|
||||
|
||||
|
||||
|
80
egs/ljspeech/tts/vits/tokenizer.py
Normal file
80
egs/ljspeech/tts/vits/tokenizer.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import g2p_en
|
||||
import tacotron_cleaner.cleaners
|
||||
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
def __init__(self, tokens: str):
|
||||
"""
|
||||
Args:
|
||||
tokens: the file that maps tokens to ids
|
||||
"""
|
||||
# Parse token file
|
||||
self.token2id: Dict[str, int] = {}
|
||||
with open(tokens, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split()
|
||||
if len(info) == 1:
|
||||
# case of space
|
||||
token = " "
|
||||
id = int(info[0])
|
||||
else:
|
||||
token, id = info[0], int(info[1])
|
||||
self.token2id[token] = id
|
||||
|
||||
self.blank_id = self.token2id["<blk>"]
|
||||
self.oov_id = self.token2id["<unk>"]
|
||||
self.vocab_size = len(self.token2id)
|
||||
|
||||
self.g2p = g2p_en.G2p()
|
||||
|
||||
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
|
||||
"""
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts.
|
||||
intersperse_blank:
|
||||
Whether to intersperse blanks in the token sequence.
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
token_ids_list = []
|
||||
|
||||
for text in texts:
|
||||
# Text normalization
|
||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||
# Convert to phonemes
|
||||
tokens = self.g2p(text)
|
||||
token_ids = []
|
||||
for t in tokens:
|
||||
if t in self.token2id:
|
||||
token_ids.append(self.token2id[t])
|
||||
else:
|
||||
token_ids.append(self.oov_id)
|
||||
|
||||
if intersperse_blank:
|
||||
token_ids = intersperse(token_ids, self.blank_id)
|
||||
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
return token_ids_list
|
@ -1,10 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
@ -27,10 +49,10 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from symbols import symbol_table
|
||||
from tokenizer import Tokenizer
|
||||
from utils import (
|
||||
MetricsTracker,
|
||||
prepare_token_batch,
|
||||
plot_feature,
|
||||
save_checkpoint,
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
@ -101,6 +123,13 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
||||
)
|
||||
@ -213,16 +242,16 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": -1, # 0
|
||||
"log_interval": 50,
|
||||
"log_interval": 10,
|
||||
"draw_interval": 500,
|
||||
# "reset_interval": 200,
|
||||
"valid_interval": 500,
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 22050,
|
||||
"frame_shift": 256,
|
||||
"frame_length": 1024,
|
||||
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||
"vocab_size": len(symbol_table),
|
||||
"mel_loss_params": {
|
||||
"frame_shift": 256,
|
||||
"frame_length": 1024,
|
||||
"n_mels": 80,
|
||||
},
|
||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||
@ -287,11 +316,16 @@ def load_checkpoint_if_available(
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
mel_loss_params = params.mel_loss_params
|
||||
mel_loss_params.update(
|
||||
frame_length=params.frame_length,
|
||||
frame_shift=params.frame_shift,
|
||||
)
|
||||
model = VITS(
|
||||
vocab_size=params.vocab_size,
|
||||
feature_dim=params.feature_dim,
|
||||
sampling_rate=params.sampling_rate,
|
||||
mel_loss_params=params.mel_loss_params,
|
||||
mel_loss_params=mel_loss_params,
|
||||
lambda_adv=params.lambda_adv,
|
||||
lambda_mel=params.lambda_mel,
|
||||
lambda_feat_match=params.lambda_feat_match,
|
||||
@ -301,79 +335,30 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
"""Parse batch data"""
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
text = batch["text"]
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
tokens = tokenizer.texts_to_token_ids(text)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["text"])
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
text = batch["text"]
|
||||
tokens, tokens_lens = prepare_token_batch(text)
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
optimizer_g: Optimizer,
|
||||
optimizer_d: Optimizer,
|
||||
scheduler_g: LRSchedulerType,
|
||||
@ -442,18 +427,13 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["text"])
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
text = batch["text"]
|
||||
tokens, tokens_lens = prepare_token_batch(text)
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
|
||||
return_sample = params.batch_idx_train % params.log_interval == 0
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward discriminator
|
||||
@ -483,9 +463,13 @@ def train_one_epoch(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
return_sample=return_sample,
|
||||
)
|
||||
for k, v in stats_g.items():
|
||||
loss_info[k] = v * batch_size
|
||||
if "return_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
if return_sample:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"]
|
||||
# update generator
|
||||
optimizer_g.zero_grad()
|
||||
scaler.scale(loss_g).backward()
|
||||
@ -577,13 +561,27 @@ def train_one_epoch(
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
if return_sample:
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
||||
)
|
||||
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
@ -596,6 +594,12 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
|
||||
)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
@ -604,9 +608,87 @@ def train_one_epoch(
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
return_sample = None
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["text"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()])
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
||||
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
|
||||
return_sample = (audio_pred, audio_gt)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss, return_sample
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
tokenizer: Tokenizer,
|
||||
optimizer_g: torch.optim.Optimizer,
|
||||
optimizer_d: torch.optim.Optimizer,
|
||||
params: AttributeDict,
|
||||
@ -620,14 +702,8 @@ def scan_pessimistic_batches_for_oom(
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
text = batch["text"]
|
||||
tokens, tokens_lens = prepare_token_batch(text)
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
@ -702,6 +778,11 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.blank_id
|
||||
params.oov_id = tokenizer.oov_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -728,14 +809,14 @@ def run(rank, world_size, args):
|
||||
lr=params.lr,
|
||||
betas=(0.8, 0.99),
|
||||
eps=1e-9,
|
||||
weight_decay=0,
|
||||
# weight_decay=0,
|
||||
)
|
||||
optimizer_d = torch.optim.AdamW(
|
||||
discriminator.parameters(),
|
||||
lr=params.lr,
|
||||
betas=(0.8, 0.99),
|
||||
eps=1e-9,
|
||||
weight_decay=0,
|
||||
# weight_decay=0,
|
||||
)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||
@ -804,6 +885,7 @@ def run(rank, world_size, args):
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
tokenizer=tokenizer,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
params=params,
|
||||
@ -815,6 +897,8 @@ def run(rank, world_size, args):
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"Start epoch {epoch}")
|
||||
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
@ -826,6 +910,7 @@ def run(rank, world_size, args):
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
|
@ -131,7 +131,14 @@ class LJSpeechTtsDataModule:
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
@ -163,6 +170,7 @@ class LJSpeechTtsDataModule:
|
||||
train = SpeechSynthesisDataset(
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
@ -176,6 +184,7 @@ class LJSpeechTtsDataModule:
|
||||
train = SpeechSynthesisDataset(
|
||||
return_tokens=False,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
@ -229,11 +238,13 @@ class LJSpeechTtsDataModule:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_tokens=False,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
@ -264,11 +275,13 @@ class LJSpeechTtsDataModule:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_tokens=False,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
|
@ -211,6 +211,7 @@ def intersperse(sequence, item=0):
|
||||
|
||||
def prepare_token_batch(
|
||||
texts: List[str],
|
||||
phonemes: Optional[List[str]] = None,
|
||||
intersperse_blank: bool = True,
|
||||
blank_id: int = 0,
|
||||
pad_id: int = 0,
|
||||
@ -222,41 +223,50 @@ def prepare_token_batch(
|
||||
blank_id: index of blank token
|
||||
pad_id: padding index
|
||||
"""
|
||||
# normalize text
|
||||
normalized_texts = []
|
||||
for text in texts:
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_abbreviations(text)
|
||||
normalized_texts.append(text)
|
||||
if phonemes is None:
|
||||
# normalize text
|
||||
normalized_texts = []
|
||||
for text in texts:
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_abbreviations(text)
|
||||
normalized_texts.append(text)
|
||||
|
||||
# convert to phonemes
|
||||
phonemes = phonemize(
|
||||
normalized_texts,
|
||||
language='en-us',
|
||||
backend='espeak',
|
||||
strip=True,
|
||||
preserve_punctuation=True,
|
||||
with_stress=True,
|
||||
)
|
||||
# convert to phonemes
|
||||
phonemes = phonemize(
|
||||
normalized_texts,
|
||||
language='en-us',
|
||||
backend='espeak',
|
||||
strip=True,
|
||||
preserve_punctuation=True,
|
||||
with_stress=True,
|
||||
)
|
||||
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
|
||||
|
||||
# convert to symbol ids
|
||||
lengths = []
|
||||
sequences = []
|
||||
skip = False
|
||||
for idx, sequence in enumerate(phonemes):
|
||||
try:
|
||||
sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)]
|
||||
except RuntimeError:
|
||||
print(text[idx])
|
||||
print(normalized_texts[idx])
|
||||
sequence = [symbol_to_id[symbol] for symbol in sequence]
|
||||
except Exception:
|
||||
# print(texts[idx])
|
||||
# print(normalized_texts[idx])
|
||||
print(phonemes[idx])
|
||||
skip = True
|
||||
if intersperse_blank:
|
||||
sequence = intersperse(sequence, blank_id)
|
||||
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
||||
try:
|
||||
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
||||
except Exception:
|
||||
print(sequence)
|
||||
skip = True
|
||||
lengths.append(len(sequence))
|
||||
|
||||
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
|
||||
lengths = torch.tensor(lengths, dtype=torch.int64)
|
||||
return sequences, lengths
|
||||
return sequences, lengths, skip
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
@ -287,7 +297,7 @@ class MetricsTracker(collections.defaultdict):
|
||||
norm_value = "%.4g" % v
|
||||
ans += str(k) + "=" + str(norm_value) + ", "
|
||||
samples = "%.2f" % self["samples"]
|
||||
ans += "over" + str(samples) + " samples."
|
||||
ans += "over " + str(samples) + " samples."
|
||||
return ans
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
@ -468,3 +478,41 @@ def save_checkpoint_with_global_batch_idx(
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
# def plot_feature(feature):
|
||||
# """
|
||||
# Display the feature matrix as an image. Requires matplotlib to be installed.
|
||||
# """
|
||||
# import matplotlib.pyplot as plt
|
||||
#
|
||||
# feature = np.flip(feature.transpose(1, 0), 0)
|
||||
# return plt.matshow(feature)
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
|
||||
def plot_feature(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
@ -241,6 +241,7 @@ class VITS(nn.Module):
|
||||
feats_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
return_sample: bool = False,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
@ -276,6 +277,7 @@ class VITS(nn.Module):
|
||||
feats_lengths=feats_lengths,
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
return_sample=return_sample,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
@ -301,6 +303,7 @@ class VITS(nn.Module):
|
||||
feats_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
return_sample: bool = False,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
@ -367,7 +370,12 @@ class VITS(nn.Module):
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||
if not return_sample:
|
||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||
else:
|
||||
mel_loss, (mel_hat_, mel_) = self.mel_loss(
|
||||
speech_hat_, speech_, return_mel=True
|
||||
)
|
||||
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
||||
dur_loss = torch.sum(dur_nll.float())
|
||||
adv_loss = self.generator_adv_loss(p_hat)
|
||||
@ -389,6 +397,14 @@ class VITS(nn.Module):
|
||||
generator_feat_match_loss=feat_match_loss.item(),
|
||||
)
|
||||
|
||||
if return_sample:
|
||||
stats["return_sample"] = (
|
||||
speech_hat_[0].data.cpu().numpy(),
|
||||
speech_[0].data.cpu().numpy(),
|
||||
mel_hat_[0].data.cpu().numpy(),
|
||||
mel_[0].data.cpu().numpy(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
@ -564,4 +580,43 @@ class VITS(nn.Module):
|
||||
alpha=alpha,
|
||||
max_len=max_len,
|
||||
)
|
||||
return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0])
|
||||
return wav.view(-1), att_w[0], dur[0]
|
||||
|
||||
def inference_batch(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
durations: Optional[torch.Tensor] = None,
|
||||
noise_scale: float = 0.667,
|
||||
noise_scale_dur: float = 0.8,
|
||||
alpha: float = 1.0,
|
||||
max_len: Optional[int] = None,
|
||||
use_teacher_forcing: bool = False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run inference.
|
||||
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Input text index tensor (B,).
|
||||
noise_scale (float): Noise scale value for flow.
|
||||
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||
max_len (Optional[int]): Maximum length.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
* wav (Tensor): Generated waveform tensor (B, T_wav).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (B, T_text).
|
||||
|
||||
"""
|
||||
# inference
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
alpha=alpha,
|
||||
max_len=max_len,
|
||||
)
|
||||
return wav, att_w, dur
|
||||
|
Loading…
x
Reference in New Issue
Block a user