From 4501821fd98821a6cf3a238c6dc5c01422643fdb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 9 Dec 2022 16:46:44 +0800 Subject: [PATCH] Support using OpenFst to compile HLG. (#606) * Support using OpenFst to compile HLG. * Fix style issues --- .../ASR/local/compile_hlg_using_openfst.py | 184 ++++++++++++++++++ egs/librispeech/ASR/prepare.sh | 41 +++- icefall/shared/convert-k2-to-openfst.py | 102 ++++++++++ requirements.txt | 1 + 4 files changed, 325 insertions(+), 3 deletions(-) create mode 100755 egs/librispeech/ASR/local/compile_hlg_using_openfst.py create mode 100755 icefall/shared/convert-k2-to-openfst.py diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py new file mode 100755 index 000000000..9e5e3df69 --- /dev/null +++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.fst + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG_fst.pt + +So when to use this script instead of ./local/compile_hlg.py ? +If you have a very large G, ./local/compile_hlg.py may throw OOM for +determinization. In that case, you can use this script to compile HLG. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FST representing HLG. + """ + + L = kaldifst.StdVectorFst.read(f"{lang_dir}/L_disambig.fst") + logging.info("Arc sort L") + kaldifst.arcsort(L, sort_type="olabel") + logging.info(f"L: #states {L.num_states}") + + G_filename_txt = "data/lm/G_3_gram.fst.txt" + G_filename_binary = "data/lm/G_3_gram.fst" + if Path(G_filename_binary).is_file(): + logging.info(f"Loading {G_filename_binary}") + G = kaldifst.StdVectorFst.read(G_filename_binary) + else: + logging.info(f"Loading {G_filename_txt}") + with open(G_filename_txt) as f: + G = kaldifst.compile(s=f.read(), acceptor=False) + logging.info(f"Saving G to {G_filename_binary}") + G.write(G_filename_binary) + + logging.info("Arc sort G") + kaldifst.arcsort(G, sort_type="ilabel") + + logging.info(f"G: #states {G.num_states}") + + logging.info("Compose L and G and connect LG") + LG = kaldifst.compose(L, G, connect=True) + logging.info(f"LG: #states {LG.num_states}") + + logging.info("Determinizestar LG") + kaldifst.determinize_star(LG) + logging.info(f"LG after determinize_star: #states {LG.num_states}") + + logging.info("Minimize encoded LG") + kaldifst.minimize_encoded(LG) + logging.info(f"LG after minimize_encoded: #states {LG.num_states}") + + logging.info("Converting LG to k2 format") + LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False) + logging.info(f"LG in k2: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}") + + lexicon = Lexicon(lang_dir) + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + logging.info(f"token id for #0: {first_token_disambig_id}") + logging.info(f"word id for #0: {first_word_disambig_id}") + + max_token_id = max(lexicon.tokens) + modified = False + logging.info( + f"Building ctc_topo. modified: {modified}, max_token_id: {max_token_id}" + ) + + H = k2.ctc_topo(max_token_id, modified=modified) + logging.info(f"H: #states: {H.shape[0]}, #arcs: {H.num_arcs}") + + logging.info("Removing disambiguation symbols on LG") + LG.labels[LG.labels >= first_token_disambig_id] = 0 + LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0 + + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + logging.info("Removing epsilons from LG") + LG = k2.remove_epsilon(LG) + logging.info( + f"LG after k2.remove_epsilon: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}" + ) + + logging.info("Connecting LG after removing epsilons") + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + logging.info(f"LG after k2.connect: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}") + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + + HLG = k2.compose(H, LG, inner_labels="tokens") + logging.info( + f"HLG after k2.compose: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}" + ) + + logging.info("Connecting HLG") + HLG = k2.connect(HLG) + logging.info( + f"HLG after k2.connect: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}" + ) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + filename = lang_dir / "HLG_fst.pt" + + if filename.is_file(): + logging.info(f"{filename} already exists - skipping") + return + + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG to {filename}") + torch.save(HLG.as_dict(), filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 542bbcdd8..11c8e1066 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -44,9 +44,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + # 5000 + # 2000 + # 1000 500 ) @@ -168,6 +168,22 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang.py --lang-dir $lang_dir fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/disambig_L.fst + fi fi @@ -208,6 +224,22 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then --lexicon $lang_dir/lexicon.txt \ --bpe-model $lang_dir/bpe.model fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi done fi @@ -270,10 +302,13 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then log "Stage 9: Compile HLG" ./local/compile_hlg.py --lang-dir data/lang_phone + ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} ./local/compile_hlg.py --lang-dir $lang_dir + + ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir done fi diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 100755 index 000000000..29a2cd7f7 --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input an FST in k2 format and convert it +to an FST in OpenFST format. + +The generated FST is saved into a binary file and its type is +StdVectorFst. + +Usage examples: +(1) Convert an acceptor + + ./convert-k2-to-openfst.py in.pt binary.fst + +(2) Convert a transducer + + ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst.utils +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--olabels", + type=str, + default=None, + help="""If not empty, the input FST is assumed to be a transducer + and we use its attribute specified by "olabels" as the output labels. + """, + ) + parser.add_argument( + "input_filename", + type=str, + help="Path to the input FST in k2 format", + ) + + parser.add_argument( + "output_filename", + type=str, + help="Path to the output FST in OpenFst format", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(f"{vars(args)}") + + input_filename = args.input_filename + output_filename = args.output_filename + olabels = args.olabels + + if Path(output_filename).is_file(): + logging.info(f"{output_filename} already exists - skipping") + return + + assert Path(input_filename).is_file(), f"{input_filename} does not exist" + logging.info(f"Loading {input_filename}") + k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + if olabels: + assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" + + p = Path(output_filename).parent + if not p.is_dir(): + logging.info(f"Creating {p}") + p.mkdir(parents=True) + + logging.info("Converting (May take some time if the input FST is large)") + fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) + logging.info(f"Saving to {output_filename}") + fst.write(output_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/requirements.txt b/requirements.txt index 5e32af853..a07f6b7c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +kaldifst kaldilm kaldialign sentencepiece>=0.1.96