mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Support using log_add in LG decoding with fast_beam_search.
This commit is contained in:
parent
feb9fd05e8
commit
ec9e4cffe2
@ -86,6 +86,7 @@ def fast_beam_search_nbest(
|
|||||||
num_paths: int,
|
num_paths: int,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
|
use_max: bool = True,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -121,6 +122,8 @@ def fast_beam_search_nbest(
|
|||||||
use_double_scores:
|
use_double_scores:
|
||||||
True to use double precision for computation. False to use
|
True to use double precision for computation. False to use
|
||||||
single precision.
|
single precision.
|
||||||
|
use_max:
|
||||||
|
False to use log-add to compute total scores. True to use max.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -141,14 +144,47 @@ def fast_beam_search_nbest(
|
|||||||
nbest_scale=nbest_scale,
|
nbest_scale=nbest_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# at this point, nbest.fsa.scores are all zeros.
|
# The following code is modified from nbest.intersect()
|
||||||
|
word_fsa = k2.invert(nbest.fsa)
|
||||||
|
if hasattr(lattice, "aux_labels"):
|
||||||
|
# delete token IDs as it is not needed
|
||||||
|
del word_fsa.aux_labels
|
||||||
|
word_fsa.scores.zero_()
|
||||||
|
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||||
|
path_to_utt_map = nbest.shape.row_ids(1)
|
||||||
|
|
||||||
nbest = nbest.intersect(lattice)
|
if hasattr(lattice, "aux_labels"):
|
||||||
# Now nbest.fsa.scores contains acoustic scores
|
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||||
|
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||||
|
inv_lattice = k2.invert(lattice)
|
||||||
|
inv_lattice = k2.arc_sort(inv_lattice)
|
||||||
|
else:
|
||||||
|
inv_lattice = k2.arc_sort(lattice)
|
||||||
|
|
||||||
max_indexes = nbest.tot_scores().argmax()
|
if inv_lattice.shape[0] == 1:
|
||||||
|
path_lattice = k2.intersect_device(
|
||||||
|
inv_lattice,
|
||||||
|
word_fsa_with_epsilon_loops,
|
||||||
|
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path_lattice = k2.intersect_device(
|
||||||
|
inv_lattice,
|
||||||
|
word_fsa_with_epsilon_loops,
|
||||||
|
b_to_a_map=path_to_utt_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
# path_lattice has word IDs as labels and token IDs as aux_labels
|
||||||
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||||
|
tot_scores = path_lattice.get_tot_scores(
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
log_semiring=(False if use_max else True),
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ def get_parser():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=str,
|
type=Path,
|
||||||
default="data/lang_bpe_500",
|
default="data/lang_bpe_500",
|
||||||
help="The lang dir containing word table and LG graph",
|
help="The lang dir containing word table and LG graph",
|
||||||
)
|
)
|
||||||
@ -195,8 +195,8 @@ def get_parser():
|
|||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Whether to use an LG graph for FSA-based beam search.
|
help="""Whether to use an LG graph for FSA-based beam search.
|
||||||
Used only when --decoding_method is fast_beam_search. If setting true,
|
Used only when --decoding_method is fast_beam_search. If true,
|
||||||
it assumes there is an LG.pt file in lang_dir.""",
|
it uses lang_dir/LG.pt during decoding.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -320,6 +320,7 @@ def decode_one_batch(
|
|||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
@ -351,6 +352,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
use_max=params.use_max,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
hyps.append([word_table[i] for i in hyp])
|
hyps.append([word_table[i] for i in hyp])
|
||||||
@ -563,7 +565,10 @@ def main():
|
|||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
params.suffix += f"-use-max-{params.use_max}"
|
if params.use_LG:
|
||||||
|
params.suffix += f"-use-max-{params.use_max}"
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
@ -632,8 +637,10 @@ def main():
|
|||||||
if params.use_LG:
|
if params.use_LG:
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
word_table = lexicon.word_table
|
word_table = lexicon.word_table
|
||||||
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
|
torch.load(lg_filename, map_location=device)
|
||||||
)
|
)
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user