From e062c1b5cfb695ccd8a3ec5453fcdc0fe003096b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 18 Sep 2021 16:32:35 +0800 Subject: [PATCH] Fixes after refactoring. --- egs/librispeech/ASR/conformer_ctc/decode.py | 134 ++++++------------ .../ASR/conformer_ctc/pretrained.py | 2 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 6 +- icefall/{decode2.py => decode.py} | 62 ++++++++ 4 files changed, 114 insertions(+), 90 deletions(-) rename icefall/{decode2.py => decode.py} (92%) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 10add9cd3..80126ae4e 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -30,20 +30,14 @@ from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import get_lattice from icefall.decode import ( - one_best_decoding, # done - rescore_with_attention_decoder, # done - rescore_with_n_best_list, # done - rescore_with_whole_lattice, # done - nbest_oracle, # done -) -from icefall.decode2 import ( + get_lattice, nbest_decoding, - nbest_oracle as nbest_oracle2, - rescore_with_n_best_list as rescore_with_n_best_list2, - rescore_with_whole_lattice as rescore_with_whole_lattice2, - rescore_with_attention_decoder as rescore_with_attention_decoder2, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon from icefall.utils import ( @@ -250,29 +244,19 @@ def decode_one_batch( # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons # as HLG decoding is faster and the oracle WER - # is slightly worse than that of rescored lattices. - if True: - # TODO: delete the `else` branch - best_path = nbest_oracle2( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - lattice_score_scale=params.lattice_score_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa - return {key: hyps} - else: - return nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - scale=params.lattice_score_scale, - ) + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + lattice_score_scale=params.lattice_score_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + return {key: hyps} if params.method in ["1best", "nbest"]: if params.method == "1best": @@ -304,65 +288,39 @@ def decode_one_batch( lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": - if True: - # TODO: remove the "else" branch - best_path_dict = rescore_with_n_best_list2( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, - ) - else: - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - scale=params.lattice_score_scale, - ) + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + lattice_score_scale=params.lattice_score_scale, + ) elif params.method == "whole-lattice-rescoring": - if True: - # TODO: remove "else" branch - best_path_dict = rescore_with_whole_lattice2( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - else: - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) elif params.method == "attention-decoder": # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` - if True: - best_path_dict = rescore_with_attention_decoder2( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, - ) - else: - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - scale=params.lattice_score_scale, - ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + lattice_score_scale=params.lattice_score_scale, + ) else: assert False, f"Unsupported decoding method: {params.method}" diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 913088777..c924b87bb 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -336,7 +336,7 @@ def main(): memory_key_padding_mask=memory_key_padding_mask, sos_id=params.sos_id, eos_id=params.eos_id, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, ngram_lm_scale=params.ngram_lm_scale, attention_scale=params.attention_decoder_scale, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 7e5ec8c0d..8524ab1b9 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -229,6 +229,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + lattice_score_scale=params.lattice_score_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -247,10 +248,13 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + lattice_score_scale=params.lattice_score_scale, ) else: best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, ) ans = dict() diff --git a/icefall/decode2.py b/icefall/decode.py similarity index 92% rename from icefall/decode2.py rename to icefall/decode.py index c5c127cbb..73e3b61f7 100644 --- a/icefall/decode2.py +++ b/icefall/decode.py @@ -64,6 +64,68 @@ def _intersect_device( return k2.cat(ans) +def get_lattice( + nnet_output: torch.Tensor, + HLG: k2.Fsa, + supervision_segments: torch.Tensor, + search_beam: float, + output_beam: float, + min_active_states: int, + max_active_states: int, + subsampling_factor: int = 1, +) -> k2.Fsa: + """Get the decoding lattice from a decoding graph and neural + network output. + Args: + nnet_output: + It is the output of a neural model of shape `[N, T, C]`. + HLG: + An Fsa, the decoding graph. See also `compile_HLG.py`. + supervision_segments: + A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. + Each row contains information for a supervision segment. Column 0 + is the `sequence_index` indicating which sequence this segment + comes from; column 1 specifies the `start_frame` of this segment + within the sequence; column 2 contains the `duration` of this + segment. + search_beam: + Decoding beam, e.g. 20. Smaller is faster, larger is more exact + (less pruning). This is the default value; it may be modified by + `min_active_states` and `max_active_states`. + output_beam: + Beam to prune output, similar to lattice-beam in Kaldi. Relative + to best path of output. + min_active_states: + Minimum number of FSA states that are allowed to be active on any given + frame for any given intersection/composition task. This is advisory, + in that it will try not to have fewer than this number active. + Set it to zero if there is no constraint. + max_active_states: + Maximum number of FSA states that are allowed to be active on any given + frame for any given intersection/composition task. This is advisory, + in that it will try not to exceed that but may not always succeed. + You can use a very large number if no constraint is needed. + subsampling_factor: + The subsampling factor of the model. + Returns: + A lattice containing the decoding result. + """ + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1 + ) + + lattice = k2.intersect_dense_pruned( + HLG, + dense_fsa_vec, + search_beam=search_beam, + output_beam=output_beam, + min_active_states=min_active_states, + max_active_states=max_active_states, + ) + + return lattice + + # TODO(fangjun): Use Kangwei's C++ implementation that also # supports List[List[int]] def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: