From 5a05b957300ee21a4d2370039ac612e7265e0834 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 6 Feb 2023 23:21:46 +0800 Subject: [PATCH] add params.hlg_scale (#880) --- egs/librispeech/ASR/conformer_ctc3/decode.py | 199 ++++++------------- 1 file changed, 61 insertions(+), 138 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 8eca2ae02..39186e546 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -58,7 +58,6 @@ For example: --left-context 64 \ --manifest-dir data/fbank_ali Note: It supports calculating symbol delay with following decoding methods: - - ctc-greedy-search - ctc-decoding - 1best """ @@ -96,10 +95,8 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - DecodingResults, get_texts, get_texts_with_timestamp, - make_pad_mask, parse_hyp_and_timestamp, setup_logger, store_transcripts_and_timestamps, @@ -177,20 +174,18 @@ def get_parser(): - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (1) ctc-greedy-search. It only use CTC output and a sentence piece - model for decoding. It produces the same results with ctc-decoding. - - (2) 1best. Extract the best path from the decoding lattice as the + - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path + - (2) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, + - (3) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (5) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -250,6 +245,14 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + add_model_arguments(parser) return parser @@ -270,47 +273,6 @@ def get_decoding_params() -> AttributeDict: return params -def ctc_greedy_search( - ctc_probs: torch.Tensor, - nnet_output_lens: torch.Tensor, -) -> List[List[int]]: - """Apply CTC greedy search - Args: - ctc_probs (torch.Tensor): (batch, max_len, feat_dim) - nnet_output_lens (torch.Tensor): (batch, ) - Returns: - List[List[int]]: best path result - """ - topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) - topk_index = topk_index.squeeze(2) # (B, maxlen) - mask = make_pad_mask(nnet_output_lens) - topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) - hyps = [hyp.tolist() for hyp in topk_index] - scores = topk_prob.max(1) - ret_hyps = [] - timestamps = [] - for i in range(len(hyps)): - hyp, time = remove_duplicates_and_blank(hyps[i]) - ret_hyps.append(hyp) - timestamps.append(time) - return ret_hyps, timestamps, scores - - -def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: - # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py - new_hyp: List[int] = [] - time: List[int] = [] - cur = 0 - while cur < len(hyp): - if hyp[cur] != 0: - new_hyp.append(hyp[cur]) - time.append(cur) - prev = cur - while cur < len(hyp) and hyp[cur] == hyp[prev]: - cur += 1 - return new_hyp, time - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -402,26 +364,11 @@ def decode_one_batch( nnet_output = model.get_ctc_output(encoder_out) # nnet_output is (N, T, C) - if params.decoding_method == "ctc-greedy-search": - hyps, timestamps, _ = ctc_greedy_search( - nnet_output, - encoder_out_lens, - ) - res = DecodingResults(hyps=hyps, timestamps=timestamps) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, - sp=bpe_model, - subsampling_factor=params.subsampling_factor, - frame_shift_ms=params.frame_shift_ms, - ) - key = "ctc-greedy-search" - return {key: (hyps, timestamps)} - supervision_segments = torch.stack( ( supervisions["sequence_idx"], supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, + encoder_out_lens.cpu(), ), 1, ).to(torch.int32) @@ -434,75 +381,6 @@ def decode_one_batch( assert bpe_model is not None decoding_graph = H - if params.decoding_method in ["1best", "nbest", "nbest-oracle"]: - hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0] - - ori_scores = decoding_graph.scores.clone() - - ans = {} - for hlg_scale in hlg_scale_list: - decoding_graph.scores = ori_scores * hlg_scale - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - key_suffix = f"-HLG-scale-{hlg_scale}" - - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa - timestamps = [[] for _ in range(len(hyps))] - ans[key + key_suffix] = (hyps, timestamps) - - elif params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, - use_double_scores=params.use_double_scores, - ) - key = "no-rescore" - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, - subsampling_factor=params.subsampling_factor, - frame_shift_ms=params.frame_shift_ms, - word_table=word_table, - ) - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - timestamps = [[] for _ in range(len(hyps))] - - ans[key + key_suffix] = (hyps, timestamps) - - return ans - lattice = get_lattice( nnet_output=nnet_output, decoding_graph=decoding_graph, @@ -532,6 +410,51 @@ def decode_one_batch( key = "ctc-decoding" return {key: (hyps, timestamps)} + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa + return {key: (hyps, timestamps)} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = f"no_rescore_hlg_scale_{params.hlg_scale}" + res = get_texts_with_timestamp(best_path) + hyps, timestamps = parse_hyp_and_timestamp( + res=res, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + return {key: (hyps, timestamps)} + assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", @@ -757,7 +680,6 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", "1best", "nbest", @@ -811,7 +733,7 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id - if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: + if params.decoding_method == "ctc-decoding": HLG = None H = k2.ctc_topo( max_token=max_token_id, @@ -828,6 +750,7 @@ def main(): ) assert HLG.requires_grad is False + HLG.scores *= params.hlg_scale if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone()