convert text to tokens in data preparation stage

This commit is contained in:
yaozengwei 2023-11-13 21:56:18 +08:00
parent f55e80a7c5
commit 8791a4efb0
8 changed files with 139 additions and 39 deletions

View File

@ -58,10 +58,10 @@ def compute_spectrogram_ljspeech():
partition = "all" partition = "all"
recordings = load_manifest( recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
) )
supervisions = load_manifest( supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
) )
config = SpectrogramConfig( config = SpectrogramConfig(

View File

@ -22,12 +22,9 @@ This file reads the texts in given manifest and generates the file that maps tok
import argparse import argparse
import logging import logging
from collections import Counter
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import load_manifest from lhotse import load_manifest
@ -74,32 +71,23 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
def get_token2id(manifest_file: Path) -> Dict[str, int]: def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs.""" """Return a dict that maps token to IDs."""
extra_tokens = [ extra_tokens = [
("<blk>", None), # 0 for blank "<blk>", # 0 for blank
("<sos/eos>", None), # 1 for sos and eos symbols. "<sos/eos>", # 1 for sos and eos symbols.
("<unk>", None), # 2 for OOV "<unk>" # 2 for OOV
] ]
all_tokens = set()
cut_set = load_manifest(manifest_file) cut_set = load_manifest(manifest_file)
g2p = g2p_en.G2p()
counter = Counter()
for cut in cut_set: for cut in cut_set:
# Each cut only contain one supervision # Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions) assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text for t in cut.tokens:
# Text normalization all_tokens.add(t)
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = g2p(text)
for t in tokens:
counter[t] += 1
# Sort by the number of occurrences in descending order all_tokens = extra_tokens + list(all_tokens)
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
tokens_and_counts = extra_tokens + tokens_and_counts
token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)}
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id return token2id

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
"""
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
def prepare_tokens_ljspeech():
output_dir = Path("data/spectrogram")
prefix = "ljspeech"
suffix = "jsonl.gz"
partition = "all"
cut_set = load_manifest(
output_dir / f"{prefix}_cuts_{partition}.{suffix}"
)
g2p = g2p_en.G2p()
new_cuts = []
for cut in cut_set:
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
new_cut_set.to_file(
output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}"
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_tokens_ljspeech()

View File

@ -69,7 +69,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" log "Stage 3: Prepare phoneme tokens for LJSpeech"
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
data/spectrogram/ljspeech_cuts_all.jsonl.gz
touch data/spectrogram/.ljspeech_with_token.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
lhotse subset --last 600 \ lhotse subset --last 600 \
data/spectrogram/ljspeech_cuts_all.jsonl.gz \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
@ -91,8 +101,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 4: Generate token file" log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend. # We assume you have installed g2p_en and espnet_tts_frontend.
# If not, please install them with: # If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p

View File

@ -128,10 +128,10 @@ def infer_dataset(
futures = [] futures = []
with ThreadPoolExecutor(max_workers=1) as executor: with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"]) batch_size = len(batch["tokens"])
text = batch["text"] tokens = batch["tokens"]
tokens = tokenizer.texts_to_token_ids(text) tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1) row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1] tokens_lens = row_splits[1:] - row_splits[:-1]

View File

@ -77,3 +77,30 @@ class Tokenizer(object):
token_ids_list.append(token_ids) token_ids_list.append(token_ids)
return token_ids_list return token_ids_list
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list

View File

@ -295,9 +295,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features = batch["features"].to(device) features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device) audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
text = batch["text"] tokens = batch["tokens"]
tokens = tokenizer.texts_to_token_ids(text) tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1) row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1] tokens_lens = row_splits[1:] - row_splits[:-1]
@ -384,7 +384,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["text"]) batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device) prepare_input(batch, tokenizer, device)
@ -554,7 +554,7 @@ def compute_validation_loss(
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["text"]) batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device) prepare_input(batch, tokenizer, device)

View File

@ -168,7 +168,9 @@ class LJSpeechTtsDataModule:
""" """
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_tokens=False, return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -182,7 +184,9 @@ class LJSpeechTtsDataModule:
use_fft_mag=True, use_fft_mag=True,
) )
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_tokens=False, return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -236,13 +240,17 @@ class LJSpeechTtsDataModule:
use_fft_mag=True, use_fft_mag=True,
) )
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_tokens=False, return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_tokens=False, return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -273,13 +281,17 @@ class LJSpeechTtsDataModule:
use_fft_mag=True, use_fft_mag=True,
) )
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_tokens=False, return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_tokens=False, return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )