Use piper_phonemize as text tokenizer in ljspeech recipe (#1511)

* use piper_phonemize as text tokenizer in ljspeech recipe

* modify usage of tokenizer in vits/train.py

* update docs
This commit is contained in:
Zengwei Yao 2024-02-29 10:13:22 +08:00 committed by GitHub
parent 291d06056c
commit d89f4ea149
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 107 additions and 101 deletions

View File

@ -1,11 +1,11 @@
VITS VITS-LJSpeech
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset. with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
.. note:: .. note::
TTS related recipes require packages in ``requirements-tts.txt``. TTS related recipes require packages in ``requirements-tts.txt``.
.. note:: .. note::
@ -120,4 +120,4 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models If you don't want to train from scratch, you can download the pretrained models
by visiting the following link: by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_ - `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_

View File

@ -1,11 +1,11 @@
VITS VITS-VCTK
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset. with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
.. note:: .. note::
TTS related recipes require packages in ``requirements-tts.txt``. TTS related recipes require packages in ``requirements-tts.txt``.
.. note:: .. note::

View File

@ -17,7 +17,7 @@
""" """
This file reads the texts in given manifest and generates the file that maps tokens to IDs. This file generates the file that maps tokens to IDs.
""" """
import argparse import argparse
@ -25,80 +25,38 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from lhotse import load_manifest from piper_phonemize import get_espeak_map
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
type=Path, type=Path,
default=Path("data/tokens.txt"), default=Path("data/tokens.txt"),
help="Path to the tokens", help="Path to the dict that maps the text tokens to IDs",
) )
return parser.parse_args() return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(filename: Path) -> Dict[str, int]:
"""Write a symbol to ID mapping to a file. """Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map() # token: [token_id]
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
# sort by token_id
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items(): for token, token_id in all_tokens:
f.write(f"{sym} {i}\n") f.write(f"{token} {token_id}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens) out_file = Path(args.tokens)
get_token2id(out_file)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging import logging
from pathlib import Path from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
def prepare_tokens_ljspeech(): def prepare_tokens_ljspeech():
@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
partition = "all" partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = [] new_cuts = []
for cut in cut_set: for cut in cut_set:
# Each cut only contains one supervision # Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions) assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text text = cut.supervisions[0].normalized_text
# Text normalization # Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text) text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes # Convert to phonemes
cut.tokens = g2p(text) tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut) new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts) new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -30,7 +30,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align cd vits/monotonic_align
python setup.py build_ext --inplace python setup.py build_ext --inplace
cd ../../ cd ../../
else else
log "monotonic_align lib already built" log "monotonic_align lib already built"
fi fi
fi fi
@ -80,6 +80,11 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech" log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py ./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
@ -113,13 +118,12 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file" log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend. # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with: # If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \ ./local/prepare_token_file.py --tokens data/tokens.txt
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
--tokens data/tokens.txt
fi fi
fi fi

View File

@ -218,8 +218,7 @@ def main():
params.update(vars(args)) params.update(vars(args))
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id params.blank_id = tokenizer.pad_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
logging.info(params) logging.info(params)

View File

@ -130,14 +130,16 @@ def infer_dataset(
batch_size = len(batch["tokens"]) batch_size = len(batch["tokens"])
tokens = batch["tokens"] tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens) tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1) row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1] tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device) tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device) tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T) # tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
audio = batch["audio"] audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist() audio_lens = batch["audio_lens"].tolist()
@ -201,8 +203,7 @@ def main():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id params.blank_id = tokenizer.pad_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
logging.info(f"Device: {device}") logging.info(f"Device: {device}")

View File

@ -108,7 +108,9 @@ def main():
model = OnnxModel(args.model_filename) model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote." text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text]) tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T) tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T') audio = model(tokens, tokens_lens) # (1, T')

View File

@ -1,4 +1,4 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) # Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #
@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Dict, List from typing import Dict, List
import g2p_en
import tacotron_cleaner.cleaners import tacotron_cleaner.cleaners
from piper_phonemize import phonemize_espeak
from utils import intersperse from utils import intersperse
@ -38,21 +39,37 @@ class Tokenizer(object):
id = int(info[0]) id = int(info[0])
else: else:
token, id = info[0], int(info[1]) token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id self.token2id[token] = id
self.blank_id = self.token2id["<blk>"] # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.oov_id = self.token2id["<unk>"] self.pad_id = self.token2id["_"] # padding
self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
self.eos_id = self.token2id["$"] # end of an utterance (eos)
self.space_id = self.token2id[" "] # word separator (whitespace)
self.vocab_size = len(self.token2id) self.vocab_size = len(self.token2id)
self.g2p = g2p_en.G2p() def texts_to_token_ids(
self,
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): texts: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
lang: str = "en-us",
) -> List[List[int]]:
""" """
Args: Args:
texts: texts:
A list of transcripts. A list of transcripts.
intersperse_blank: intersperse_blank:
Whether to intersperse blanks in the token sequence. Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
lang:
Language argument passed to phonemize_espeak().
Returns: Returns:
Return a list of token id list [utterance][token_id] Return a list of token id list [utterance][token_id]
@ -63,30 +80,46 @@ class Tokenizer(object):
# Text normalization # Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text) text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes # Convert to phonemes
tokens = self.g2p(text) tokens_list = phonemize_espeak(text, lang)
tokens = []
for t in tokens_list:
tokens.extend(t)
token_ids = [] token_ids = []
for t in tokens: for t in tokens:
if t in self.token2id: if t not in self.token2id:
token_ids.append(self.token2id[t]) logging.warning(f"Skip OOV {t}")
else: continue
token_ids.append(self.oov_id) token_ids.append(self.token2id[t])
if intersperse_blank: if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id) token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids) token_ids_list.append(token_ids)
return token_ids_list return token_ids_list
def tokens_to_token_ids( def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True self,
): tokens_list: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
) -> List[List[int]]:
""" """
Args: Args:
tokens_list: tokens_list:
A list of token list, each corresponding to one utterance. A list of token list, each corresponding to one utterance.
intersperse_blank: intersperse_blank:
Whether to intersperse blanks in the token sequence. Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
Returns: Returns:
Return a list of token id list [utterance][token_id] Return a list of token id list [utterance][token_id]
@ -96,13 +129,17 @@ class Tokenizer(object):
for tokens in tokens_list: for tokens in tokens_list:
token_ids = [] token_ids = []
for t in tokens: for t in tokens:
if t in self.token2id: if t not in self.token2id:
token_ids.append(self.token2id[t]) logging.warning(f"Skip OOV {t}")
else: continue
token_ids.append(self.oov_id) token_ids.append(self.token2id[t])
if intersperse_blank: if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id) token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids) token_ids_list.append(token_ids)

View File

@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"] tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens) tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1) row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1] tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device) tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device) tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T) # a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens return audio, audio_lens, features, features_lens, tokens, tokens_lens
@ -742,8 +744,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id params.blank_id = tokenizer.pad_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
logging.info(params) logging.info(params)

View File

@ -3,4 +3,5 @@ matplotlib==3.8.2
cython==3.0.6 cython==3.0.6
numba==0.58.1 numba==0.58.1
g2p_en==2.1.0 g2p_en==2.1.0
espnet_tts_frontend==0.0.3 espnet_tts_frontend==0.0.3
# piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5