mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Use piper_phonemize as text tokenizer in vctk TTS recipe (#1522)
* to align with PR #1524
This commit is contained in:
parent
9b0eae3b4a
commit
eec12f053d
@ -10,7 +10,7 @@ The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk
|
|||||||
|
|
||||||
This recipe provides a VITS model trained on the VCTK dataset.
|
This recipe provides a VITS model trained on the VCTK dataset.
|
||||||
|
|
||||||
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
|
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
|
||||||
|
|
||||||
For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html).
|
For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html).
|
||||||
|
|
||||||
@ -21,7 +21,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 1000 \
|
--num-epochs 1000 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir vits/exp \
|
--exp-dir vits/exp \
|
||||||
--tokens data/tokens.txt
|
--tokens data/tokens.txt
|
||||||
--max-duration 350
|
--max-duration 350
|
||||||
|
@ -1,104 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from lhotse import load_manifest
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest-file",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
|
|
||||||
help="Path to the manifest file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/tokens.txt"),
|
|
||||||
help="Path to the tokens",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
|
||||||
"""Write a symbol to ID mapping to a file.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
No need to implement `read_mapping` as it can be done
|
|
||||||
through :func:`k2.SymbolTable.from_file`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename:
|
|
||||||
Filename to save the mapping.
|
|
||||||
sym2id:
|
|
||||||
A dict mapping symbols to IDs.
|
|
||||||
Returns:
|
|
||||||
Return None.
|
|
||||||
"""
|
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
|
||||||
for sym, i in sym2id.items():
|
|
||||||
f.write(f"{sym} {i}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
|
||||||
"""Return a dict that maps token to IDs."""
|
|
||||||
extra_tokens = [
|
|
||||||
"<blk>", # 0 for blank
|
|
||||||
"<sos/eos>", # 1 for sos and eos symbols.
|
|
||||||
"<unk>", # 2 for OOV
|
|
||||||
]
|
|
||||||
all_tokens = set()
|
|
||||||
|
|
||||||
cut_set = load_manifest(manifest_file)
|
|
||||||
|
|
||||||
for cut in cut_set:
|
|
||||||
# Each cut only contain one supervision
|
|
||||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
|
||||||
for t in cut.tokens:
|
|
||||||
all_tokens.add(t)
|
|
||||||
|
|
||||||
all_tokens = extra_tokens + list(all_tokens)
|
|
||||||
|
|
||||||
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
|
|
||||||
return token2id
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
args = get_args()
|
|
||||||
manifest_file = Path(args.manifest_file)
|
|
||||||
out_file = Path(args.tokens)
|
|
||||||
|
|
||||||
token2id = get_token2id(manifest_file)
|
|
||||||
write_mapping(out_file, token2id)
|
|
1
egs/vctk/TTS/local/prepare_token_file.py
Symbolic link
1
egs/vctk/TTS/local/prepare_token_file.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/local/prepare_token_file.py
|
@ -24,9 +24,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import g2p_en
|
|
||||||
import tacotron_cleaner.cleaners
|
import tacotron_cleaner.cleaners
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
|
from piper_phonemize import phonemize_espeak
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
@ -37,17 +37,20 @@ def prepare_tokens_vctk():
|
|||||||
partition = "all"
|
partition = "all"
|
||||||
|
|
||||||
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||||
g2p = g2p_en.G2p()
|
|
||||||
|
|
||||||
new_cuts = []
|
new_cuts = []
|
||||||
for cut in tqdm(cut_set):
|
for cut in tqdm(cut_set):
|
||||||
# Each cut only contains one supervision
|
# Each cut only contains one supervision
|
||||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
|
||||||
text = cut.supervisions[0].text
|
text = cut.supervisions[0].text
|
||||||
# Text normalization
|
# Text normalization
|
||||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
# Convert to phonemes
|
# Convert to phonemes
|
||||||
cut.tokens = g2p(text)
|
tokens_list = phonemize_espeak(text, "en-us")
|
||||||
|
tokens = []
|
||||||
|
for t in tokens_list:
|
||||||
|
tokens.extend(t)
|
||||||
|
cut.tokens = tokens
|
||||||
new_cuts.append(cut)
|
new_cuts.append(cut)
|
||||||
|
|
||||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||||
|
@ -78,6 +78,13 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Prepare phoneme tokens for VCTK"
|
log "Stage 3: Prepare phoneme tokens for VCTK"
|
||||||
|
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||||
|
# If not, please install them with:
|
||||||
|
# - piper_phonemize:
|
||||||
|
# refer to https://github.com/rhasspy/piper-phonemize,
|
||||||
|
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
|
||||||
|
# - espnet_tts_frontend:
|
||||||
|
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||||
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
|
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
|
||||||
./local/prepare_tokens_vctk.py
|
./local/prepare_tokens_vctk.py
|
||||||
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
|
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
|
||||||
@ -111,14 +118,15 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Generate token file"
|
log "Stage 5: Generate token file"
|
||||||
# We assume you have installed g2p_en and espnet_tts_frontend.
|
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||||
# If not, please install them with:
|
# If not, please install them with:
|
||||||
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
|
# - piper_phonemize:
|
||||||
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
# refer to https://github.com/rhasspy/piper-phonemize,
|
||||||
|
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
|
||||||
|
# - espnet_tts_frontend:
|
||||||
|
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||||
if [ ! -e data/tokens.txt ]; then
|
if [ ! -e data/tokens.txt ]; then
|
||||||
./local/prepare_token_file.py \
|
./local/prepare_token_file.py --tokens data/tokens.txt
|
||||||
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
|
|
||||||
--tokens data/tokens.txt
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -97,7 +98,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
|||||||
for key, value in meta_data.items():
|
for key, value in meta_data.items():
|
||||||
meta = model.metadata_props.add()
|
meta = model.metadata_props.add()
|
||||||
meta.key = key
|
meta.key = key
|
||||||
meta.value = value
|
meta.value = str(value)
|
||||||
|
|
||||||
onnx.save(model, filename)
|
onnx.save(model, filename)
|
||||||
|
|
||||||
@ -160,6 +161,7 @@ def export_model_onnx(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
model_filename: str,
|
model_filename: str,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
|
n_speakers: int,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Export the given generator model to ONNX format.
|
"""Export the given generator model to ONNX format.
|
||||||
@ -212,10 +214,15 @@ def export_model_onnx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"model_type": "VITS",
|
"model_type": "vits",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "k2-fsa",
|
"model_author": "k2-fsa",
|
||||||
"comment": "VITS generator",
|
"comment": "icefall", # must be icefall for models from icefall
|
||||||
|
"language": "English",
|
||||||
|
"voice": "en-us", # Choose your language appropriately
|
||||||
|
"has_espeak": 1,
|
||||||
|
"n_speakers": n_speakers,
|
||||||
|
"sample_rate": 22050, # Must match the real sample rate
|
||||||
}
|
}
|
||||||
logging.info(f"meta_data: {meta_data}")
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
@ -231,8 +238,7 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.blank_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.oov_id = tokenizer.oov_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
with open(args.speakers) as f:
|
with open(args.speakers) as f:
|
||||||
@ -265,6 +271,7 @@ def main():
|
|||||||
model,
|
model,
|
||||||
model_filename,
|
model_filename,
|
||||||
params.vocab_size,
|
params.vocab_size,
|
||||||
|
params.num_spks,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
)
|
)
|
||||||
logging.info(f"Exported generator to {model_filename}")
|
logging.info(f"Exported generator to {model_filename}")
|
||||||
|
@ -135,14 +135,16 @@ def infer_dataset(
|
|||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
|
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
|
)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
# tensor of shape (B, T)
|
# tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||||
speakers = (
|
speakers = (
|
||||||
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
|
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
|
||||||
.int()
|
.int()
|
||||||
@ -214,8 +216,7 @@ def main():
|
|||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.blank_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.oov_id = tokenizer.oov_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -122,7 +123,9 @@ def main():
|
|||||||
model = OnnxModel(args.model_filename)
|
model = OnnxModel(args.model_filename)
|
||||||
|
|
||||||
text = "I went there to see the land, the people and how their system works, end quote."
|
text = "I went there to see the land, the people and how their system works, end quote."
|
||||||
tokens = tokenizer.texts_to_token_ids([text])
|
tokens = tokenizer.texts_to_token_ids(
|
||||||
|
[text], intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
|
)
|
||||||
tokens = torch.tensor(tokens) # (1, T)
|
tokens = torch.tensor(tokens) # (1, T)
|
||||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||||
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
|
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -342,14 +343,16 @@ def prepare_input(
|
|||||||
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
|
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
|
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
|
)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
# a tensor of shape (B, T)
|
# a tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
||||||
|
|
||||||
@ -812,8 +815,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.blank_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.oov_id = tokenizer.oov_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
vctk = VctkTtsDataModule(args)
|
vctk = VctkTtsDataModule(args)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
# Copyright 2021 Piotr Żelasko
|
||||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
Loading…
x
Reference in New Issue
Block a user