mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Add new tokenizer
This commit is contained in:
parent
0377cccc6f
commit
9c083b428f
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
|
This file generates the file that maps tokens to IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -25,80 +25,38 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from lhotse import load_manifest
|
from piper_phonemize import get_espeak_map
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest-file",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
|
|
||||||
help="Path to the manifest file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--tokens",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path("data/tokens.txt"),
|
default=Path("data/tokens.txt"),
|
||||||
help="Path to the tokens",
|
help="Path to the dict that maps the text tokens to IDs",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
def get_token2id(filename: Path) -> Dict[str, int]:
|
||||||
"""Write a symbol to ID mapping to a file.
|
"""Get a dict that maps token to IDs, and save it to the given filename."""
|
||||||
|
all_tokens = get_espeak_map() # token: [token_id]
|
||||||
|
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
|
||||||
|
# sort by token_id
|
||||||
|
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
|
||||||
|
|
||||||
Note:
|
|
||||||
No need to implement `read_mapping` as it can be done
|
|
||||||
through :func:`k2.SymbolTable.from_file`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename:
|
|
||||||
Filename to save the mapping.
|
|
||||||
sym2id:
|
|
||||||
A dict mapping symbols to IDs.
|
|
||||||
Returns:
|
|
||||||
Return None.
|
|
||||||
"""
|
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
for sym, i in sym2id.items():
|
for token, token_id in all_tokens:
|
||||||
f.write(f"{sym} {i}\n")
|
f.write(f"{token} {token_id}\n")
|
||||||
|
|
||||||
|
|
||||||
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
|
||||||
"""Return a dict that maps token to IDs."""
|
|
||||||
extra_tokens = [
|
|
||||||
"<blk>", # 0 for blank
|
|
||||||
"<sos/eos>", # 1 for sos and eos symbols.
|
|
||||||
"<unk>", # 2 for OOV
|
|
||||||
]
|
|
||||||
all_tokens = set()
|
|
||||||
|
|
||||||
cut_set = load_manifest(manifest_file)
|
|
||||||
|
|
||||||
for cut in cut_set:
|
|
||||||
# Each cut only contain one supervision
|
|
||||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
|
||||||
for t in cut.tokens:
|
|
||||||
all_tokens.add(t)
|
|
||||||
|
|
||||||
all_tokens = extra_tokens + list(all_tokens)
|
|
||||||
|
|
||||||
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
|
|
||||||
return token2id
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
manifest_file = Path(args.manifest_file)
|
|
||||||
out_file = Path(args.tokens)
|
out_file = Path(args.tokens)
|
||||||
|
get_token2id(out_file)
|
||||||
token2id = get_token2id(manifest_file)
|
|
||||||
write_mapping(out_file, token2id)
|
|
||||||
|
@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import g2p_en
|
|
||||||
import tacotron_cleaner.cleaners
|
import tacotron_cleaner.cleaners
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
|
from piper_phonemize import phonemize_espeak
|
||||||
|
|
||||||
|
|
||||||
def prepare_tokens_ljspeech():
|
def prepare_tokens_ljspeech():
|
||||||
@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
|
|||||||
partition = "all"
|
partition = "all"
|
||||||
|
|
||||||
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||||
g2p = g2p_en.G2p()
|
|
||||||
|
|
||||||
new_cuts = []
|
new_cuts = []
|
||||||
for cut in cut_set:
|
for cut in cut_set:
|
||||||
# Each cut only contains one supervision
|
# Each cut only contains one supervision
|
||||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
|
||||||
text = cut.supervisions[0].normalized_text
|
text = cut.supervisions[0].normalized_text
|
||||||
# Text normalization
|
# Text normalization
|
||||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
# Convert to phonemes
|
# Convert to phonemes
|
||||||
cut.tokens = g2p(text)
|
tokens_list = phonemize_espeak(text, "en-us")
|
||||||
|
tokens = []
|
||||||
|
for t in tokens_list:
|
||||||
|
tokens.extend(t)
|
||||||
|
cut.tokens = tokens
|
||||||
new_cuts.append(cut)
|
new_cuts.append(cut)
|
||||||
|
|
||||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||||
|
@ -218,8 +218,7 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.blank_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.oov_id = tokenizer.oov_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -21,6 +21,7 @@ from hifigan import HiFiGANGenerator
|
|||||||
from posterior_encoder import PosteriorEncoder
|
from posterior_encoder import PosteriorEncoder
|
||||||
from residual_coupling import ResidualAffineCouplingBlock
|
from residual_coupling import ResidualAffineCouplingBlock
|
||||||
from text_encoder import TextEncoder
|
from text_encoder import TextEncoder
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from utils import get_random_segments
|
from utils import get_random_segments
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
@ -375,18 +376,18 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
|
|
||||||
# forward duration predictor
|
# forward duration predictor
|
||||||
w = attn.sum(2) # (B, 1, T_text)
|
w = attn.sum(2) # (B, 1, T_text)
|
||||||
|
with autocast(enabled=False):
|
||||||
if self.use_stochastic_duration_predictor:
|
if self.use_stochastic_duration_predictor:
|
||||||
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
||||||
dur_nll = dur_nll / torch.sum(x_mask)
|
dur_nll = dur_nll / torch.sum(x_mask)
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(
|
||||||
x, x_mask, g=g, inverse=True, noise_scale=1.0
|
x, x_mask, g=g, inverse=True, noise_scale=1.0
|
||||||
)
|
)
|
||||||
logw_ = torch.log(w + 1e-6) * x_mask
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
else:
|
else:
|
||||||
logw_ = torch.log(w + 1e-6) * x_mask
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
logw = self.duration_predictor(x, x_mask, g=g)
|
logw = self.duration_predictor(x, x_mask, g=g)
|
||||||
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
dur_nll = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||||
|
|
||||||
# expand the length to match with the feature sequence
|
# expand the length to match with the feature sequence
|
||||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
@ -455,9 +456,6 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
Tensor: Duration tensor (B, T_text).
|
Tensor: Duration tensor (B, T_text).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# encoder
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
|
||||||
x_mask = x_mask.to(x.dtype)
|
|
||||||
g = None
|
g = None
|
||||||
if self.spks is not None:
|
if self.spks is not None:
|
||||||
# (B, global_channels, 1)
|
# (B, global_channels, 1)
|
||||||
@ -477,6 +475,10 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
g = g + g_
|
g = g + g_
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths, g=g)
|
||||||
|
x_mask = x_mask.to(x.dtype)
|
||||||
|
|
||||||
if use_teacher_forcing:
|
if use_teacher_forcing:
|
||||||
# forward posterior encoder
|
# forward posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||||
|
@ -108,7 +108,9 @@ def main():
|
|||||||
model = OnnxModel(args.model_filename)
|
model = OnnxModel(args.model_filename)
|
||||||
|
|
||||||
text = "I went there to see the land, the people and how their system works, end quote."
|
text = "I went there to see the land, the people and how their system works, end quote."
|
||||||
tokens = tokenizer.texts_to_token_ids([text])
|
tokens = tokenizer.texts_to_token_ids(
|
||||||
|
[text], intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
|
)
|
||||||
tokens = torch.tensor(tokens) # (1, T)
|
tokens = torch.tensor(tokens) # (1, T)
|
||||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||||
audio = model(tokens, tokens_lens) # (1, T')
|
audio = model(tokens, tokens_lens) # (1, T')
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/tokenizer.py
|
|
146
egs/ljspeech/TTS/vits2/tokenizer.py
Normal file
146
egs/ljspeech/TTS/vits2/tokenizer.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import tacotron_cleaner.cleaners
|
||||||
|
from piper_phonemize import phonemize_espeak
|
||||||
|
from utils import intersperse
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer(object):
|
||||||
|
def __init__(self, tokens: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens: the file that maps tokens to ids
|
||||||
|
"""
|
||||||
|
# Parse token file
|
||||||
|
self.token2id: Dict[str, int] = {}
|
||||||
|
with open(tokens, "r", encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split()
|
||||||
|
if len(info) == 1:
|
||||||
|
# case of space
|
||||||
|
token = " "
|
||||||
|
id = int(info[0])
|
||||||
|
else:
|
||||||
|
token, id = info[0], int(info[1])
|
||||||
|
assert token not in self.token2id, token
|
||||||
|
self.token2id[token] = id
|
||||||
|
|
||||||
|
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
|
||||||
|
self.pad_id = self.token2id["_"] # padding
|
||||||
|
self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
|
||||||
|
self.eos_id = self.token2id["$"] # end of an utterance (eos)
|
||||||
|
self.space_id = self.token2id[" "] # word separator (whitespace)
|
||||||
|
|
||||||
|
self.vocab_size = len(self.token2id)
|
||||||
|
|
||||||
|
def texts_to_token_ids(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
intersperse_blank: bool = True,
|
||||||
|
add_sos: bool = False,
|
||||||
|
add_eos: bool = False,
|
||||||
|
lang: str = "en-us",
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of transcripts.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
add_sos:
|
||||||
|
Whether to add sos token at the start.
|
||||||
|
add_eos:
|
||||||
|
Whether to add eos token at the end.
|
||||||
|
lang:
|
||||||
|
Language argument passed to phonemize_espeak().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Text normalization
|
||||||
|
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||||
|
# Convert to phonemes
|
||||||
|
tokens_list = phonemize_espeak(text, lang)
|
||||||
|
tokens = []
|
||||||
|
for t in tokens_list:
|
||||||
|
tokens.extend(t)
|
||||||
|
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.pad_id)
|
||||||
|
if add_sos:
|
||||||
|
token_ids = [self.sos_id] + token_ids
|
||||||
|
if add_eos:
|
||||||
|
token_ids = token_ids + [self.eos_id]
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
||||||
|
|
||||||
|
def tokens_to_token_ids(
|
||||||
|
self,
|
||||||
|
tokens_list: List[str],
|
||||||
|
intersperse_blank: bool = True,
|
||||||
|
add_sos: bool = False,
|
||||||
|
add_eos: bool = False,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens_list:
|
||||||
|
A list of token list, each corresponding to one utterance.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
add_sos:
|
||||||
|
Whether to add sos token at the start.
|
||||||
|
add_eos:
|
||||||
|
Whether to add eos token at the end.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for tokens in tokens_list:
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
token_ids = intersperse(token_ids, self.pad_id)
|
||||||
|
if add_sos:
|
||||||
|
token_ids = [self.sos_id] + token_ids
|
||||||
|
if add_eos:
|
||||||
|
token_ids = token_ids + [self.eos_id]
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
|
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
|
)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
# a tensor of shape (B, T)
|
# a tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||||
|
|
||||||
@ -813,8 +815,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.blank_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.oov_id = tokenizer.oov_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user