diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index 34385d3d3..2cb912c66 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -127,6 +127,10 @@ from gigaspeech_scoring import asr_text_post_processing from lhotse import set_caching_enabled from train_cr_aed import add_model_arguments, get_model, get_params +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -137,6 +141,7 @@ from icefall.decode import ( ctc_greedy_search, ctc_prefix_beam_search, ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, @@ -286,6 +291,23 @@ def get_parser(): """, ) + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -320,8 +342,9 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "beam": 4, # for prefix-beam-search "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, @@ -350,6 +373,7 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -453,6 +477,20 @@ def decode_one_batch( ans[a_scale_str] = hyps return ans + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + LM=LM, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -626,6 +664,7 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -676,6 +715,7 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -779,6 +819,7 @@ def main(): "ctc-greedy-search", "prefix-beam-search", "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "ctc-decoding", "1best", "nbest", @@ -805,6 +846,11 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + params.suffix += f"_lm-scale-{params.lm_scale}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -834,6 +880,7 @@ def main(): "ctc-decoding", "prefix-beam-search", "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "attention-decoder-rescoring-no-ngram", ]: HLG = None @@ -912,6 +959,19 @@ def main(): else: G = None + # only load the neural network LM if required + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + logging.info("About to create model") model = get_model(params)