Add fast_beam_search_LG (#622)

* Add fast_beam_search_LG

* add fast_beam_search_LG to commonly used recipes

* fix ci

* fix ci

* Fix error
This commit is contained in:
Wei Kang 2022-11-03 16:29:30 +08:00 committed by GitHub
parent d2a1c65c5c
commit 163d929601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 113 additions and 63 deletions

View File

@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
done
rm pruned_transducer_stateless2/exp/*.pt
rm -r data/lang_bpe_500
fi

View File

@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
done
rm pruned_transducer_stateless3/exp/*.pt
rm -r data/lang_bpe_500
fi

View File

@ -206,6 +206,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -230,7 +231,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -241,7 +242,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -250,7 +251,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -259,7 +260,7 @@ def get_parser():
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -355,8 +356,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
@ -387,7 +388,10 @@ def decode_one_batch(
)
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -397,8 +401,12 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -526,8 +534,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -643,6 +651,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -737,7 +746,7 @@ def main():
model.device = device
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -15,7 +15,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import k2
@ -727,7 +727,7 @@ class Hypothesis:
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int]
timestamp: List[int] = field(default_factory=list)
state_cost: Optional[NgramLmStateCost] = None

View File

@ -212,6 +212,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -247,8 +248,8 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
Used only when --decoding_method is fast_beam_search_LG and
fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores.
""",
)
@ -256,7 +257,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -265,7 +266,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -363,9 +364,10 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -401,7 +403,10 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -411,8 +416,12 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -548,9 +557,10 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -663,6 +673,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -757,7 +768,7 @@ def main():
model.device = device
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -202,6 +202,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -226,7 +227,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -237,7 +238,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -246,7 +247,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -255,7 +256,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is, fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -440,8 +441,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
@ -483,7 +484,10 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -494,8 +498,12 @@ def decode_one_batch(
max_states=params.max_states,
temperature=params.temperature,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -714,8 +722,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
@ -901,6 +909,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -1002,7 +1011,7 @@ def main():
G = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -243,6 +243,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -267,7 +268,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -278,7 +279,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -287,7 +288,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -296,7 +297,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -394,8 +395,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result and timestamps. See above description for the
@ -430,7 +431,10 @@ def decode_one_batch(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
res = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -579,8 +583,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -742,6 +746,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -886,7 +891,7 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -210,6 +210,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -234,7 +235,7 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
@ -245,7 +246,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -254,7 +255,7 @@ def get_parser():
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -263,7 +264,7 @@ def get_parser():
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
help="""Used only when --decoding-method is fast_beam_search_LG,
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
@ -361,8 +362,8 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
@ -399,7 +400,10 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -409,8 +413,12 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -538,8 +546,8 @@ def decode_dataset(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -653,6 +661,7 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -797,7 +806,7 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"

View File

@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp(
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp(
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp(
N = len(res.tokens)
assert len(res.timestamps) == N
use_word_table = False
if decoding_method == "fast_beam_search_nbest_LG":
if (
decoding_method == "fast_beam_search_nbest_LG"
and decoding_method == "fast_beam_search_LG"
):
assert word_table is not None
use_word_table = True