From 284cbf7ed105db8a70b72d0800a3bcd8372161f4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 21 Jun 2022 22:55:18 +0800 Subject: [PATCH] Support LG for fast beam search. --- .../ASR/pruned_transducer_stateless/decode.py | 3 +- .../pruned_transducer_stateless2/decode.py | 12 +- .../pruned_transducer_stateless3/decode.py | 159 +++++++++++----- .../pruned_transducer_stateless4/decode.py | 170 ++++++++++++------ .../pruned_transducer_stateless5/decode.py | 170 ++++++++++++------ 5 files changed, 361 insertions(+), 153 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 04bae8a1d..ea011b1f9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -482,7 +482,8 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 0a4673e98..56cfde8d2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -177,6 +177,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -482,8 +489,8 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, - fast_beam_search_nbest, or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -726,7 +733,6 @@ def main(): test_set_name=test_set, results_dict=results_dict, ) - break logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index d0c6f3684..bb04c3378 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -50,9 +50,9 @@ Usage: --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 (5) fast beam search (nbest) ./pruned_transducer_stateless3/decode.py \ @@ -61,9 +61,9 @@ Usage: --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 @@ -74,11 +74,22 @@ Usage: --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -96,6 +107,7 @@ from asr_datamodule import AsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, @@ -110,6 +122,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -165,6 +178,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -176,6 +196,9 @@ def get_parser(): - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -191,31 +214,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -238,9 +272,8 @@ def get_parser(): type=int, default=200, help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest or - fast_beam_search_nbest_oracle - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -248,9 +281,8 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest or - fast_beam_search_nbest_oracle - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) return parser @@ -261,6 +293,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -284,10 +317,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, - fast_beam_search_nbest, or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -319,6 +354,20 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": hyp_tokens = fast_beam_search_nbest( model=model, @@ -403,16 +452,25 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } - elif "fast_beam_search_nbest" in params.decoding_method: + elif params.decoding_method == "fast_beam_search": return { ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" + f"max_states_{params.max_states}" ): hyps } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -422,6 +480,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -435,10 +494,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, - fast_beam_search_nbest, or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -466,6 +527,7 @@ def decode_dataset( params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, batch=batch, ) @@ -549,6 +611,7 @@ def main(): "beam_search", "fast_beam_search", "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", ) @@ -559,16 +622,15 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - elif "fast_beam_search_nbest" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - params.suffix += f"-num-paths-{params.num_paths}" - params.suffix += f"-nbest-scale-{params.nbest_scale}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -634,9 +696,23 @@ def main(): model.unk_id = params.unk_id if "fast_beam_search" in params.decoding_method: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -659,6 +735,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 20d1bf338..79dd7856e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -51,9 +51,9 @@ Usage: --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 (5) fast beam search (nbest) ./pruned_transducer_stateless4/decode.py \ @@ -62,9 +62,9 @@ Usage: --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 @@ -75,11 +75,22 @@ Usage: --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless4/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -97,6 +108,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, @@ -111,6 +123,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -178,6 +191,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -189,6 +209,9 @@ def get_parser(): - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -204,31 +227,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -251,9 +285,8 @@ def get_parser(): type=int, default=200, help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest or - fast_beam_search_nbest_oracle - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -261,9 +294,8 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest or - fast_beam_search_nbest_oracle - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) return parser @@ -274,6 +306,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -297,9 +330,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -331,6 +367,20 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": hyp_tokens = fast_beam_search_nbest( model=model, @@ -407,24 +457,17 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - elif "fast_beam_search_nbest" in params.decoding_method: - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -434,6 +477,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -447,10 +491,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, - fast_beam_search_nbest, or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -479,6 +525,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + word_table=word_table, batch=batch, ) @@ -561,6 +608,7 @@ def main(): "beam_search", "fast_beam_search", "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", ) @@ -571,16 +619,15 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - elif "fast_beam_search_nbest" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - params.suffix += f"-num-paths-{params.num_paths}" - params.suffix += f"-nbest-scale-{params.nbest_scale}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -695,9 +742,23 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -719,6 +780,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 5a8bdd733..d845eed51 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -51,9 +51,9 @@ Usage: --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 (5) fast beam search (nbest) ./pruned_transducer_stateless5/decode.py \ @@ -62,9 +62,9 @@ Usage: --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 @@ -75,11 +75,22 @@ Usage: --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -97,6 +108,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, @@ -111,6 +123,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -178,6 +191,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -189,6 +209,9 @@ def get_parser(): - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -204,31 +227,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, or - fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -251,9 +285,8 @@ def get_parser(): type=int, default=200, help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest or - fast_beam_search_nbest_oracle - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -261,9 +294,8 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest or - fast_beam_search_nbest_oracle - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) add_model_arguments(parser) @@ -276,6 +308,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -299,9 +332,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -333,6 +369,20 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": hyp_tokens = fast_beam_search_nbest( model=model, @@ -409,24 +459,17 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - elif "fast_beam_search_nbest" in params.decoding_method: - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -436,6 +479,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -449,10 +493,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, - fast_beam_search_nbest, or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -481,6 +527,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + word_table=word_table, batch=batch, ) @@ -563,6 +610,7 @@ def main(): "beam_search", "fast_beam_search", "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", ) @@ -573,16 +621,15 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - elif "fast_beam_search_nbest" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - params.suffix += f"-num-paths-{params.num_paths}" - params.suffix += f"-nbest-scale-{params.nbest_scale}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -697,9 +744,23 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -721,6 +782,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, )