From 8791a4efb090faca1648ae8d8b8c6ddcb6ee9a0f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 13 Nov 2023 21:56:18 +0800 Subject: [PATCH] convert text to tokens in data preparation stage --- .../TTS/local/compute_spectrogram_ljspeech.py | 4 +- egs/ljspeech/TTS/local/prepare_token_file.py | 30 +++------ .../TTS/local/prepare_tokens_ljspeech.py | 63 +++++++++++++++++++ egs/ljspeech/TTS/prepare.sh | 16 ++++- egs/ljspeech/TTS/vits/infer.py | 6 +- egs/ljspeech/TTS/vits/tokenizer.py | 27 ++++++++ egs/ljspeech/TTS/vits/train.py | 8 +-- egs/ljspeech/TTS/vits/tts_datamodule.py | 24 +++++-- 8 files changed, 139 insertions(+), 39 deletions(-) create mode 100755 egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py index edb22b276..eacf0df57 100755 --- a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -58,10 +58,10 @@ def compute_spectrogram_ljspeech(): partition = "all" recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet ) supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) config = SpectrogramConfig( diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index 167b73f2e..007bb299b 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -22,12 +22,9 @@ This file reads the texts in given manifest and generates the file that maps tok import argparse import logging -from collections import Counter from pathlib import Path from typing import Dict -import g2p_en -import tacotron_cleaner.cleaners from lhotse import load_manifest @@ -74,32 +71,23 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(manifest_file: Path) -> Dict[str, int]: """Return a dict that maps token to IDs.""" extra_tokens = [ - ("", None), # 0 for blank - ("", None), # 1 for sos and eos symbols. - ("", None), # 2 for OOV + "", # 0 for blank + "", # 1 for sos and eos symbols. + "" # 2 for OOV ] + all_tokens = set() + cut_set = load_manifest(manifest_file) - g2p = g2p_en.G2p() - counter = Counter() for cut in cut_set: # Each cut only contain one supervision assert len(cut.supervisions) == 1, len(cut.supervisions) - text = cut.supervisions[0].normalized_text - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens = g2p(text) - for t in tokens: - counter[t] += 1 + for t in cut.tokens: + all_tokens.add(t) - # Sort by the number of occurrences in descending order - tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1]) - - tokens_and_counts = extra_tokens + tokens_and_counts - - token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)} + all_tokens = extra_tokens + list(all_tokens) + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} return token2id diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py new file mode 100755 index 000000000..f7fa7e2d2 --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest + + +def prepare_tokens_ljspeech(): + output_dir = Path("data/spectrogram") + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest( + output_dir / f"{prefix}_cuts_{partition}.{suffix}" + ) + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file( + output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_ljspeech() diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 396d91b59..8ee40896e 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -69,7 +69,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" + log "Stage 3: Prepare phoneme tokens for LJSpeech" + if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py + mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then lhotse subset --last 600 \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \ @@ -91,8 +101,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Generate token file" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" # We assume you have installed g2p_en and espnet_tts_frontend. # If not, please install them with: # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 4917a7ee9..a7c4a4c09 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -128,10 +128,10 @@ def infer_dataset( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: for batch_idx, batch in enumerate(dl): - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) - text = batch["text"] - tokens = tokenizer.texts_to_token_ids(text) + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 8a61511ef..0678b26fe 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -77,3 +77,30 @@ class Tokenizer(object): token_ids_list.append(token_ids) return token_ids_list + + def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 1a2c934fe..eb43a4cc9 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -295,9 +295,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) - text = batch["text"] + tokens = batch["tokens"] - tokens = tokenizer.texts_to_token_ids(text) + tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] @@ -384,7 +384,7 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) audio, audio_lens, features, features_lens, tokens, tokens_lens = \ prepare_input(batch, tokenizer, device) @@ -554,7 +554,7 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) audio, audio_lens, features, features_lens, tokens, tokens_lens = \ prepare_input(batch, tokenizer, device) diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 40e9c19dd..f27676670 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -168,7 +168,9 @@ class LJSpeechTtsDataModule: """ logging.info("About to create train dataset") train = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -182,7 +184,9 @@ class LJSpeechTtsDataModule: use_fft_mag=True, ) train = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) @@ -236,13 +240,17 @@ class LJSpeechTtsDataModule: use_fft_mag=True, ) validate = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) else: validate = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -273,13 +281,17 @@ class LJSpeechTtsDataModule: use_fft_mag=True, ) test = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) else: test = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, )