mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Support using log_add in LG decoding with fast_beam_search.
This commit is contained in:
parent
feb9fd05e8
commit
ec9e4cffe2
@ -86,6 +86,7 @@ def fast_beam_search_nbest(
|
||||
num_paths: int,
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
use_max: bool = True,
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -121,6 +122,8 @@ def fast_beam_search_nbest(
|
||||
use_double_scores:
|
||||
True to use double precision for computation. False to use
|
||||
single precision.
|
||||
use_max:
|
||||
False to use log-add to compute total scores. True to use max.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -141,14 +144,47 @@ def fast_beam_search_nbest(
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
|
||||
# at this point, nbest.fsa.scores are all zeros.
|
||||
# The following code is modified from nbest.intersect()
|
||||
word_fsa = k2.invert(nbest.fsa)
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
word_fsa.scores.zero_()
|
||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||
path_to_utt_map = nbest.shape.row_ids(1)
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa.scores contains acoustic scores
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||
inv_lattice = k2.invert(lattice)
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
else:
|
||||
inv_lattice = k2.arc_sort(lattice)
|
||||
|
||||
max_indexes = nbest.tot_scores().argmax()
|
||||
if inv_lattice.shape[0] == 1:
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||
sorted_match_a=True,
|
||||
)
|
||||
else:
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_utt_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
# path_lattice has word IDs as labels and token IDs as aux_labels
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
tot_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=use_double_scores,
|
||||
log_semiring=(False if use_max else True),
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
|
@ -154,7 +154,7 @@ def get_parser():
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
@ -195,8 +195,8 @@ def get_parser():
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use an LG graph for FSA-based beam search.
|
||||
Used only when --decoding_method is fast_beam_search. If setting true,
|
||||
it assumes there is an LG.pt file in lang_dir.""",
|
||||
Used only when --decoding_method is fast_beam_search. If true,
|
||||
it uses lang_dir/LG.pt during decoding.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -320,6 +320,7 @@ def decode_one_batch(
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
@ -351,6 +352,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
@ -563,7 +565,10 @@ def main():
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"-use-max-{params.use_max}"
|
||||
if params.use_LG:
|
||||
params.suffix += f"-use-max-{params.use_max}"
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
@ -632,8 +637,10 @@ def main():
|
||||
if params.use_LG:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user