From 1bf2e174378c09fb6f8099e06eb03f73451aa815 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 14 Jun 2022 17:04:04 +0800 Subject: [PATCH] Add fast_beam_search_nbest. --- .../beam_search.py | 80 ++++++++++ .../ASR/pruned_transducer_stateless/decode.py | 58 ++++++-- .../beam_search.py | 84 ++++++++++- .../pruned_transducer_stateless2/decode.py | 121 ++++++++++++++-- .../pruned_transducer_stateless3/decode.py | 137 ++++++++++++------ .../pruned_transducer_stateless4/decode.py | 118 ++++++++++++++- .../pruned_transducer_stateless5/decode.py | 118 ++++++++++++++- icefall/decode.py | 6 +- 8 files changed, 633 insertions(+), 89 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index db23fd993..2be509e75 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -75,6 +75,86 @@ def fast_beam_search_one_best( return hyps +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + + return hyps + + def fast_beam_search_nbest_oracle( model: Transducer, decoding_graph: k2.Fsa, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ea43836bd..e8aae7776 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -82,6 +82,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -250,6 +251,26 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search and + --use-LG is True. + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search and + --use-LG is True. + """, + ) + return parser @@ -307,21 +328,32 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - if params.use_LG: - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - else: + if not params.use_LG: + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + else: + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7c936b257..809e2a7f3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -37,7 +37,7 @@ def fast_beam_search_one_best( ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then + A lattice is first obtained using fast beam search, and then the shortest path within the lattice is used as the final output. Args: @@ -74,6 +74,86 @@ def fast_beam_search_one_best( return hyps +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + + return hyps + + def fast_beam_search_nbest_oracle( model: Transducer, decoding_graph: k2.Fsa, @@ -89,7 +169,7 @@ def fast_beam_search_nbest_oracle( ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then + A lattice is first obtained using fast beam search, and then we select `num_paths` linear paths from the lattice. The path that has the minimum edit distance with the given reference transcript is used as the output. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index d7d6b1202..14c33a946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -43,7 +43,7 @@ Usage: --decoding-method modified_beam_search \ --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ @@ -53,6 +53,32 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) fast beam search (nbest) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 """ @@ -69,6 +95,8 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -145,6 +173,8 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle """, ) @@ -164,7 +194,9 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -172,7 +204,8 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -180,7 +213,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -198,6 +232,26 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle + """, + ) + return parser @@ -231,7 +285,8 @@ def decode_one_batch( for the format of the `batch`. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, + fast_beam_search_nbest, or fast_beam_search_nbest_oracle. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -263,6 +318,35 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -318,6 +402,16 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } + elif "fast_beam_search_nbest" in params.decoding_method: + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps + } else: return {f"beam_size_{params.beam_size}": hyps} @@ -342,7 +436,8 @@ def decode_dataset( The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, + fast_beam_search_nbest, or fast_beam_search_nbest_oracle. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -360,7 +455,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -452,6 +547,8 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -461,10 +558,16 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + elif "fast_beam_search_nbest" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -528,7 +631,7 @@ def main(): model.eval() model.device = device - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 5b3dce853..d0c6f3684 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -19,40 +19,66 @@ Usage: (1) greedy search ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search (not recommended) ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 + +(5) fast beam search (nbest) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 """ @@ -69,6 +95,7 @@ import torch.nn as nn from asr_datamodule import AsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, @@ -147,6 +174,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest - fast_beam_search_nbest_oracle """, ) @@ -168,7 +196,8 @@ def get_parser(): search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -176,7 +205,8 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -184,7 +214,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -205,9 +236,10 @@ def get_parser(): parser.add_argument( "--num-paths", type=int, - default=100, - help="""Number of paths for computed nbest oracle WER - when the decoding method is fast_beam_search_nbest_oracle. + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle """, ) @@ -216,9 +248,11 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding_method is fast_beam_search_nbest_oracle. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle """, ) + return parser @@ -252,8 +286,8 @@ def decode_one_batch( for the format of the `batch`. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is - fast_beam_search or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, + fast_beam_search_nbest, or fast_beam_search_nbest_oracle. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -285,6 +319,20 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -355,7 +403,7 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } - elif params.decoding_method == "fast_beam_search_nbest_oracle": + elif "fast_beam_search_nbest" in params.decoding_method: return { ( f"beam_{params.beam}_" @@ -389,7 +437,8 @@ def decode_dataset( The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, + fast_beam_search_nbest, or fast_beam_search_nbest_oracle. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -407,7 +456,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -499,6 +548,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", "fast_beam_search_nbest_oracle", "modified_beam_search", ) @@ -513,7 +563,7 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - elif params.decoding_method == "fast_beam_search_nbest_oracle": + elif "fast_beam_search_nbest" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" @@ -539,9 +589,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.unk_id = sp.unk_id() + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -583,10 +633,7 @@ def main(): model.device = device model.unk_id = params.unk_id - if params.decoding_method in ( - "fast_beam_search", - "fast_beam_search_nbest_oracle", - ): + if "fast_beam_search" in params.decoding_method: decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 70afc3ea3..20d1bf338 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -44,7 +44,7 @@ Usage: --decoding-method modified_beam_search \ --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless4/decode.py \ --epoch 30 \ --avg 15 \ @@ -54,6 +54,32 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) fast beam search (nbest) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 """ @@ -70,6 +96,8 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -159,6 +187,8 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle """, ) @@ -178,7 +208,9 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -186,7 +218,8 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -194,7 +227,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -212,6 +246,26 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle + """, + ) + return parser @@ -277,6 +331,35 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -332,6 +415,16 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } + elif "fast_beam_search_nbest" in params.decoding_method: + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps + } else: return {f"beam_size_{params.beam_size}": hyps} @@ -356,7 +449,8 @@ def decode_dataset( The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, + fast_beam_search_nbest, or fast_beam_search_nbest_oracle. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -374,7 +468,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -466,6 +560,8 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -475,10 +571,16 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + elif "fast_beam_search_nbest" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -592,7 +694,7 @@ def main(): model.to(device) model.eval() - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index c2ca07480..5a8bdd733 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -44,7 +44,7 @@ Usage: --decoding-method modified_beam_search \ --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless5/decode.py \ --epoch 28 \ --avg 15 \ @@ -54,6 +54,32 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) fast beam search (nbest) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 """ @@ -70,6 +96,8 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -128,7 +156,7 @@ def get_parser(): parser.add_argument( "--use-averaged-model", type=str2bool, - default=False, + default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." @@ -159,6 +187,8 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle """, ) @@ -178,7 +208,9 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -186,7 +218,8 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -194,7 +227,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, or + fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -212,6 +246,26 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest or + fast_beam_search_nbest_oracle + """, + ) + add_model_arguments(parser) return parser @@ -279,6 +333,35 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -334,6 +417,16 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } + elif "fast_beam_search_nbest" in params.decoding_method: + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps + } else: return {f"beam_size_{params.beam_size}": hyps} @@ -358,7 +451,8 @@ def decode_dataset( The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, + fast_beam_search_nbest, or fast_beam_search_nbest_oracle. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -468,6 +562,8 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -477,10 +573,16 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + elif "fast_beam_search_nbest" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -594,7 +696,7 @@ def main(): model.to(device) model.eval() - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/icefall/decode.py b/icefall/decode.py index 94f3e88ba..3ba899b4e 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -308,9 +308,7 @@ class Nbest(object): del word_fsa.aux_labels word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( - word_fsa - ) + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) path_to_utt_map = self.shape.row_ids(1) @@ -609,7 +607,7 @@ def rescore_with_n_best_list( num_paths: Size of nbest list. lm_scale_list: - A list of float representing LM score scales. + A list of floats representing LM score scales. nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``.