Use piper_phonemize as text tokenizer in ljspeech recipe (#1511)

* use piper_phonemize as text tokenizer in ljspeech recipe

* modify usage of tokenizer in vits/train.py

* update docs
This commit is contained in:
Zengwei Yao 2024-02-29 10:13:22 +08:00 committed by GitHub
parent 291d06056c
commit d89f4ea149
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 107 additions and 101 deletions

View File

@ -1,4 +1,4 @@
VITS
VITS-LJSpeech
===============
This tutorial shows you how to train an VITS model
@ -120,4 +120,4 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_

View File

@ -1,4 +1,4 @@
VITS
VITS-VCTK
===============
This tutorial shows you how to train an VITS model

View File

@ -17,7 +17,7 @@
"""
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
This file generates the file that maps tokens to IDs.
"""
import argparse
@ -25,80 +25,38 @@ import logging
from pathlib import Path
from typing import Dict
from lhotse import load_manifest
from piper_phonemize import get_espeak_map
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
help="Path to the dict that maps the text tokens to IDs",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
def get_token2id(filename: Path) -> Dict[str, int]:
"""Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map() # token: [token_id]
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
# sort by token_id
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
for token, token_id in all_tokens:
f.write(f"{token} {token_id}\n")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)
get_token2id(out_file)

View File

@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
def prepare_tokens_ljspeech():
@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = []
for cut in cut_set:
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -80,6 +80,11 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
@ -113,13 +118,12 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
--tokens data/tokens.txt
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi

View File

@ -218,8 +218,7 @@ def main():
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)

View File

@ -130,14 +130,16 @@ def infer_dataset(
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
@ -201,8 +203,7 @@ def main():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(f"Device: {device}")

View File

@ -108,7 +108,9 @@ def main():
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')

View File

@ -1,4 +1,4 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List
import g2p_en
import tacotron_cleaner.cleaners
from piper_phonemize import phonemize_espeak
from utils import intersperse
@ -38,21 +39,37 @@ class Tokenizer(object):
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
self.blank_id = self.token2id["<blk>"]
self.oov_id = self.token2id["<unk>"]
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
self.eos_id = self.token2id["$"] # end of an utterance (eos)
self.space_id = self.token2id[" "] # word separator (whitespace)
self.vocab_size = len(self.token2id)
self.g2p = g2p_en.G2p()
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
def texts_to_token_ids(
self,
texts: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
@ -63,30 +80,46 @@ class Tokenizer(object):
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = self.g2p(text)
tokens_list = phonemize_espeak(text, lang)
tokens = []
for t in tokens_list:
tokens.extend(t)
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
self,
tokens_list: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
) -> List[List[int]]:
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
Returns:
Return a list of token id list [utterance][token_id]
@ -96,13 +129,17 @@ class Tokenizer(object):
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids)

View File

@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens
@ -742,8 +744,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)

View File

@ -4,3 +4,4 @@ cython==3.0.6
numba==0.58.1
g2p_en==2.1.0
espnet_tts_frontend==0.0.3
# piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5