From bd69e4be325cb799665697c61ab5f1308fedff55 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 28 Jul 2021 12:22:09 +0800 Subject: [PATCH] Use attention decoder for rescoring. --- egs/librispeech/ASR/conformer_ctc/decode.py | 43 +++- .../ASR/conformer_ctc/transformer.py | 2 +- egs/librispeech/ASR/local/compile_hlg.py | 9 +- icefall/decode.py | 215 ++++++++++++++++-- 4 files changed, 239 insertions(+), 30 deletions(-) mode change 100644 => 100755 egs/librispeech/ASR/local/compile_hlg.py diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index a9d2b465c..625afeda3 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -21,6 +21,7 @@ from icefall.decode import ( get_lattice, nbest_decoding, one_best_decoding, + rescore_with_attention_decoder, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -82,9 +83,12 @@ def get_params() -> AttributeDict: # - nbest # - nbest-rescoring # - whole-lattice-rescoring - "method": "nbest-rescoring", - # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 100, + # - attention-decoder + # "method": "whole-lattice-rescoring", + "method": "attention-decoder", + # num_paths is used when method is "nbest", "nbest-rescoring", + # and attention-decoder + "num_paths": 1000, } ) return params @@ -147,7 +151,7 @@ def decode_one_batch( supervisions = batch["supervisions"] - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) @@ -191,7 +195,11 @@ def decode_one_batch( hyps = [[lexicon.words[i] for i in ids] for ids in hyps] return {key: hyps} - assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] @@ -203,10 +211,25 @@ def decode_one_batch( num_paths=params.num_paths, lm_scale_list=lm_scale_list, ) - else: + elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" ans = dict() for lm_scale_str, best_path in best_path_dict.items(): @@ -351,7 +374,11 @@ def main(): if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() - if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") @@ -374,7 +401,7 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt") G = k2.Fsa.from_dict(d).to(device) - if params.method == "whole-lattice-rescoring": + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index e302cfeaf..fc748a252 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -259,7 +259,7 @@ class Transformer(nn.Module): return decoder_loss def decoder_nll( - self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None + self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None ) -> Tensor: """ Args: diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py old mode 100644 new mode 100755 index dc9105418..b962be552 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -32,7 +32,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: """ lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) - print(f"building ctc_top. max_token_id: {max_token_id}") + print(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) @@ -86,13 +86,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG = k2.arc_sort(LG) print("Composing H and LG") - HLG = k2.compose(H, LG, inner_labels="phones") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") print("Connecting LG") HLG = k2.connect(HLG) print("Arc sorting LG") HLG = k2.arc_sort(HLG) + print(f"HLG.shape: {HLG.shape}") return HLG diff --git a/icefall/decode.py b/icefall/decode.py index 0ab712b3b..4801185b8 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,8 +1,9 @@ import logging -from typing import Dict, List +from typing import Dict, List, Optional, Tuple, Union import k2 import torch +import torch.nn as nn def _intersect_device( @@ -11,7 +12,7 @@ def _intersect_device( b_to_a_map: torch.Tensor, sorted_match_a: bool, batch_size: int = 50, -): +) -> k2.Fsa: """This is a wrapper of k2.intersect_device and its purpose is to split b_fsas into several batches and process each batch separately to avoid CUDA OOM error. @@ -55,7 +56,7 @@ def get_lattice( min_active_states: int, max_active_states: int, subsampling_factor: int = 1, -): +) -> k2.Fsa: """Get the decoding lattice from a decoding graph and neural network output. @@ -129,7 +130,7 @@ def one_best_decoding( def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True -): +) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. The basic idea is to first extra n-best paths from the given lattice, @@ -253,11 +254,11 @@ def nbest_decoding( return best_path_fsa -def compute_am_scores( +def compute_am_and_lm_scores( lattice: k2.Fsa, word_fsa_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """Compute AM scores of n-best lists (represented as word_fsas). Args: @@ -272,8 +273,8 @@ def compute_am_scores( which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to. path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). Returns: - Return a 1-D torch.Tensor containing the AM scores of each path. - `ans.numel() == word_fsas_with_epsilon_loops.shape[0]` + Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores). + Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]` """ assert len(lattice.shape) == 3 assert hasattr(lattice, "lm_scores") @@ -293,23 +294,29 @@ def compute_am_scores( del inv_lattice.aux_labels inv_lattice = k2.arc_sort(inv_lattice) - am_path_lattice = _intersect_device( + path_lattice = _intersect_device( inv_lattice, word_fsa_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True, ) - am_path_lattice = k2.top_sort(k2.connect(am_path_lattice)) + path_lattice = k2.top_sort(k2.connect(path_lattice)) # The `scores` of every arc consists of `am_scores` and `lm_scores` - am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores + path_lattice.scores = path_lattice.scores - path_lattice.lm_scores - am_scores = am_path_lattice.get_tot_scores( + am_scores = path_lattice.get_tot_scores( use_double_scores=True, log_semiring=False ) - return am_scores + path_lattice.scores = path_lattice.lm_scores + + lm_scores = path_lattice.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + + return am_scores.to(torch.float32), lm_scores.to(torch.float32) def rescore_with_n_best_list( @@ -395,7 +402,7 @@ def rescore_with_n_best_list( word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - am_scores = compute_am_scores( + am_scores, _ = compute_am_and_lm_scores( lattice, word_fsa_with_epsilon_loops, path_to_seq_map ) @@ -456,8 +463,10 @@ def rescore_with_n_best_list( def rescore_with_whole_lattice( - lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] -) -> Dict[str, k2.Fsa]: + lattice: k2.Fsa, + G_with_epsilon_loops: k2.Fsa, + lm_scale_list: Optional[List[float]] = None, +) -> Union[k2.Fsa, Dict[str, k2.Fsa]]: """Use whole lattice to rescore. Args: @@ -467,10 +476,13 @@ def rescore_with_whole_lattice( An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. lm_scale_list: - A list containing lm_scale values. + A list containing lm_scale values or None. Returns: - A dict of FsaVec, whose key is a lm_scale and the value represents the - best decoding path for each sequence in the lattice. + If lm_scale_list is not None, return a dict of FsaVec, whose key + is a lm_scale and the value represents the best decoding path for + each sequence in the lattice. + If lm_scale_list is not None, return a lattice that is rescored + with the given LM. """ assert len(lattice.shape) == 3 assert hasattr(lattice, "lm_scores") @@ -517,6 +529,9 @@ def rescore_with_whole_lattice( # and word IDs as aux_labels. lat = k2.invert(rescoring_lattice) + if lm_scale_list is None: + return lat + ans = dict() # # The following implements @@ -532,3 +547,165 @@ def rescore_with_whole_lattice( key = f"lm_scale_{lm_scale}" ans[key] = best_path return ans + + +def rescore_with_attention_decoder( + lattice: k2.Fsa, + num_paths: int, + model: nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, +) -> Dict[str, k2.Fsa]: + """This function extracts n paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest + score is used as the decoding output. + + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `[T, N, C]`. + memory_key_padding_mask: + The padding mask for memory with shape [N, T]. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each sequence in the lattice. + """ + # First, extract `num_paths` paths for each sequence. + # path is a k2.RaggedInt with axes [seq][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + + # word_seq is a k2.RaggedInt sharing the same shape as `path` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + word_seq = k2.index(lattice.aux_labels, path) + + # Remove epsilons and -1 from word_seq + word_seq = k2.ragged.remove_values_leq(word_seq, 0) + + # Remove paths that has identical word sequences. + # + # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # except that there are no repeated paths with the same word_seq + # within a sequence. + # + # num_repeats is also a k2.RaggedInt with 2 axes containing the + # multiplicities of each path. + # num_repeats.num_elements() == unique_word_seqs.num_elements() + # + # Since k2.ragged.unique_sequences will reorder paths within a seq, + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.tot_size(1) + unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( + word_seq, need_num_repeats=True, need_new2old_indexes=True + ) + + seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path + # belongs. + path_to_seq_map = seq_to_path_shape.row_ids(1) + + # Remove the seq axis. + # Now unique_word_seq has only two axes [path][word] + unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + + # word_fsa is an FsaVec with axes [path][state][arc] + word_fsa = k2.linear_fsa(unique_word_seq) + + word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) + + am_scores, ngram_lm_scores = compute_am_and_lm_scores( + lattice, word_fsa_with_epsilon_loops, path_to_seq_map + ) + # Now we use the attention decoder to compute another + # score: attention_scores. + # + # To do that, we have to get the input and output for the attention + # decoder. + + # CAUTION: The "tokens" attribute is set in the file + # local/compile_hlg.py + token_seq = k2.index(lattice.tokens, path) + + # Remove epsilons and -1 from token_seq + token_seq = k2.ragged.remove_values_leq(token_seq, 0) + + # Remove the seq axis. + token_seq = k2.ragged.remove_axis(token_seq, 0) + + token_seq, _ = k2.ragged.index( + token_seq, indexes=new2old, axis=0, need_value_indexes=False + ) + + # Now word in unique_word_seq has its corresponding token IDs. + token_ids = k2.ragged.to_list(token_seq) + + num_word_seqs = new2old.numel() + + path_to_seq_map_long = path_to_seq_map.to(torch.long) + expanded_memory = memory.index_select(1, path_to_seq_map_long) + + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_seq_map_long + ) + + nll = model.decoder_nll( + expanded_memory, expanded_memory_key_padding_mask, token_ids + ) + assert nll.ndim == 2 + assert nll.shape[0] == num_word_seqs + + attention_scores = -nll.sum(dim=1) + assert attention_scores.ndim == 1 + assert attention_scores.numel() == num_word_seqs + + ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + + attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + + path_2axes = k2.ragged.remove_axis(path, 0) + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + tot_scores = ( + am_scores + + n_scale * ngram_lm_scores + + a_scale * attention_scores + ) + ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + best_path_indexes = k2.index(new2old, argmax_indexes) + + # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] + best_path = k2.index(path_2axes, best_path_indexes) + + # labels is a k2.RaggedInt with 2 axes [path][token_id] + # Note that it contains -1s. + labels = k2.index(lattice.labels.contiguous(), best_path) + + labels = k2.ragged.remove_values_eq(labels, -1) + + # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so + # aux_labels is also a k2.RaggedInt with 2 axes + aux_labels = k2.index(lattice.aux_labels, best_path.values()) + + best_path_fsa = k2.linear_fsa(labels) + best_path_fsa.aux_labels = aux_labels + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" + ans[key] = best_path_fsa + return ans