icefall/egs/baker_zh/TTS/local/generate_tokens.py
2024-12-31 17:17:05 +08:00

86 lines
1.7 KiB
Python
Executable File

#!/usr/bin/env python3
"""
This file generates the file tokens.txt.
Usage:
python3 ./local/generate_tokens.py > data/tokens.txt
"""
import argparse
from typing import List
import jieba
from pypinyin import Style, lazy_pinyin, pinyin_dict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to to save tokens.txt.",
)
return parser
def generate_token_list() -> List[str]:
token_set = set()
word_dict = pinyin_dict.pinyin_dict
i = 0
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue
w = chr(key)
t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
token_set.add(t)
no_digit = set()
for t in token_set:
if t[-1] not in "1234":
no_digit.add(t)
else:
no_digit.add(t[:-1])
no_digit.add("dei")
no_digit.add("tou")
no_digit.add("dia")
for t in no_digit:
token_set.add(t)
for i in range(1, 5):
token_set.add(f"{t}{i}")
ans = list(token_set)
ans.sort()
punctuations = list(",.!?:\"'")
ans = punctuations + ans
# use ID 0 for blank
# Use ID 1 of _ for padding
ans.insert(0, " ")
ans.insert(1, "_") #
return ans
def main():
args = get_parser().parse_args()
token_list = generate_token_list()
with open(args.tokens, "w", encoding="utf-8") as f:
for indx, token in enumerate(token_list):
f.write(f"{token} {indx}\n")
if __name__ == "__main__":
main()