replace phonimizer with g2p

This commit is contained in:
yaozengwei 2023-10-28 21:16:43 +08:00
parent 3df16b3f2b
commit b719581e2f
11 changed files with 935 additions and 126 deletions

View File

@ -0,0 +1,116 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
"""
import argparse
import logging
from collections import Counter
from pathlib import Path
from typing import Dict
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import load_manifest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = {
"<blk>": 0, # blank
"<sos/eos>": 1, # sos and eos symbols.
"<unk>": 2, # OOV
}
cut_set = load_manifest(manifest_file)
g2p = g2p_en.G2p()
counter = Counter()
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = g2p(text)
for t in tokens:
counter[t] += 1
# Sort by the number of occurrences in descending order
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
for token, idx in extra_tokens.items():
tokens_and_counts.insert(idx, (token, None))
token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)}
return token2id
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -52,7 +52,8 @@ def main():
manifest_dir = Path(args.manifest_dir)
prefix = "ljspeech"
suffix = "jsonl.gz"
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
# all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}")
all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}")
cut_ids = list(all_cuts.ids)
random.shuffle(cut_ids)

View File

@ -66,11 +66,50 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
fi
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# log "Stage 3: Phonemize the transcripts for LJSpeech"
# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then
# ./local/phonemize_text.py data/spectrogram
# touch data/spectrogram/.ljspeech_phonemized.done
# fi
# fi
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# log "Stage 4: Split the LJSpeech cuts into three sets"
# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
# ./local/split_subsets.py data/spectrogram
# touch data/spectrogram/.ljspeech_split.done
# fi
# fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split the LJSpeech cuts into three sets"
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
./local/split_subsets.py data/spectrogram
touch data/spectrogram/.ljspeech_split.done
lhotse subset --last 600 \
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
data/spectrogram/ljspeech_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
data/spectrogram/ljspeech_cuts_test.jsonl.gz
rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
data/spectrogram/ljspeech_cuts_train.jsonl.gz
touch data/spectrogram/.ljspeech_split.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Generate token file"
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
--tokens data/tokens.txt
fi
fi

View File

@ -515,10 +515,12 @@ class VITSGenerator(torch.nn.Module):
cum_dur_flat = cum_dur.view(b * t_x)
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
path = path.view(b, t_x, t_y).to(dtype=torch.float)
# path will be like (t_x = 3, t_y = 5):
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
# path = path.to(dtype=mask.dtype)
return path.unsqueeze(1).transpose(2, 3) * mask

366
egs/ljspeech/tts/vits/infer.py Executable file
View File

@ -0,0 +1,366 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torchaudio
from train2 import get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from tts_datamodule import LJSpeechTtsDataModule
from utils import prepare_token_batch
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
return parser
def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
# Background worker save audios to disk.
def _save_worker(
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
audio_pred: torch.Tensor,
audio_lens: List[int],
audio_lens_pred: List[int],
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
audio[i:i + 1, :audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
audio_pred[i:i + 1, :audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
device = next(model.parameters()).device
num_cuts = 0
log_interval = 10
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
# We only want one background worker so that serialization is deterministic.
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"])
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
# import pdb
# pdb.set_trace()
futures.append(
executor.submit(
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
)
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
# return results
for f in futures:
f.result()
@torch.no_grad()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
infer_dataset(
dl=test_dl,
params=params,
model=model,
)
# save_results(
# params=params,
# test_set_name=test_set,
# results_dict=results_dict,
# )
logging.info("Done!")
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -241,7 +241,8 @@ class MelSpectrogramLoss(torch.nn.Module):
self,
y_hat: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
return_mel: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
@ -259,6 +260,9 @@ class MelSpectrogramLoss(torch.nn.Module):
mel = self.wav_to_mel(y.squeeze(1))
mel_loss = F.l1_loss(mel_hat, mel)
if return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss

View File

@ -0,0 +1,80 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
import g2p_en
import tacotron_cleaner.cleaners
from utils import intersperse
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
self.token2id[token] = id
self.blank_id = self.token2id["<blk>"]
self.oov_id = self.token2id["<unk>"]
self.vocab_size = len(self.token2id)
self.g2p = g2p_en.G2p()
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for text in texts:
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = self.g2p(text)
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list

View File

@ -1,10 +1,32 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
@ -27,10 +49,10 @@ from icefall.utils import (
str2bool,
)
from symbols import symbol_table
from tokenizer import Tokenizer
from utils import (
MetricsTracker,
prepare_token_batch,
plot_feature,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
@ -101,6 +123,13 @@ def get_parser():
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--lr", type=float, default=2.0e-4, help="The base learning rate."
)
@ -213,16 +242,16 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 50,
"log_interval": 10,
"draw_interval": 500,
# "reset_interval": 200,
"valid_interval": 500,
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 22050,
"frame_shift": 256,
"frame_length": 1024,
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
"vocab_size": len(symbol_table),
"mel_loss_params": {
"frame_shift": 256,
"frame_length": 1024,
"n_mels": 80,
},
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
@ -287,11 +316,16 @@ def load_checkpoint_if_available(
def get_model(params: AttributeDict) -> nn.Module:
mel_loss_params = params.mel_loss_params
mel_loss_params.update(
frame_length=params.frame_length,
frame_shift=params.frame_shift,
)
model = VITS(
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
mel_loss_params=params.mel_loss_params,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
lambda_feat_match=params.lambda_feat_match,
@ -301,79 +335,30 @@ def get_model(params: AttributeDict) -> nn.Module:
return model
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
"""Parse batch data"""
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
# used to summary the stats over iterations
tot_loss = MetricsTracker()
tokens = tokenizer.texts_to_token_ids(text)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["text"])
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
return audio, audio_lens, features, features_lens, tokens, tokens_lens
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer_g: Optimizer,
optimizer_d: Optimizer,
scheduler_g: LRSchedulerType,
@ -442,18 +427,13 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["text"])
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
return_sample = params.batch_idx_train % params.log_interval == 0
try:
with autocast(enabled=params.use_fp16):
# forward discriminator
@ -483,9 +463,13 @@ def train_one_epoch(
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
return_sample=return_sample,
)
for k, v in stats_g.items():
loss_info[k] = v * batch_size
if "return_sample" not in k:
loss_info[k] = v * batch_size
if return_sample:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"]
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
@ -577,13 +561,27 @@ def train_one_epoch(
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if return_sample:
tb_writer.add_audio(
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
)
tb_writer.add_audio(
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
)
tb_writer.add_image(
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
)
tb_writer.add_image(
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
)
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
)
@ -596,6 +594,12 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
)
tb_writer.add_audio(
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value
@ -604,9 +608,87 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
return_sample = None
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["text"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# summary stats
tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()])
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
return_sample = (audio_pred, audio_gt)
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss, return_sample
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
params: AttributeDict,
@ -620,14 +702,8 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
@ -702,6 +778,11 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
@ -728,14 +809,14 @@ def run(rank, world_size, args):
lr=params.lr,
betas=(0.8, 0.99),
eps=1e-9,
weight_decay=0,
# weight_decay=0,
)
optimizer_d = torch.optim.AdamW(
discriminator.parameters(),
lr=params.lr,
betas=(0.8, 0.99),
eps=1e-9,
weight_decay=0,
# weight_decay=0,
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
@ -804,6 +885,7 @@ def run(rank, world_size, args):
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
params=params,
@ -815,6 +897,8 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
@ -826,6 +910,7 @@ def run(rank, world_size, args):
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,

View File

@ -131,7 +131,14 @@ class LJSpeechTtsDataModule:
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
@ -163,6 +170,7 @@ class LJSpeechTtsDataModule:
train = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
@ -176,6 +184,7 @@ class LJSpeechTtsDataModule:
train = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
@ -229,11 +238,13 @@ class LJSpeechTtsDataModule:
validate = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
@ -264,11 +275,13 @@ class LJSpeechTtsDataModule:
test = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,

View File

@ -211,6 +211,7 @@ def intersperse(sequence, item=0):
def prepare_token_batch(
texts: List[str],
phonemes: Optional[List[str]] = None,
intersperse_blank: bool = True,
blank_id: int = 0,
pad_id: int = 0,
@ -222,41 +223,50 @@ def prepare_token_batch(
blank_id: index of blank token
pad_id: padding index
"""
# normalize text
normalized_texts = []
for text in texts:
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
normalized_texts.append(text)
if phonemes is None:
# normalize text
normalized_texts = []
for text in texts:
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
normalized_texts.append(text)
# convert to phonemes
phonemes = phonemize(
normalized_texts,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
# convert to phonemes
phonemes = phonemize(
normalized_texts,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
# convert to symbol ids
lengths = []
sequences = []
skip = False
for idx, sequence in enumerate(phonemes):
try:
sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)]
except RuntimeError:
print(text[idx])
print(normalized_texts[idx])
sequence = [symbol_to_id[symbol] for symbol in sequence]
except Exception:
# print(texts[idx])
# print(normalized_texts[idx])
print(phonemes[idx])
skip = True
if intersperse_blank:
sequence = intersperse(sequence, blank_id)
sequences.append(torch.tensor(sequence, dtype=torch.int64))
try:
sequences.append(torch.tensor(sequence, dtype=torch.int64))
except Exception:
print(sequence)
skip = True
lengths.append(len(sequence))
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
lengths = torch.tensor(lengths, dtype=torch.int64)
return sequences, lengths
return sequences, lengths, skip
class MetricsTracker(collections.defaultdict):
@ -287,7 +297,7 @@ class MetricsTracker(collections.defaultdict):
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
samples = "%.2f" % self["samples"]
ans += "over" + str(samples) + " samples."
ans += "over " + str(samples) + " samples."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
@ -468,3 +478,41 @@ def save_checkpoint_with_global_batch_idx(
sampler=sampler,
rank=rank,
)
# def plot_feature(feature):
# """
# Display the feature matrix as an image. Requires matplotlib to be installed.
# """
# import matplotlib.pyplot as plt
#
# feature = np.flip(feature.transpose(1, 0), 0)
# return plt.matshow(feature)
MATPLOTLIB_FLAG = False
def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data

