mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add fast_beam_search_LG (#622)
* Add fast_beam_search_LG * add fast_beam_search_LG to commonly used recipes * fix ci * fix ci * Fix error
This commit is contained in:
parent
d2a1c65c5c
commit
163d929601
@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
|||||||
done
|
done
|
||||||
|
|
||||||
rm pruned_transducer_stateless2/exp/*.pt
|
rm pruned_transducer_stateless2/exp/*.pt
|
||||||
|
rm -r data/lang_bpe_500
|
||||||
fi
|
fi
|
||||||
|
@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
|||||||
done
|
done
|
||||||
|
|
||||||
rm pruned_transducer_stateless3/exp/*.pt
|
rm pruned_transducer_stateless3/exp/*.pt
|
||||||
|
rm -r data/lang_bpe_500
|
||||||
fi
|
fi
|
||||||
|
@ -206,6 +206,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
@ -230,7 +231,7 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search,
|
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG
|
||||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle
|
and fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
@ -241,7 +242,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.01,
|
||||||
help="""
|
help="""
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||||
It specifies the scale for n-gram LM scores.
|
It specifies the scale for n-gram LM scores.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -250,7 +251,7 @@ def get_parser():
|
|||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -259,7 +260,7 @@ def get_parser():
|
|||||||
"--max-states",
|
"--max-states",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -355,8 +356,8 @@ def decode_one_batch(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
@ -387,7 +388,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if (
|
||||||
|
params.decoding_method == "fast_beam_search"
|
||||||
|
or params.decoding_method == "fast_beam_search_LG"
|
||||||
|
):
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -397,8 +401,12 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyps.append(hyp.split())
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
@ -526,8 +534,8 @@ def decode_dataset(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
@ -643,6 +651,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
@ -737,7 +746,7 @@ def main():
|
|||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if "LG" in params.decoding_method:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
word_table = lexicon.word_table
|
word_table = lexicon.word_table
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -727,7 +727,7 @@ class Hypothesis:
|
|||||||
|
|
||||||
# timestamp[i] is the frame index after subsampling
|
# timestamp[i] is the frame index after subsampling
|
||||||
# on which ys[i] is decoded
|
# on which ys[i] is decoded
|
||||||
timestamp: List[int]
|
timestamp: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
state_cost: Optional[NgramLmStateCost] = None
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
|
@ -212,6 +212,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
@ -247,8 +248,8 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.01,
|
||||||
help="""
|
help="""
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
Used only when --decoding_method is fast_beam_search_LG and
|
||||||
It specifies the scale for n-gram LM scores.
|
fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -256,7 +257,7 @@ def get_parser():
|
|||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -265,7 +266,7 @@ def get_parser():
|
|||||||
"--max-states",
|
"--max-states",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -363,9 +364,10 @@ def decode_one_batch(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
|
||||||
|
fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -401,7 +403,10 @@ def decode_one_batch(
|
|||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if (
|
||||||
|
params.decoding_method == "fast_beam_search"
|
||||||
|
or params.decoding_method == "fast_beam_search_LG"
|
||||||
|
):
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -411,8 +416,12 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyps.append(hyp.split())
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
@ -548,9 +557,10 @@ def decode_dataset(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest, fast_beam_search_nbest_oracle, and
|
||||||
|
fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -663,6 +673,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
@ -757,7 +768,7 @@ def main():
|
|||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if "LG" in params.decoding_method:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
word_table = lexicon.word_table
|
word_table = lexicon.word_table
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
@ -202,6 +202,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
@ -226,7 +227,7 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search,
|
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle
|
and fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
@ -237,7 +238,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.01,
|
||||||
help="""
|
help="""
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||||
It specifies the scale for n-gram LM scores.
|
It specifies the scale for n-gram LM scores.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -246,7 +247,7 @@ def get_parser():
|
|||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -255,7 +256,7 @@ def get_parser():
|
|||||||
"--max-states",
|
"--max-states",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is, fast_beam_search_LG,
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -440,8 +441,8 @@ def decode_one_batch(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
G:
|
G:
|
||||||
Optional. Used only when decoding method is fast_beam_search,
|
Optional. Used only when decoding method is fast_beam_search,
|
||||||
@ -483,7 +484,10 @@ def decode_one_batch(
|
|||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if (
|
||||||
|
params.decoding_method == "fast_beam_search"
|
||||||
|
or params.decoding_method == "fast_beam_search_LG"
|
||||||
|
):
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -494,8 +498,12 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
temperature=params.temperature,
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyps.append(hyp.split())
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
@ -714,8 +722,8 @@ def decode_dataset(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
G:
|
G:
|
||||||
Optional. Used only when decoding method is fast_beam_search,
|
Optional. Used only when decoding method is fast_beam_search,
|
||||||
@ -901,6 +909,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
@ -1002,7 +1011,7 @@ def main():
|
|||||||
|
|
||||||
G = None
|
G = None
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if "LG" in params.decoding_method:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
word_table = lexicon.word_table
|
word_table = lexicon.word_table
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
@ -243,6 +243,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
@ -267,7 +268,7 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search,
|
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle
|
and fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
@ -278,7 +279,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.01,
|
||||||
help="""
|
help="""
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||||
It specifies the scale for n-gram LM scores.
|
It specifies the scale for n-gram LM scores.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -287,7 +288,7 @@ def get_parser():
|
|||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -296,7 +297,7 @@ def get_parser():
|
|||||||
"--max-states",
|
"--max-states",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -394,8 +395,8 @@ def decode_one_batch(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result and timestamps. See above description for the
|
Return the decoding result and timestamps. See above description for the
|
||||||
@ -430,7 +431,10 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if (
|
||||||
|
params.decoding_method == "fast_beam_search"
|
||||||
|
or params.decoding_method == "fast_beam_search_LG"
|
||||||
|
):
|
||||||
res = fast_beam_search_one_best(
|
res = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -579,8 +583,8 @@ def decode_dataset(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
@ -742,6 +746,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
@ -886,7 +891,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if "LG" in params.decoding_method:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
word_table = lexicon.word_table
|
word_table = lexicon.word_table
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
@ -210,6 +210,7 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
@ -234,7 +235,7 @@ def get_parser():
|
|||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is fast_beam_search,
|
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle
|
and fast_beam_search_nbest_oracle
|
||||||
""",
|
""",
|
||||||
@ -245,7 +246,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.01,
|
||||||
help="""
|
help="""
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||||
It specifies the scale for n-gram LM scores.
|
It specifies the scale for n-gram LM scores.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -254,7 +255,7 @@ def get_parser():
|
|||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -263,7 +264,7 @@ def get_parser():
|
|||||||
"--max-states",
|
"--max-states",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is fast_beam_search_LG,
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
and fast_beam_search_nbest_oracle""",
|
and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
@ -361,8 +362,8 @@ def decode_one_batch(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
@ -399,7 +400,10 @@ def decode_one_batch(
|
|||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if (
|
||||||
|
params.decoding_method == "fast_beam_search"
|
||||||
|
or params.decoding_method == "fast_beam_search_LG"
|
||||||
|
):
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -409,8 +413,12 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyps.append(hyp.split())
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
@ -538,8 +546,8 @@ def decode_dataset(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
@ -653,6 +661,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
@ -797,7 +806,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if "LG" in params.decoding_method:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
word_table = lexicon.word_table
|
word_table = lexicon.word_table
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp(
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp(
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
"fast_beam_search_nbest",
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp(
|
|||||||
N = len(res.tokens)
|
N = len(res.tokens)
|
||||||
assert len(res.timestamps) == N
|
assert len(res.timestamps) == N
|
||||||
use_word_table = False
|
use_word_table = False
|
||||||
if decoding_method == "fast_beam_search_nbest_LG":
|
if (
|
||||||
|
decoding_method == "fast_beam_search_nbest_LG"
|
||||||
|
and decoding_method == "fast_beam_search_LG"
|
||||||
|
):
|
||||||
assert word_table is not None
|
assert word_table is not None
|
||||||
use_word_table = True
|
use_word_table = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user