convert text to tokens in data preparation stage

This commit is contained in:
yaozengwei 2023-11-13 21:56:18 +08:00
parent f55e80a7c5
commit 8791a4efb0
8 changed files with 139 additions and 39 deletions

View File

@ -58,10 +58,10 @@ def compute_spectrogram_ljspeech():
partition = "all"
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)
config = SpectrogramConfig(

View File

@ -22,12 +22,9 @@ This file reads the texts in given manifest and generates the file that maps tok
import argparse
import logging
from collections import Counter
from pathlib import Path
from typing import Dict
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import load_manifest
@ -74,32 +71,23 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
("<blk>", None), # 0 for blank
("<sos/eos>", None), # 1 for sos and eos symbols.
("<unk>", None), # 2 for OOV
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>" # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
g2p = g2p_en.G2p()
counter = Counter()
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = g2p(text)
for t in tokens:
counter[t] += 1
for t in cut.tokens:
all_tokens.add(t)
# Sort by the number of occurrences in descending order
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
tokens_and_counts = extra_tokens + tokens_and_counts
token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)}
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
"""
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
def prepare_tokens_ljspeech():
output_dir = Path("data/spectrogram")
prefix = "ljspeech"
suffix = "jsonl.gz"
partition = "all"
cut_set = load_manifest(
output_dir / f"{prefix}_cuts_{partition}.{suffix}"
)
g2p = g2p_en.G2p()
new_cuts = []
for cut in cut_set:
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
new_cut_set.to_file(
output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}"
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_tokens_ljspeech()

View File

@ -69,7 +69,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
log "Stage 3: Prepare phoneme tokens for LJSpeech"
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
data/spectrogram/ljspeech_cuts_all.jsonl.gz
touch data/spectrogram/.ljspeech_with_token.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
lhotse subset --last 600 \
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
@ -91,8 +101,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Generate token file"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p

View File

@ -128,10 +128,10 @@ def infer_dataset(
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"])
batch_size = len(batch["tokens"])
text = batch["text"]
tokens = tokenizer.texts_to_token_ids(text)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]

View File

@ -77,3 +77,30 @@ class Tokenizer(object):
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list

View File

@ -295,9 +295,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
text = batch["text"]
tokens = batch["tokens"]
tokens = tokenizer.texts_to_token_ids(text)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
@ -384,7 +384,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["text"])
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
@ -554,7 +554,7 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["text"])
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)

View File

@ -168,7 +168,9 @@ class LJSpeechTtsDataModule:
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_tokens=False,
return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -182,7 +184,9 @@ class LJSpeechTtsDataModule:
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_tokens=False,
return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
@ -236,13 +240,17 @@ class LJSpeechTtsDataModule:
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_tokens=False,
return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_tokens=False,
return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -273,13 +281,17 @@ class LJSpeechTtsDataModule:
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_tokens=False,
return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_tokens=False,
return_token_ids=False,
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)