View File

@ -241,6 +241,7 @@ class VITS(nn.Module):
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_sample: bool = False,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
@ -276,6 +277,7 @@ class VITS(nn.Module):
feats_lengths=feats_lengths,
speech=speech,
speech_lengths=speech_lengths,
return_sample=return_sample,
sids=sids,
spembs=spembs,
lids=lids,
@ -301,6 +303,7 @@ class VITS(nn.Module):
feats_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_sample: bool = False,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
@ -367,7 +370,12 @@ class VITS(nn.Module):
# calculate losses
with autocast(enabled=False):
mel_loss = self.mel_loss(speech_hat_, speech_)
if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_)
else:
mel_loss, (mel_hat_, mel_) = self.mel_loss(
speech_hat_, speech_, return_mel=True
)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
@ -389,6 +397,14 @@ class VITS(nn.Module):
generator_feat_match_loss=feat_match_loss.item(),
)
if return_sample:
stats["return_sample"] = (
speech_hat_[0].data.cpu().numpy(),
speech_[0].data.cpu().numpy(),
mel_hat_[0].data.cpu().numpy(),
mel_[0].data.cpu().numpy(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
@ -564,4 +580,43 @@ class VITS(nn.Module):
alpha=alpha,
max_len=max_len,
)
return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0])
return wav.view(-1), att_w[0], dur[0]
def inference_batch(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
durations: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (B, T_text).
text_lengths (Tensor): Input text index tensor (B,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (B, T_wav).
* att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
* duration (Tensor): Predicted duration tensor (B, T_text).
"""
# inference
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len,
)
return wav, att_w, dur