diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index df976804a..5b048b600 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -17,7 +17,7 @@ """ -This file reads the texts in given manifest and generates the file that maps tokens to IDs. +This file generates the file that maps tokens to IDs. """ import argparse @@ -25,80 +25,38 @@ import logging from pathlib import Path from typing import Dict -from lhotse import load_manifest +from piper_phonemize import get_espeak_map def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--manifest-file", - type=Path, - default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), - help="Path to the manifest file", - ) - parser.add_argument( "--tokens", type=Path, default=Path("data/tokens.txt"), - help="Path to the tokens", + help="Path to the dict that maps the text tokens to IDs", ) return parser.parse_args() -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. +def get_token2id(filename: Path) -> Dict[str, int]: + """Get a dict that maps token to IDs, and save it to the given filename.""" + all_tokens = get_espeak_map() # token: [token_id] + all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()} + # sort by token_id + all_tokens = sorted(all_tokens.items(), key=lambda x: x[1]) - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_token2id(manifest_file: Path) -> Dict[str, int]: - """Return a dict that maps token to IDs.""" - extra_tokens = [ - "", # 0 for blank - "", # 1 for sos and eos symbols. - "", # 2 for OOV - ] - all_tokens = set() - - cut_set = load_manifest(manifest_file) - - for cut in cut_set: - # Each cut only contain one supervision - assert len(cut.supervisions) == 1, len(cut.supervisions) - for t in cut.tokens: - all_tokens.add(t) - - all_tokens = extra_tokens + list(all_tokens) - - token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} - return token2id + for token, token_id in all_tokens: + f.write(f"{token} {token_id}\n") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - manifest_file = Path(args.manifest_file) out_file = Path(args.tokens) - - token2id = get_token2id(manifest_file) - write_mapping(out_file, token2id) + get_token2id(out_file) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index fcd0137a0..08fe7430e 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t import logging from pathlib import Path -import g2p_en import tacotron_cleaner.cleaners from lhotse import CutSet, load_manifest +from piper_phonemize import phonemize_espeak def prepare_tokens_ljspeech(): @@ -35,17 +35,20 @@ def prepare_tokens_ljspeech(): partition = "all" cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - g2p = g2p_en.G2p() new_cuts = [] for cut in cut_set: # Each cut only contains one supervision - assert len(cut.supervisions) == 1, len(cut.supervisions) + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) text = cut.supervisions[0].normalized_text # Text normalization text = tacotron_cleaner.cleaners.custom_english_cleaners(text) # Convert to phonemes - cut.tokens = g2p(text) + tokens_list = phonemize_espeak(text, "en-us") + tokens = [] + for t in tokens_list: + tokens.extend(t) + cut.tokens = tokens new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) diff --git a/egs/ljspeech/TTS/vits2/export-onnx.py b/egs/ljspeech/TTS/vits2/export-onnx.py index f82f9dbe9..c607f0114 100755 --- a/egs/ljspeech/TTS/vits2/export-onnx.py +++ b/egs/ljspeech/TTS/vits2/export-onnx.py @@ -218,8 +218,7 @@ def main(): params.update(vars(args)) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(params) diff --git a/egs/ljspeech/TTS/vits2/generator.py b/egs/ljspeech/TTS/vits2/generator.py index 17fd513a2..345c06211 100644 --- a/egs/ljspeech/TTS/vits2/generator.py +++ b/egs/ljspeech/TTS/vits2/generator.py @@ -21,6 +21,7 @@ from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder +from torch.cuda.amp import autocast from utils import get_random_segments from icefall.utils import make_pad_mask @@ -375,18 +376,18 @@ class VITSGenerator(torch.nn.Module): # forward duration predictor w = attn.sum(2) # (B, 1, T_text) - - if self.use_stochastic_duration_predictor: - dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) - dur_nll = dur_nll / torch.sum(x_mask) - logw = self.duration_predictor( - x, x_mask, g=g, inverse=True, noise_scale=1.0 - ) - logw_ = torch.log(w + 1e-6) * x_mask - else: - logw_ = torch.log(w + 1e-6) * x_mask - logw = self.duration_predictor(x, x_mask, g=g) - dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) + with autocast(enabled=False): + if self.use_stochastic_duration_predictor: + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + logw = self.duration_predictor( + x, x_mask, g=g, inverse=True, noise_scale=1.0 + ) + logw_ = torch.log(w + 1e-6) * x_mask + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.duration_predictor(x, x_mask, g=g) + dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # expand the length to match with the feature sequence # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) @@ -455,9 +456,6 @@ class VITSGenerator(torch.nn.Module): Tensor: Duration tensor (B, T_text). """ - # encoder - x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) - x_mask = x_mask.to(x.dtype) g = None if self.spks is not None: # (B, global_channels, 1) @@ -477,6 +475,10 @@ class VITSGenerator(torch.nn.Module): else: g = g + g_ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g=g) + x_mask = x_mask.to(x.dtype) + if use_teacher_forcing: # forward posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) diff --git a/egs/ljspeech/TTS/vits2/test_onnx.py b/egs/ljspeech/TTS/vits2/test_onnx.py index fcbc1d663..4f46e8e6c 100755 --- a/egs/ljspeech/TTS/vits2/test_onnx.py +++ b/egs/ljspeech/TTS/vits2/test_onnx.py @@ -108,7 +108,9 @@ def main(): model = OnnxModel(args.model_filename) text = "I went there to see the land, the people and how their system works, end quote." - tokens = tokenizer.texts_to_token_ids([text]) + tokens = tokenizer.texts_to_token_ids( + [text], intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = torch.tensor(tokens) # (1, T) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) audio = model(tokens, tokens_lens) # (1, T') diff --git a/egs/ljspeech/TTS/vits2/tokenizer.py b/egs/ljspeech/TTS/vits2/tokenizer.py deleted file mode 120000 index 057b0dc4b..000000000 --- a/egs/ljspeech/TTS/vits2/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits2/tokenizer.py b/egs/ljspeech/TTS/vits2/tokenizer.py new file mode 100644 index 000000000..9a5a9090e --- /dev/null +++ b/egs/ljspeech/TTS/vits2/tokenizer.py @@ -0,0 +1,146 @@ +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List + +import tacotron_cleaner.cleaners +from piper_phonemize import phonemize_espeak +from utils import intersperse + + +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.sos_id = self.token2id["^"] # beginning of an utterance (bos) + self.eos_id = self.token2id["$"] # end of an utterance (eos) + self.space_id = self.token2id[" "] # word separator (whitespace) + + self.vocab_size = len(self.token2id) + + def texts_to_token_ids( + self, + texts: List[str], + intersperse_blank: bool = True, + add_sos: bool = False, + add_eos: bool = False, + lang: str = "en-us", + ) -> List[List[int]]: + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + add_sos: + Whether to add sos token at the start. + add_eos: + Whether to add eos token at the end. + lang: + Language argument passed to phonemize_espeak(). + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for text in texts: + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens_list = phonemize_espeak(text, lang) + tokens = [] + for t in tokens_list: + tokens.extend(t) + + token_ids = [] + for t in tokens: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.pad_id) + if add_sos: + token_ids = [self.sos_id] + token_ids + if add_eos: + token_ids = token_ids + [self.eos_id] + + token_ids_list.append(token_ids) + + return token_ids_list + + def tokens_to_token_ids( + self, + tokens_list: List[str], + intersperse_blank: bool = True, + add_sos: bool = False, + add_eos: bool = False, + ) -> List[List[int]]: + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + add_sos: + Whether to add sos token at the start. + add_eos: + Whether to add eos token at the end. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.pad_id) + if add_sos: + token_ids = [self.sos_id] + token_ids + if add_eos: + token_ids = token_ids + [self.eos_id] + + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/TTS/vits2/train.py b/egs/ljspeech/TTS/vits2/train.py index d11c1674e..39e3195ec 100755 --- a/egs/ljspeech/TTS/vits2/train.py +++ b/egs/ljspeech/TTS/vits2/train.py @@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) return audio, audio_lens, features, features_lens, tokens, tokens_lens @@ -813,8 +815,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(params)