diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 21a2cecd2..10add9cd3 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -33,9 +33,9 @@ from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import get_lattice from icefall.decode import ( one_best_decoding, # done - rescore_with_attention_decoder, + rescore_with_attention_decoder, # done rescore_with_n_best_list, # done - rescore_with_whole_lattice, + rescore_with_whole_lattice, # done nbest_oracle, # done ) from icefall.decode2 import ( @@ -43,6 +43,7 @@ from icefall.decode2 import ( nbest_oracle as nbest_oracle2, rescore_with_n_best_list as rescore_with_n_best_list2, rescore_with_whole_lattice as rescore_with_whole_lattice2, + rescore_with_attention_decoder as rescore_with_attention_decoder2, ) from icefall.lexicon import Lexicon from icefall.utils import ( @@ -340,16 +341,28 @@ def decode_one_batch( lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - scale=params.lattice_score_scale, - ) + if True: + best_path_dict = rescore_with_attention_decoder2( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + lattice_score_scale=params.lattice_score_scale, + ) + else: + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + scale=params.lattice_score_scale, + ) else: assert False, f"Unsupported decoding method: {params.method}" diff --git a/icefall/decode.py b/icefall/decode.py index 29b76d973..52a134dfc 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -857,13 +857,15 @@ def rescore_with_attention_decoder( assert attention_scores.numel() == num_word_seqs if ngram_lm_scale is None: - ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list = [0.01, 0.05, 0.08] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] else: ngram_lm_scale_list = [ngram_lm_scale] if attention_scale is None: - attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] else: attention_scale_list = [attention_scale] diff --git a/icefall/decode2.py b/icefall/decode2.py index bb38d6026..16e4dcd25 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -310,6 +310,9 @@ class Nbest(object): Hint: `self.fsa.scores` contains two parts: am scores and lm scores. + Caution: + We require that ``self.fsa`` has an attribute ``lm_scores``. + Returns: Return a ragged tensor with 2 axes [utt][path_scores]. Its dtype is torch.float64. @@ -326,6 +329,35 @@ class Nbest(object): return k2.RaggedTensor(self.shape, am_scores) + def compute_lm_scores(self) -> k2.RaggedTensor: + """Compute LM scores of each linear FSA (i.e., each path within + an utterance). + + Hint: + `self.fsa.scores` contains two parts: am scores and lm scores. + + Caution: + We require that ``self.fsa`` has an attribute ``lm_scores``. + + Returns: + Return a ragged tensor with 2 axes [utt][path_scores]. + Its dtype is torch.float64. + """ + saved_scores = self.fsa.scores + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + self.fsa.scores = self.fsa.lm_scores + + # Caution: self.fsa.lm_scores is per arc + # while lm_scores in the following is per path + # + lm_scores = self.fsa.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + self.fsa.scores = saved_scores + + return k2.RaggedTensor(self.shape, lm_scores) + def tot_scores(self) -> k2.RaggedTensor: """Get total scores of the FSAs in this Nbest. @@ -547,7 +579,7 @@ def rescore_with_n_best_list( # nbest.fsa.scores are all 0s at this point nbest = nbest.intersect(lattice) - # Now nbest.fsa has it scores set + # Now nbest.fsa has its scores set assert hasattr(nbest.fsa, "lm_scores") am_scores = nbest.compute_am_scores() @@ -639,3 +671,98 @@ def rescore_with_whole_lattice( key = f"lm_scale_{lm_scale}_yy" ans[key] = best_path return ans + + +def rescore_with_attention_decoder( + lattice: k2.Fsa, + num_paths: int, + model: torch.nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], + sos_id: int, + eos_id: int, + lattice_score_scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + lattice_score_scale=lattice_score_scale, + ) + # nbest.fsa.scores are all 0s at this point + + nbest = nbest.intersect(lattice) + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_memory = memory.index_select(1, path_to_utt_map) + + if memory_key_padding_mask is not None: + # The shape of memory_key_padding_mask is (N, T), so we + # use axis=0 here. + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_utt_map + ) + else: + expanded_memory_key_padding_mask = None + + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + + nll = model.decoder_nll( + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.01, 0.05, 0.08] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores.values + + a_scale * attention_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" + ans[key] = best_path + return ans