mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
convert text to tokens in data preparation stage
This commit is contained in:
parent
f55e80a7c5
commit
8791a4efb0
@ -58,10 +58,10 @@ def compute_spectrogram_ljspeech():
|
|||||||
partition = "all"
|
partition = "all"
|
||||||
|
|
||||||
recordings = load_manifest(
|
recordings = load_manifest(
|
||||||
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
|
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
|
||||||
)
|
)
|
||||||
supervisions = load_manifest(
|
supervisions = load_manifest(
|
||||||
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
|
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||||
)
|
)
|
||||||
|
|
||||||
config = SpectrogramConfig(
|
config = SpectrogramConfig(
|
||||||
|
@ -22,12 +22,9 @@ This file reads the texts in given manifest and generates the file that maps tok
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from collections import Counter
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import g2p_en
|
|
||||||
import tacotron_cleaner.cleaners
|
|
||||||
from lhotse import load_manifest
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
@ -74,32 +71,23 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
|||||||
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
||||||
"""Return a dict that maps token to IDs."""
|
"""Return a dict that maps token to IDs."""
|
||||||
extra_tokens = [
|
extra_tokens = [
|
||||||
("<blk>", None), # 0 for blank
|
"<blk>", # 0 for blank
|
||||||
("<sos/eos>", None), # 1 for sos and eos symbols.
|
"<sos/eos>", # 1 for sos and eos symbols.
|
||||||
("<unk>", None), # 2 for OOV
|
"<unk>" # 2 for OOV
|
||||||
]
|
]
|
||||||
|
all_tokens = set()
|
||||||
|
|
||||||
cut_set = load_manifest(manifest_file)
|
cut_set = load_manifest(manifest_file)
|
||||||
g2p = g2p_en.G2p()
|
|
||||||
counter = Counter()
|
|
||||||
|
|
||||||
for cut in cut_set:
|
for cut in cut_set:
|
||||||
# Each cut only contain one supervision
|
# Each cut only contain one supervision
|
||||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
text = cut.supervisions[0].normalized_text
|
for t in cut.tokens:
|
||||||
# Text normalization
|
all_tokens.add(t)
|
||||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
|
||||||
# Convert to phonemes
|
|
||||||
tokens = g2p(text)
|
|
||||||
for t in tokens:
|
|
||||||
counter[t] += 1
|
|
||||||
|
|
||||||
# Sort by the number of occurrences in descending order
|
all_tokens = extra_tokens + list(all_tokens)
|
||||||
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
|
|
||||||
|
|
||||||
tokens_and_counts = extra_tokens + tokens_and_counts
|
|
||||||
|
|
||||||
token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)}
|
|
||||||
|
|
||||||
|
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
|
||||||
return token2id
|
return token2id
|
||||||
|
|
||||||
|
|
||||||
|
63
egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
Executable file
63
egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
Executable file
@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import g2p_en
|
||||||
|
import tacotron_cleaner.cleaners
|
||||||
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tokens_ljspeech():
|
||||||
|
output_dir = Path("data/spectrogram")
|
||||||
|
prefix = "ljspeech"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
partition = "all"
|
||||||
|
|
||||||
|
cut_set = load_manifest(
|
||||||
|
output_dir / f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
)
|
||||||
|
g2p = g2p_en.G2p()
|
||||||
|
|
||||||
|
new_cuts = []
|
||||||
|
for cut in cut_set:
|
||||||
|
# Each cut only contains one supervision
|
||||||
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
|
text = cut.supervisions[0].normalized_text
|
||||||
|
# Text normalization
|
||||||
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
|
# Convert to phonemes
|
||||||
|
cut.tokens = g2p(text)
|
||||||
|
new_cuts.append(cut)
|
||||||
|
|
||||||
|
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||||
|
new_cut_set.to_file(
|
||||||
|
output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
prepare_tokens_ljspeech()
|
@ -69,7 +69,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Split the LJSpeech cuts into train, valid and test sets"
|
log "Stage 3: Prepare phoneme tokens for LJSpeech"
|
||||||
|
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
|
||||||
|
./local/prepare_tokens_ljspeech.py
|
||||||
|
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
|
||||||
|
data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
||||||
|
touch data/spectrogram/.ljspeech_with_token.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
|
||||||
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||||
lhotse subset --last 600 \
|
lhotse subset --last 600 \
|
||||||
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
||||||
@ -91,8 +101,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 4: Generate token file"
|
log "Stage 5: Generate token file"
|
||||||
# We assume you have installed g2p_en and espnet_tts_frontend.
|
# We assume you have installed g2p_en and espnet_tts_frontend.
|
||||||
# If not, please install them with:
|
# If not, please install them with:
|
||||||
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
|
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
|
||||||
|
@ -128,10 +128,10 @@ def infer_dataset(
|
|||||||
futures = []
|
futures = []
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
text = batch["text"]
|
tokens = batch["tokens"]
|
||||||
tokens = tokenizer.texts_to_token_ids(text)
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
@ -77,3 +77,30 @@ class Tokenizer(object):
|
|||||||
token_ids_list.append(token_ids)
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
return token_ids_list
|
return token_ids_list
|
||||||
|
|
||||||
|
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens_list:
|
||||||
|
A list of token list, each corresponding to one utterance.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for tokens in tokens_list:
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t in self.token2id:
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
else:
|
||||||
|
token_ids.append(self.oov_id)
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.blank_id)
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
||||||
|
@ -295,9 +295,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
features = batch["features"].to(device)
|
features = batch["features"].to(device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
text = batch["text"]
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
tokens = tokenizer.texts_to_token_ids(text)
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
@ -384,7 +384,7 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["tokens"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||||
prepare_input(batch, tokenizer, device)
|
prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
@ -554,7 +554,7 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["tokens"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||||
prepare_input(batch, tokenizer, device)
|
prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
|
@ -168,7 +168,9 @@ class LJSpeechTtsDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_token_ids=False,
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -182,7 +184,9 @@ class LJSpeechTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_token_ids=False,
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -236,13 +240,17 @@ class LJSpeechTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_token_ids=False,
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_token_ids=False,
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -273,13 +281,17 @@ class LJSpeechTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_token_ids=False,
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_tokens=False,
|
return_token_ids=False,
|
||||||
|
return_text=False,
|
||||||
|
return_tokens=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user