From 5016ee3c95551842dc04333f12f5ca5791256ec1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Oct 2021 16:20:35 +0800 Subject: [PATCH] Give an informative message when users provide an unsupported decoding method (#77) --- .../ASR/conformer_ctc/pretrained.py | 219 +++++++++--------- 1 file changed, 106 insertions(+), 113 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 07d3e7269..be94e6875 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -20,23 +20,23 @@ import argparse import logging import math -import sentencepiece as spm from typing import List import k2 import kaldifeat +import sentencepiece as spm import torch import torchaudio from conformer import Conformer from torch.nn.utils.rnn import pad_sequence -from icefall.lexicon import Lexicon from icefall.decode import ( get_lattice, one_best_decoding, rescore_with_attention_decoder, rescore_with_whole_lattice, ) +from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, get_texts @@ -58,7 +58,7 @@ def get_parser(): "--lang-dir", type=str, required=True, - help="Path to lang bpe dir.", + help="Path to lang dir.", ) parser.add_argument( @@ -142,7 +142,7 @@ def get_parser(): parser.add_argument( "--sos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -152,7 +152,7 @@ def get_parser(): parser.add_argument( "--eos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -285,128 +285,121 @@ def main(): dtype=torch.int32, ) - try: - if params.method == "ctc-decoding": - logging.info("Building CTC topology") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) - logging.info("Loading BPE model") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.lang_dir + "/bpe.model") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.lang_dir + "/bpe.model") - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) - logging.info("Use CTC decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") + HLG = k2.Fsa.from_dict( + torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() if params.method in [ - "1best", "whole-lattice-rescoring", "attention-decoder", ]: - logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") - HLG = k2.Fsa.from_dict( - torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores ) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info( - "Use HLG + LM rescoring + attention decoder rescoring" - ) - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file( - params.lang_dir + "/words.txt" + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None ) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file( + params.lang_dir + "/words.txt" + ) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") - logging.info("Decoding Done") + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) - except Exception: - raise ValueError("Please use a supported decoding method.") + logging.info("Decoding Done") if __name__ == "__main__":