diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 85ceb61b8..435a79e7f 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -21,7 +21,16 @@ """ Usage: -(1) ctc-decoding +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -30,7 +39,7 @@ Usage: --max-duration 600 \ --decoding-method ctc-decoding -(2) 1best +(3) 1best ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -40,7 +49,7 @@ Usage: --hlg-scale 0.6 \ --decoding-method 1best -(3) nbest +(4) nbest ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -50,7 +59,7 @@ Usage: --hlg-scale 0.6 \ --decoding-method nbest -(4) nbest-rescoring +(5) nbest-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -62,7 +71,7 @@ Usage: --lm-dir data/lm \ --decoding-method nbest-rescoring -(5) whole-lattice-rescoring +(6) whole-lattice-rescoring ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -74,7 +83,7 @@ Usage: --lm-dir data/lm \ --decoding-method whole-lattice-rescoring -(6) attention-decoder-rescoring-no-ngram +(7) attention-decoder-rescoring-no-ngram ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -84,7 +93,7 @@ Usage: --max-duration 100 \ --decoding-method attention-decoder-rescoring-no-ngram -(7) attention-decoder-rescoring-with-ngram +(8) attention-decoder-rescoring-with-ngram ./zipformer/ctc_decode.py \ --epoch 30 \ --avg 15 \ @@ -120,6 +129,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.decode import ( + ctc_greedy_search, get_lattice, nbest_decoding, nbest_oracle, @@ -220,26 +230,29 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (2) 1best. Extract the best path from the decoding lattice as the + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path + - (4) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, + - (5) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (7) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. - - (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding lattice, rescore them with the attention decoder. - - (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM rescored lattice, rescore them with the attention decoder. """, ) @@ -381,6 +394,15 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) ctc_output = model.ctc_output(encoder_out) # (N, T, C) + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -684,6 +706,7 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( + "ctc-greedy-search", "ctc-decoding", "1best", "nbest", @@ -733,7 +756,9 @@ def main(): params.eos_id = 1 params.sos_id = 1 - if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: + if params.decoding_method in [ + "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + ]: HLG = None H = k2.ctc_topo( max_token=max_token_id, diff --git a/icefall/decode.py b/icefall/decode.py index 3abd5648a..b17de0ba7 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1473,3 +1473,34 @@ def rescore_with_rnn_lm( key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa ans[key] = best_path return ans + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def ctc_greedy_search( + ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor +) -> List[List[int]]: + """CTC greedy search. + + Args: + ctc_output: (batch, seq_len, vocab_size) + encoder_out_lens: (batch,) + Returns: + List[List[int]]: greedy search result + """ + batch = ctc_output.shape[0] + index = ctc_output.argmax(dim=-1) # (batch, seq_len) + hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)] + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps