#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2023 Xiaomi Corp. (authors: Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script takes as input `lang_dir`, which should contain:: - lang_dir/words.txt and generates the following files in the directory `lang_dir`: - lexicon.txt - lexicon_disambig.txt - L.pt - L_disambig.pt - tokens.txt """ import argparse import re from pathlib import Path from typing import Dict, List import k2 import torch from prepare_lang import ( Lexicon, add_disambig_symbols, add_self_loops, write_lexicon, write_mapping, ) from icefall.utils import text_to_pinyin def get_parser(): parser = argparse.ArgumentParser( description="Prepare lang for pinyin", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--lang-dir", type=str, help="The lang directory.") parser.add_argument( "--token-type", default="full_with_tone", type=str, help="""The type of pinyin, should be in: full_with_tone: zhōng guó full_no_tone: zhong guo partial_with_tone: zh ōng g uó partial_no_tone: zh ong g uo """, ) parser.add_argument( "--pinyin-errors", default="split", type=str, help="""How to handle characters that has no pinyin, see `text_to_pinyin` in icefall/utils.py for details """, ) return parser def lexicon_to_fst_no_sil( lexicon: Lexicon, token2id: Dict[str, int], word2id: Dict[str, int], need_self_loops: bool = False, ) -> k2.Fsa: """Convert a lexicon to an FST (in k2 format). Args: lexicon: The input lexicon. See also :func:`read_lexicon` token2id: A dict mapping tokens to IDs. word2id: A dict mapping words to IDs. need_self_loops: If True, add self-loop to states with non-epsilon output symbols on at least one arc out of the state. The input label for this self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. Returns: Return an instance of `k2.Fsa` representing the given lexicon. """ loop_state = 0 # words enter and leave from here next_state = 1 # the next un-allocated state, will be incremented as we go arcs = [] # The blank symbol is defined in local/train_bpe_model.py assert token2id[""] == 0 assert word2id[""] == 0 eps = 0 for word, pieces in lexicon: assert len(pieces) > 0, f"{word} has no pronunciations" cur_state = loop_state word = word2id[word] pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps arcs.append([cur_state, next_state, pieces[i], w, 0]) cur_state = next_state next_state += 1 # now for the last piece of this word i = len(pieces) - 1 w = word if i == 0 else eps arcs.append([cur_state, loop_state, pieces[i], w, 0]) if need_self_loops: disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( arcs, disambig_token=disambig_token, disambig_word=disambig_word, ) final_state = next_state arcs.append([loop_state, final_state, -1, -1, 0]) arcs.append([final_state]) arcs = sorted(arcs, key=lambda arc: arc[0]) arcs = [[str(i) for i in arc] for arc in arcs] arcs = [" ".join(arc) for arc in arcs] arcs = "\n".join(arcs) fsa = k2.Fsa.from_str(arcs, acceptor=False) return fsa def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: """Check if all the given tokens are in token symbol table. Args: token_sym_table: Token symbol table that contains all the valid tokens. tokens: A list of tokens. Returns: Return True if there is any token not in the token_sym_table, otherwise False. """ for tok in tokens: if tok not in token_sym_table: return True return False def generate_lexicon( args, token_sym_table: Dict[str, int], words: List[str] ) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: token_sym_table: Token symbol table that mapping token to token ids. words: A list of strings representing words. Returns: Return a dict whose keys are words and values are the corresponding tokens. """ lexicon = [] for word in words: tokens = text_to_pinyin( word.strip(), mode=args.token_type, errors=args.pinyin_errors ) if contain_oov(token_sym_table, tokens): print(f"Word : {word} contains OOV token, skipping.") continue lexicon.append((word, tokens)) # The OOV word is lexicon.append(("", [""])) return lexicon def generate_tokens(args, words: List[str]) -> Dict[str, int]: """Generate tokens from the given word list. Args: words: A list that contains words to generate tokens. Returns: Return a dict whose keys are tokens and values are token ids ranged from 0 to len(keys) - 1. """ tokens: Dict[str, int] = dict() tokens[""] = 0 tokens[""] = 1 tokens[""] = 2 for word in words: word = word.strip() tokens_list = text_to_pinyin( word, mode=args.token_type, errors=args.pinyin_errors ) for token in tokens_list: if token not in tokens: tokens[token] = len(tokens) return tokens def main(): parser = get_parser() args = parser.parse_args() lang_dir = Path(args.lang_dir) word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") words = word_sym_table.symbols excluded = ["", "!SIL", "", "", "#0", "", ""] for w in excluded: if w in words: words.remove(w) token_sym_table = generate_tokens(args, words) lexicon = generate_lexicon(args, token_sym_table, words) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) next_token_id = max(token_sym_table.values()) + 1 for i in range(max_disambig + 1): disambig = f"#{i}" assert disambig not in token_sym_table token_sym_table[disambig] = next_token_id next_token_id += 1 word_sym_table.add("#0") word_sym_table.add("") word_sym_table.add("") write_mapping(lang_dir / "tokens.txt", token_sym_table) write_lexicon(lang_dir / "lexicon.txt", lexicon) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( lexicon, token2id=token_sym_table, word2id=word_sym_table, ) L_disambig = lexicon_to_fst_no_sil( lexicon_disambig, token2id=token_sym_table, word2id=word_sym_table, need_self_loops=True, ) torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") if __name__ == "__main__": main()