From 163d929601d0130f51bad266f107f9dfbaf6fde4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 3 Nov 2022 16:29:30 +0800 Subject: [PATCH] Add fast_beam_search_LG (#622) * Add fast_beam_search_LG * add fast_beam_search_LG to commonly used recipes * fix ci * fix ci * Fix error --- ...pruned-transducer-stateless2-2022-04-29.sh | 1 + ...pruned-transducer-stateless3-2022-04-29.sh | 1 + .../ASR/pruned_transducer_stateless/decode.py | 33 ++++++++++------ .../beam_search.py | 4 +- .../pruned_transducer_stateless2/decode.py | 39 ++++++++++++------- .../pruned_transducer_stateless3/decode.py | 33 ++++++++++------ .../pruned_transducer_stateless4/decode.py | 25 +++++++----- .../pruned_transducer_stateless5/decode.py | 33 ++++++++++------ icefall/utils.py | 7 +++- 9 files changed, 113 insertions(+), 63 deletions(-) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index ae2bb6822..c3d07dc0e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless2/exp/*.pt + rm -r data/lang_bpe_500 fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 00580ca1f..22de3b45d 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless3/exp/*.pt + rm -r data/lang_bpe_500 fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3977f8443..ab23a5a83 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -206,6 +206,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -230,7 +231,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -241,7 +242,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -250,7 +251,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -259,7 +260,7 @@ def get_parser(): "--max-states", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -355,8 +356,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of @@ -387,7 +388,10 @@ def decode_one_batch( ) hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -397,8 +401,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -526,8 +534,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -643,6 +651,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -737,7 +746,7 @@ def main(): model.device = device if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0004a24eb..4f5016e94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,7 +15,7 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Union import k2 @@ -727,7 +727,7 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] + timestamp: List[int] = field(default_factory=list) state_cost: Optional[NgramLmStateCost] = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 3b834b919..99d4b5702 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -212,6 +212,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -247,8 +248,8 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. + Used only when --decoding_method is fast_beam_search_LG and + fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -256,7 +257,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -265,7 +266,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -363,9 +364,10 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -401,7 +403,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -411,8 +416,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -548,9 +557,10 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -663,6 +673,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -757,7 +768,7 @@ def main(): model.device = device if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 0f30792e3..f34cf1e1f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -202,6 +202,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -226,7 +227,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -237,7 +238,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -246,7 +247,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -255,7 +256,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is, fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -440,8 +441,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. G: Optional. Used only when decoding method is fast_beam_search, @@ -483,7 +484,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -494,8 +498,12 @@ def decode_one_batch( max_states=params.max_states, temperature=params.temperature, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -714,8 +722,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. G: Optional. Used only when decoding method is fast_beam_search, @@ -901,6 +909,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -1002,7 +1011,7 @@ def main(): G = None if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 85097a01a..6afc21ce7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -243,6 +243,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -267,7 +268,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -278,7 +279,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -287,7 +288,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -296,7 +297,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -394,8 +395,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result and timestamps. See above description for the @@ -430,7 +431,10 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -579,8 +583,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -742,6 +746,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -886,7 +891,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 632932214..c27d78e34 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -210,6 +210,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -234,7 +235,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -245,7 +246,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -254,7 +255,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -263,7 +264,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -361,8 +362,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of @@ -399,7 +400,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -409,8 +413,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -538,8 +546,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -653,6 +661,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -797,7 +806,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/icefall/utils.py b/icefall/utils.py index 45a49fb5c..93dd0b967 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp( - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp( "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp( N = len(res.tokens) assert len(res.timestamps) == N use_word_table = False - if decoding_method == "fast_beam_search_nbest_LG": + if ( + decoding_method == "fast_beam_search_nbest_LG" + and decoding_method == "fast_beam_search_LG" + ): assert word_table is not None use_word_table = True