Support using log_add in LG decoding with fast_beam_search.

This commit is contained in:
Fangjun Kuang 2022-06-21 15:08:07 +08:00
parent feb9fd05e8
commit ec9e4cffe2
2 changed files with 53 additions and 10 deletions

View File

@ -86,6 +86,7 @@ def fast_beam_search_nbest(
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
use_max: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
@ -121,6 +122,8 @@ def fast_beam_search_nbest(
use_double_scores:
True to use double precision for computation. False to use
single precision.
use_max:
False to use log-add to compute total scores. True to use max.
Returns:
Return the decoded result.
"""
@ -141,14 +144,47 @@ def fast_beam_search_nbest(
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
max_indexes = nbest.tot_scores().argmax()
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
best_path = k2.index_fsa(nbest.fsa, max_indexes)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores,
log_semiring=(False if use_max else True),
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)

View File

@ -154,7 +154,7 @@ def get_parser():
parser.add_argument(
"--lang-dir",
type=str,
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
@ -195,8 +195,8 @@ def get_parser():
type=str2bool,
default=False,
help="""Whether to use an LG graph for FSA-based beam search.
Used only when --decoding_method is fast_beam_search. If setting true,
it assumes there is an LG.pt file in lang_dir.""",
Used only when --decoding_method is fast_beam_search. If true,
it uses lang_dir/LG.pt during decoding.""",
)
parser.add_argument(
@ -320,6 +320,7 @@ def decode_one_batch(
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
@ -351,6 +352,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
use_max=params.use_max,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -563,7 +565,10 @@ def main():
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_LG:
params.suffix += f"-use-max-{params.use_max}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -632,8 +637,10 @@ def main():
if params.use_LG:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else: