From d2bedbe02e5170c29d1658dfa797fe493c57abf7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 14 Sep 2021 13:45:51 +0800 Subject: [PATCH] Add nbest decoding. --- egs/librispeech/ASR/conformer_ctc/decode.py | 4 ++- icefall/decode2.py | 39 ++++++++++++++++++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index c9d31ff6c..97d0b63fd 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -32,13 +32,13 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import ( get_lattice, - nbest_decoding, nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, rescore_with_whole_lattice, ) +from icefall.decode2 import nbest_decoding from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -371,6 +371,8 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + if batch_idx > 20: + break hyps_dict = decode_one_batch( params=params, diff --git a/icefall/decode2.py b/icefall/decode2.py index e49d0871a..ea76fb63b 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -148,8 +148,7 @@ class Nbest(object): path from the resulting FsaVec. Caution: - We assume FSAs in `self.fsa` don't have epsilon self-loops. - We also assume `self.fsa.labels` and `lattice.labels` are token IDs. + We assume `self.fsa.labels` and `lattice.labels` are token IDs. Args: lattice: @@ -172,14 +171,22 @@ class Nbest(object): lattice.arcs.dim0() == self.shape.dim0 ), f"{lattice.arcs.dim0()} vs {self.shape.dim0}" + # Note: We view each linear FSA as a word sequence + # and we use the passed lattice to give each word sequence a score. + # + # We are not viewing each linear FSAs as a token sequence. + # + # So we use k2.invert() here. + # We use a word fsa to intersect with k2.invert(lattice) word_fsa = k2.invert(self.fsa) # delete token IDs as it is not needed del word_fsa.aux_labels word_fsa.scores.zero_() - - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) + word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( + word_fsa + ) path_to_utt_map = self.shape.row_ids(1) @@ -222,3 +229,27 @@ class Nbest(object): use_double_scores=False, log_semiring=False ) return k2.RaggedTensor(self.shape, scores) + + +def nbest_decoding( + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, +) -> k2.Fsa: + """ """ + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + scale=scale, + ) + + nbest = nbest.intersect(lattice) + + # max_indexes contains the indexes for the max scores + # of paths within an utterance. + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + return best_path