Use piper_phonemize as text tokenizer in vctk TTS recipe (#1522)

* to align with PR #1524
This commit is contained in:
zr_jin 2024-03-18 17:53:52 +08:00 committed by GitHub
parent 9b0eae3b4a
commit eec12f053d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 56 additions and 135 deletions

View File

@ -10,7 +10,7 @@ The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk
This recipe provides a VITS model trained on the VCTK dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html).
@ -21,7 +21,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--max-duration 350

View File

@ -1,104 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict
from lhotse import load_manifest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/local/prepare_token_file.py

View File

@ -24,9 +24,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
from tqdm.auto import tqdm
@ -37,17 +37,20 @@ def prepare_tokens_vctk():
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = []
for cut in tqdm(cut_set):
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -78,6 +78,13 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for VCTK"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
./local/prepare_tokens_vctk.py
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
@ -111,14 +118,15 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
--tokens data/tokens.txt
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -97,7 +98,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
meta.value = str(value)
onnx.save(model, filename)
@ -160,6 +161,7 @@ def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
n_speakers: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@ -212,10 +214,15 @@ def export_model_onnx(
)
meta_data = {
"model_type": "VITS",
"model_type": "vits",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
"comment": "icefall", # must be icefall for models from icefall
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
"n_speakers": n_speakers,
"sample_rate": 22050, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
@ -231,8 +238,7 @@ def main():
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
with open(args.speakers) as f:
@ -265,6 +271,7 @@ def main():
model,
model_filename,
params.vocab_size,
params.num_spks,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")

View File

@ -135,14 +135,16 @@ def infer_dataset(
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
@ -214,8 +216,7 @@ def main():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -122,7 +123,9 @@ def main():
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
speaker = torch.tensor([1], dtype=torch.int64) # (1, )

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -342,14 +343,16 @@ def prepare_input(
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
@ -812,8 +815,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
vctk = VctkTtsDataModule(args)

View File

@ -1,6 +1,7 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#