diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py index 84f7084e4..8810bac9a 100755 --- a/egs/aishell/ASR/zipformer/ctc_decode.py +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -24,8 +24,8 @@ Usage: (1) ctc-greedy-search (with cr-ctc) ./zipformer/ctc_decode.py \ - --epoch 50 \ - --avg 24 \ + --epoch 60 \ + --avg 28 \ --exp-dir ./zipformer/exp \ --use-cr-ctc 1 \ --use-ctc 1 \ @@ -47,40 +47,18 @@ import k2 import torch import torch.nn as nn from asr_datamodule import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.context_graph import ContextGraph, ContextState from icefall.decode import ( ctc_greedy_search, ctc_prefix_beam_search, - ctc_prefix_beam_search_attention_decoder_rescoring, - ctc_prefix_beam_search_shallow_fussion, - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder_no_ngram, - rescore_with_attention_decoder_with_ngram, - rescore_with_n_best_list, - rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon from icefall.utils import ( @@ -162,69 +140,11 @@ def get_parser(): - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. + (2) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. """, ) - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, - and fast_beam_search_nbest_oracle""", - ) - parser.add_argument( "--context-size", type=int, @@ -232,42 +152,6 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - add_model_arguments(parser) return parser @@ -276,9 +160,7 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, - decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -299,10 +181,6 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -340,11 +218,16 @@ def decode_one_batch( hyp_tokens = [] hyps = [] - if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1: + if params.decoding_method == "ctc-greedy-search": hyp_tokens = ctc_greedy_search( ctc_output=ctc_output, encoder_out_lens=encoder_out_lens, ) + elif params.decoding_method == "ctc-prefix-beam-search": + hyp_tokens = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -356,20 +239,10 @@ def decode_one_batch( key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "ctc-greedy-search": return {"ctc-greedy-search_" + key: hyps} - elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} + elif params.decoding_method == "ctc-prefix-beam-search": + return {"ctc-prefix-beam-search_" + key: hyps} else: - return {f"beam_size_{params.beam_size}_" + key: hyps} + assert False, f"Unsupported decoding method: {params.decoding_method}" def decode_dataset( @@ -377,8 +250,6 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, - decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -389,10 +260,6 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -406,10 +273,7 @@ def decode_dataset( except TypeError: num_batches = "?" - if params.decoding_method == "ctc-greedy-search": - log_interval = 50 - else: - log_interval = 20 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -421,8 +285,6 @@ def decode_dataset( params=params, model=model, lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, batch=batch, ) for name, hyps in hyps_dict.items(): @@ -504,7 +366,8 @@ def main(): assert params.decoding_method in ( "ctc-greedy-search", - ) # only support ctc-greedy-search + "ctc-prefix-beam-search", + ) # support ctc-greedy-search and ctc-prefix-beam-search params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -522,22 +385,9 @@ def main(): params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"-left-context-{params.left_context_frames}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"-context-{params.context_size}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -551,18 +401,12 @@ def main(): params.device = device logging.info(f"Device: {device}") - logging.info(params) lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - logging.info(params) logging.info("About to create model") @@ -648,20 +492,6 @@ def main(): model.to(device) model.eval() - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -694,8 +524,6 @@ def main(): params=params, model=model, lexicon=lexicon, - graph_compiler=graph_compiler, - decoding_graph=decoding_graph, ) save_results(