Add new tokenizer

This commit is contained in:
Erwan 2024-02-26 09:34:46 +01:00
parent 0377cccc6f
commit 9c083b428f
7 changed files with 191 additions and 81 deletions

View File

@ -17,7 +17,7 @@
""" """
This file reads the texts in given manifest and generates the file that maps tokens to IDs. This file generates the file that maps tokens to IDs.
""" """
import argparse import argparse
@ -25,80 +25,38 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from lhotse import load_manifest from piper_phonemize import get_espeak_map
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
type=Path, type=Path,
default=Path("data/tokens.txt"), default=Path("data/tokens.txt"),
help="Path to the tokens", help="Path to the dict that maps the text tokens to IDs",
) )
return parser.parse_args() return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(filename: Path) -> Dict[str, int]:
"""Write a symbol to ID mapping to a file. """Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map() # token: [token_id]
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
# sort by token_id
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items(): for token, token_id in all_tokens:
f.write(f"{sym} {i}\n") f.write(f"{token} {token_id}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens) out_file = Path(args.tokens)
get_token2id(out_file)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging import logging
from pathlib import Path from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
def prepare_tokens_ljspeech(): def prepare_tokens_ljspeech():
@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
partition = "all" partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = [] new_cuts = []
for cut in cut_set: for cut in cut_set:
# Each cut only contains one supervision # Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions) assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text text = cut.supervisions[0].normalized_text
# Text normalization # Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text) text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes # Convert to phonemes
cut.tokens = g2p(text) tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut) new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts) new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -218,8 +218,7 @@ def main():
params.update(vars(args)) params.update(vars(args))
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id params.blank_id = tokenizer.pad_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
logging.info(params) logging.info(params)

View File

@ -21,6 +21,7 @@ from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder from text_encoder import TextEncoder
from torch.cuda.amp import autocast
from utils import get_random_segments from utils import get_random_segments
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
@ -375,7 +376,7 @@ class VITSGenerator(torch.nn.Module):
# forward duration predictor # forward duration predictor
w = attn.sum(2) # (B, 1, T_text) w = attn.sum(2) # (B, 1, T_text)
with autocast(enabled=False):
if self.use_stochastic_duration_predictor: if self.use_stochastic_duration_predictor:
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / torch.sum(x_mask) dur_nll = dur_nll / torch.sum(x_mask)
@ -455,9 +456,6 @@ class VITSGenerator(torch.nn.Module):
Tensor: Duration tensor (B, T_text). Tensor: Duration tensor (B, T_text).
""" """
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
x_mask = x_mask.to(x.dtype)
g = None g = None
if self.spks is not None: if self.spks is not None:
# (B, global_channels, 1) # (B, global_channels, 1)
@ -477,6 +475,10 @@ class VITSGenerator(torch.nn.Module):
else: else:
g = g + g_ g = g + g_
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g=g)
x_mask = x_mask.to(x.dtype)
if use_teacher_forcing: if use_teacher_forcing:
# forward posterior encoder # forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)

View File

@ -108,7 +108,9 @@ def main():
model = OnnxModel(args.model_filename) model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote." text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text]) tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T) tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T') audio = model(tokens, tokens_lens) # (1, T')

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/tokenizer.py

View File

@ -0,0 +1,146 @@
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List
import tacotron_cleaner.cleaners
from piper_phonemize import phonemize_espeak
from utils import intersperse
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
self.eos_id = self.token2id["$"] # end of an utterance (eos)
self.space_id = self.token2id[" "] # word separator (whitespace)
self.vocab_size = len(self.token2id)
def texts_to_token_ids(
self,
texts: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for text in texts:
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens_list = phonemize_espeak(text, lang)
tokens = []
for t in tokens_list:
tokens.extend(t)
token_ids = []
for t in tokens:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
self,
tokens_list: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
) -> List[List[int]]:
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids)
return token_ids_list

View File

@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"] tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens) tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1) row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1] tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device) tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device) tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T) # a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens return audio, audio_lens, features, features_lens, tokens, tokens_lens
@ -813,8 +815,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens) tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id params.blank_id = tokenizer.pad_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size params.vocab_size = tokenizer.vocab_size
logging.info(params) logging.info(params)