diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index a175a8e03..40c442e7a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -75,7 +75,7 @@ def fast_beam_search_one_best( return hyps -def fast_beam_search_nbest( +def fast_beam_search_nbest_LG( model: Transducer, decoding_graph: k2.Fsa, encoder_out: torch.Tensor, @@ -86,7 +86,6 @@ def fast_beam_search_nbest( num_paths: int, nbest_scale: float = 0.5, use_double_scores: bool = True, - use_max: bool = True, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -122,8 +121,6 @@ def fast_beam_search_nbest( use_double_scores: True to use double precision for computation. False to use single precision. - use_max: - False to use log-add to compute total scores. True to use max. Returns: Return the decoded result. """ @@ -180,8 +177,11 @@ def fast_beam_search_nbest( path_lattice = k2.top_sort(k2.connect(path_lattice)) tot_scores = path_lattice.get_tot_scores( use_double_scores=use_double_scores, - log_semiring=(False if use_max else True), + log_semiring=True, # Note: we always use True ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) @@ -191,6 +191,86 @@ def fast_beam_search_nbest( return hyps +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + + return hyps + + def fast_beam_search_nbest_oracle( model: Transducer, decoding_graph: k2.Fsa, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 45c7d4437..ba6add265 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -50,20 +50,44 @@ Usage: --exp-dir ./pruned_transducer_stateless/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 -(5) fast beam search using LG +(5) fast beam search (nbest) ./pruned_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless/exp \ - --use-LG True \ - --use-max False \ --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 8 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ --max-contexts 8 \ --max-states 64 """ @@ -83,6 +107,8 @@ from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -100,7 +126,6 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -168,6 +193,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -183,30 +213,13 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=8.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--use-LG", - type=str2bool, - default=False, - help="""Whether to use an LG graph for FSA-based beam search. - Used only when --decoding_method is fast_beam_search. If true, - it uses lang_dir/LG.pt during decoding.""", - ) - - parser.add_argument( - "--use-max", - type=str2bool, - default=False, - help="""If True, use max-op to select the hypothesis that have the - max log_prob in case of duplicate hypotheses. - If False, use log_add. - Used only for beam_search, modified_beam_search, and fast_beam_search + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle """, ) @@ -215,7 +228,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search. + Used only when --decoding_method is fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -225,7 +238,8 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -233,7 +247,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -256,9 +271,8 @@ def get_parser(): type=int, default=200, help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search and - --use-LG is True. - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -266,9 +280,8 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search and - --use-LG is True. - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) return parser @@ -329,33 +342,60 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - if not params.use_LG: - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - use_max=params.use_max, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -373,7 +413,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - use_max=params.use_max, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -395,7 +434,6 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, - use_max=params.use_max, ) else: raise ValueError( @@ -405,14 +443,17 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -458,7 +499,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -551,6 +592,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -561,19 +605,18 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if "fast_beam_search" in params.decoding_method: - params.suffix += f"-use-LG-{params.use_LG}" params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - if params.use_LG: - params.suffix += f"-use-max-{params.use_max}" + if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" ) - params.suffix += f"-use-max-{params.use_max}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -633,8 +676,8 @@ def main(): model.eval() model.device = device - if params.decoding_method == "fast_beam_search": - if params.use_LG: + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt"