From 7886da9b5913732f479e3e01cb90228dfcca2219 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 14 Nov 2023 20:15:43 +0800 Subject: [PATCH] add attention-decoder-rescoring --- .../ASR/zipformer/attention_decoder.py | 14 +- egs/librispeech/ASR/zipformer/ctc_decode.py | 64 ++++- egs/librispeech/ASR/zipformer/train.py | 15 +- icefall/decode.py | 231 ++++++++++++++++++ 4 files changed, 311 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 0ee24f9a8..6c6cabec5 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -24,9 +24,10 @@ from typing import List, Tuple import k2 import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss +from label_smoothing import LabelSmoothingLoss from icefall.utils import add_eos, add_sos, make_pad_mask +from scaling import penalize_abs_values_gt class AttentionDecoderModel(nn.Module): @@ -355,6 +356,17 @@ class MultiHeadedAttention(nn.Module): # (batch, head, time1, time2) attn_output_weights = torch.matmul(q, k) / self.scale + # attn_output_weights = torch.matmul(q, k) + # # This is a harder way of limiting the attention scores to not be too large. + # # It incurs a penalty if any of them has an absolute value greater than 50.0. + # # this should be outside the normal range of the attention scores. We use + # # this mechanism instead of, say, a limit on entropy, because once the entropy + # # gets very small gradients through the softmax can become very small, and + # # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=50.0, penalty=1.0e-04 + ) + if mask is not None: attn_output_weights = attn_output_weights.masked_fill( mask.unsqueeze(1), float("-inf") diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 4db50b981..32b1b64a0 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -103,6 +103,8 @@ from icefall.decode import ( one_best_decoding, rescore_with_n_best_list, rescore_with_whole_lattice, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, ) from icefall.lexicon import Lexicon from icefall.utils import ( @@ -406,6 +408,26 @@ def decode_one_batch( key = "ctc-decoding" return {key: hyps} + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + if params.decoding_method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -446,6 +468,7 @@ def decode_one_batch( assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -466,6 +489,21 @@ def decode_one_batch( G_with_epsilon_loops=G, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.decoding_method}" @@ -564,12 +602,21 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. @@ -577,8 +624,8 @@ def save_results( with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" @@ -616,6 +663,8 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", ) params.res_dir = params.exp_dir / params.decoding_method @@ -654,8 +703,10 @@ def main(): params.vocab_size = num_classes # and are defined in local/train_bpe_model.py params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 - if params.decoding_method == "ctc-decoding": + if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: HLG = None H = k2.ctc_topo( max_token=max_token_id, @@ -679,6 +730,7 @@ def main(): if params.decoding_method in ( "nbest-rescoring", "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -710,7 +762,9 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.decoding_method == "whole-lattice-rescoring": + if params.decoding_method in [ + "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index b5bb789d2..d0f25ed01 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -453,13 +453,13 @@ def get_parser(): help="Scale for attention-decoder loss.", ) - parser.add_argument( - "--label-smoothing", - type=float, - default=0.1, - help="""Label smoothing rate used in attention decoder, - (0.0 means the conventional cross entropy loss)""", - ) + # parser.add_argument( + # "--label-smoothing", + # type=float, + # default=0.1, + # help="""Label smoothing rate used in attention decoder, + # (0.0 means the conventional cross entropy loss)""", + # ) parser.add_argument( "--seed", @@ -591,6 +591,7 @@ def get_params() -> AttributeDict: "subsampling_factor": 4, # not passed in, this is fixed. # parameters for attention-decoder "ignore_id": -1, + "label_smoothing": 0.1, "warm_step": 2000, "env_info": get_env_info(), } diff --git a/icefall/decode.py b/icefall/decode.py index 23f9fb9b3..1d0991d87 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1083,6 +1083,237 @@ def rescore_with_attention_decoder( return ans +def rescore_with_attention_decoder_with_ngram( + lattice: k2.Fsa, + num_paths: int, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + nbest_scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + attention_decoder: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + encoder_out: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(N, T, C)`. + encoder_out_lens: + Length of encoder outputs, with shape of `(N,)`. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + try: + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + nbest = nbest.intersect(lattice) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info(f"num_paths before decreasing: {num_paths}") + num_paths = int(num_paths / 2) + if loop_count >= max_loop_count or num_paths <= 0: + logging.info("Return None as the resulting lattice is too large.") + return None + logging.info( + "This OOM is not an error. You can ignore it. " + "If your model does not converge well, or --max-duration " + "is too large, or the input sound file is difficult to " + "decode, you will meet this exception." + ) + logging.info(f"num_paths after decreasing: {num_paths}") + loop_count += 1 + + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.01, 0.05, 0.08] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores.values + + a_scale * attention_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" + ans[key] = best_path + return ans + + +def rescore_with_attention_decoder_no_ngram( + lattice: k2.Fsa, + num_paths: int, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + nbest_scale: float = 1.0, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + attention_decoder: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + encoder_out: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(N, T, C)`. + encoder_out_lens: + Length of encoder outputs, with shape of `(N,)`. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + # path is a ragged tensor with dtype torch.int32. + # It has three axes [utt][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + # Note that labels, aux_labels and scores contains 0s and -1s. + # The last entry in each sublist is -1. + # The axes are [path][token_id] + labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0) + aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0) + scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0) + + # Remove -1 from labels as we will use it to construct a linear FSA + labels = labels.remove_values_eq(-1) + fsa = k2.linear_fsa(labels) + fsa.aux_labels = aux_labels.values + + # utt_to_path_shape has axes [utt][path] + utt_to_path_shape = path.shape.get_layer(0) + scores = k2.RaggedTensor(utt_to_path_shape, scores.sum()) + + path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long) + # the shape of memory is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + + token_ids = aux_labels.remove_values_leq(0).tolist() + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + for a_scale in attention_scale_list: + tot_scores = scores.values + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(fsa, max_indexes) + + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans + + def rescore_with_rnn_lm( lattice: k2.Fsa, num_paths: int,