Add fast_beam_search_LG (#622)

* Add fast_beam_search_LG

* add fast_beam_search_LG to commonly used recipes

* fix ci

* fix ci

* Fix error
This commit is contained in:
Wei Kang 2022-11-03 16:29:30 +08:00 committed by GitHub
parent d2a1c65c5c
commit 163d929601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 113 additions and 63 deletions

View File

@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
done done
rm pruned_transducer_stateless2/exp/*.pt rm pruned_transducer_stateless2/exp/*.pt
rm -r data/lang_bpe_500
fi fi

View File

@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
done done
rm pruned_transducer_stateless3/exp/*.pt rm pruned_transducer_stateless3/exp/*.pt
rm -r data/lang_bpe_500
fi fi

View File

@ -206,6 +206,7 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
@ -230,7 +231,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, Used only when --decoding-method is fast_beam_search, fast_beam_search_LG
fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle and fast_beam_search_nbest_oracle
""", """,
@ -241,7 +242,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -250,7 +251,7 @@ def get_parser():
"--max-contexts", "--max-contexts",
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -259,7 +260,7 @@ def get_parser():
"--max-states", "--max-states",
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -355,8 +356,8 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
@ -387,7 +388,10 @@ def decode_one_batch(
) )
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -397,8 +401,12 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): if params.decoding_method == "fast_beam_search":
hyps.append(hyp.split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -526,8 +534,8 @@ def decode_dataset(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
@ -643,6 +651,7 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
@ -737,7 +746,7 @@ def main():
model.device = device model.device = device
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import k2 import k2
@ -727,7 +727,7 @@ class Hypothesis:
# timestamp[i] is the frame index after subsampling # timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded # on which ys[i] is decoded
timestamp: List[int] timestamp: List[int] = field(default_factory=list)
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None

View File

@ -212,6 +212,7 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
@ -247,8 +248,8 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding_method is fast_beam_search_LG and
It specifies the scale for n-gram LM scores. fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -256,7 +257,7 @@ def get_parser():
"--max-contexts", "--max-contexts",
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -265,7 +266,7 @@ def get_parser():
"--max-states", "--max-states",
type=int, type=int,
default=64, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -363,9 +364,10 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -401,7 +403,10 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -411,8 +416,12 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): if params.decoding_method == "fast_beam_search":
hyps.append(hyp.split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -548,9 +557,10 @@ def decode_dataset(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -663,6 +673,7 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
@ -757,7 +768,7 @@ def main():
model.device = device model.device = device
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"

View File

@ -202,6 +202,7 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
@ -226,7 +227,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle and fast_beam_search_nbest_oracle
""", """,
@ -237,7 +238,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -246,7 +247,7 @@ def get_parser():
"--max-contexts", "--max-contexts",
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -255,7 +256,7 @@ def get_parser():
"--max-states", "--max-states",
type=int, type=int,
default=64, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is, fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -440,8 +441,8 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G: G:
Optional. Used only when decoding method is fast_beam_search, Optional. Used only when decoding method is fast_beam_search,
@ -483,7 +484,10 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -494,8 +498,12 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
temperature=params.temperature, temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): if params.decoding_method == "fast_beam_search":
hyps.append(hyp.split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -714,8 +722,8 @@ def decode_dataset(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G: G:
Optional. Used only when decoding method is fast_beam_search, Optional. Used only when decoding method is fast_beam_search,
@ -901,6 +909,7 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
@ -1002,7 +1011,7 @@ def main():
G = None G = None
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"

View File

@ -243,6 +243,7 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
@ -267,7 +268,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle and fast_beam_search_nbest_oracle
""", """,
@ -278,7 +279,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -287,7 +288,7 @@ def get_parser():
"--max-contexts", "--max-contexts",
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -296,7 +297,7 @@ def get_parser():
"--max-states", "--max-states",
type=int, type=int,
default=64, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -394,8 +395,8 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result and timestamps. See above description for the Return the decoding result and timestamps. See above description for the
@ -430,7 +431,10 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
res = fast_beam_search_one_best( res = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -579,8 +583,8 @@ def decode_dataset(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
@ -742,6 +746,7 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
@ -886,7 +891,7 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"

View File

@ -210,6 +210,7 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
@ -234,7 +235,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle and fast_beam_search_nbest_oracle
""", """,
@ -245,7 +246,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -254,7 +255,7 @@ def get_parser():
"--max-contexts", "--max-contexts",
type=int, type=int,
default=8, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -263,7 +264,7 @@ def get_parser():
"--max-states", "--max-states",
type=int, type=int,
default=64, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
@ -361,8 +362,8 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
@ -399,7 +400,10 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -409,8 +413,12 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): if params.decoding_method == "fast_beam_search":
hyps.append(hyp.split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -538,8 +546,8 @@ def decode_dataset(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
@ -653,6 +661,7 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
@ -797,7 +806,7 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"

View File

@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp(
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp(
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp(
N = len(res.tokens) N = len(res.tokens)
assert len(res.timestamps) == N assert len(res.timestamps) == N
use_word_table = False use_word_table = False
if decoding_method == "fast_beam_search_nbest_LG": if (
decoding_method == "fast_beam_search_nbest_LG"
and decoding_method == "fast_beam_search_LG"
):
assert word_table is not None assert word_table is not None
use_word_table = True use_word_table = True