From 0669aa8ab9f3b209f636e6a5f5b14f808c5b59a9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 9 Aug 2021 12:47:11 +0800 Subject: [PATCH] Add attention rescore pipeline --- egs/librispeech/ASR/conformer_ctc/decode.py | 20 ++ egs/librispeech/ASR/prepare.sh | 2 +- icefall/decode.py | 245 ++++++++++++++++++++ icefall/nbest.py | 156 ++++--------- icefall/utils.py | 2 +- 5 files changed, 317 insertions(+), 108 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 889a0a474..7aec25a84 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -100,6 +100,7 @@ def decode_one_batch( model: nn.Module, HLG: k2.Fsa, batch: dict, + batch_idx: int, lexicon: Lexicon, sos_id: int, eos_id: int, @@ -201,6 +202,7 @@ def decode_one_batch( "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "attention-decoder-v2", ] lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] @@ -232,6 +234,23 @@ def decode_one_batch( sos_id=sos_id, eos_id=eos_id, ) + elif params.method == "attention-decoder-v2": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder_v2( + lattice=rescored_lattice, + batch_idx=batch_idx, + dump_best_matching_feature=params.dump_feature, + num_paths=params.num_paths, + top_k=params.top_k, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -295,6 +314,7 @@ def decode_dataset( model=model, HLG=HLG, batch=batch, + batch_idx, lexicon=lexicon, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index ae676b199..f0fb5039d 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -25,7 +25,7 @@ stop_stage=100 # - librispeech-vocab.txt # - librispeech-lexicon.txt # -# - $do_dir/musan +# - $dl_dir/musan # This directory contains the following directories downloaded from # http://www.openslr.org/17/ # diff --git a/icefall/decode.py b/icefall/decode.py index 0e9baf2e4..c7d3a86e1 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -721,3 +721,248 @@ def rescore_with_attention_decoder( key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" ans[key] = best_path_fsa return ans + + +def rescore_nbest_with_attention_decoder( + nbest: Nbest, + model: nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + sos_id: int, + eos_id: int, +) -> Nbest: + """This function rescores an nbest list with an attention decoder. The paths + with rescored scores are returned as a new nbest. + + Args: + nbest: + An Nbest, the nbest path of given sequences. + It can be the return value of :func:`generate_nbest_list`. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `[T, N, C]`. + memory_key_padding_mask: + The padding mask for memory with shape [N, T]. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each sequence in the lattice. + """ + num_seqs = nbest.shape.Dim0() + token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous()) + + # Remove -1 from token_seq, there is no epsilon tokens in token_seq, we + # removed it when generating nbest list + token_seq = k2.ragged.remove_values_leq(token_seq, -1) + + token_ids = k2.ragged.to_list(token_seq) + + path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long) + expanded_memory = memory.index_select(1, path_to_seq_map_long) + + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_seq_map_long + ) + + # TODO: pass the sos_token_id and eos_token_id via function arguments + nll = model.decoder_nll( + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, + ) + assert nll.ndim == 2 + assert nll.shape[0] == num_seqs + + attention_scores = torch.zeros( + nbest.fsas.labels().size()[0], + dtype=torch.float32, + device=nbest.device + ) + start_index = 0 + for i in range(num_seqs): + # Plus 1 to fill the score of final arc + tokens_num = len(tokens_ids[i]) + 1 + attention_scores[start_index: start_index + tokens_num] = + nll[i][0: tokens_num] + start_index += tokens_num + + fsas = nbest.fsas.clone() + fsas.score = attention_scores + return Nbest(fsas, nbest.shape.clone()) + + +def rescore_with_attention_decoder_v2( + lattice: k2.Fsa, + batch_idx: int, + dump_best_matching_feature: bool, + num_paths: int, + top_k: int, + model: nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + sos_id: int, + eos_id: int, +) -> Dict[str, k2.Fsa]: + """This function extracts n paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest + score is used as the decoding output. + + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `[T, N, C]`. + memory_key_padding_mask: + The padding mask for memory with shape [N, T]. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each sequence in the lattice. + """ + nbest = generate_nbest_list(lattice, num_paths) + # Now we have nbest with scores + nbest = nbest.intersect(lattice) + + if dump_best_matching_feature: + nbest_k, nbest_q = nbest.split(k=top_k, sort=False) + rescored_nbest_k = rescore_nbest_with_attention_decoder( + nbest=nbest_k, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) + stats_tensor = get_best_matching_stats( + rescored_nbest_k, + nbest_q, + max_order=3 + ) + rescored_nbest_q = rescore_nbest_with_attention_decoder( + nbest=nbest_q, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + # return feature & label or dump to file + + nbest_topk, nbest_remain = nbest.split(k=top_k) + + rescored_nbest_topk = rescore_nbest_with_attention_decoder( + nbest=nbest_topk, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) + stats_tensor = get_best_matching_stats( + rescored_nbest_topk, + nbest_remain, + max_order=3 + ) + # run rescore estimation model to get the mean and var of each token + mean, var = rescore_est_model(stats_tensor) + # calculate nbest_remain estimated score and select topk + nbest_remain_topk = nbest_remain.top_k(k=top_k) + rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder( + nbest=nbest_remain_topk, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) + best_path_dict=get_best_path_from_nbests( + rescored_nbest_topk, + rescored_nbest_remain_topk, + ) + + return ans + + +def generate_nbest_list( + lats: k2.Fsa, + num_paths: int, + aux_labels: bool = False +) -> Nbest: + '''Generate an n-best list from a lattice. + + Args: + lats: + The decoding lattice from the first pass after LM rescoring. + lats is an FsaVec. It can be the return value of + :func:`rescore_with_whole_lattice` + num_paths: + Size of n for n-best list. CAUTION: After removing paths + that represent the same word sequences, the number of paths + in different sequences may not be equal. + Return: + Return an Nbest object. Note the returned FSAs don't have epsilon + self-loops. + ''' + assert len(lats.shape) == 3 + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # Seqs is a k2.RaggedInt sharing the same shape as `paths`. + # Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + # Its axes are [seq][path][word_id] + if aux_labels: + # if aux_labels enable, seqs contains word_id + assert hasattr(lats, "aux_labels") + seqs = k2.index(lats.aux_labels, paths) + else: + # CAUTION: We use `phones` instead of `tokens` here because + # :func:`compile_HLG` uses `phones` + # + # Note: compile_HLG is from k2-fsa/snowfall + assert hasattr(lats, 'phones') + + assert not hasattr(lats, 'tokens') + lats.tokens = lats.phones + seqs = k2.index(lats.tokens, paths) + + # Remove epsilons (0s) and -1 from word_seqs + seqs = k2.ragged.remove_values_leq(seqs, 0) + + # unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id]. + # But then number of pathsin each sequence may be different. + unique_seqs, _, _ = k2.ragged.unique_sequences( + seqs, need_num_repeats=False, need_new2old_indexes=False) + + seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0) + + # Remove the seq axis. + # Now unique_word_seqs has only two axes [path][word_id] + unique_seqs = k2.ragged.remove_axis(unique_seqs, 0) + + fsas = k2.linear_fsa(unique_seqs) + + return Nbest(fsa=fsas, shape=seq_to_path_shape) + diff --git a/icefall/nbest.py b/icefall/nbest.py index 1a5394673..14d44e227 100644 --- a/icefall/nbest.py +++ b/icefall/nbest.py @@ -5,10 +5,9 @@ # See https://github.com/k2-fsa/snowfall/issues/232 for more details # import logging -from typing import List +from typing import List, Tuple import torch -import _k2 import k2 # Note: We use `utterance` and `sequence` interchangeably in the comment @@ -19,7 +18,7 @@ class Nbest(object): An Nbest object contains two fields: (1) fsa, its type is k2.Fsa - (2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape) + (2) shape, its type is k2.RaggedShape The field `fsa` is an FsaVec containing a vector of **linear** FSAs. @@ -29,7 +28,7 @@ class Nbest(object): of paths, which is also the number of FSAs in `fsa`. ''' - def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None: + def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None: assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}' assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}' @@ -85,7 +84,7 @@ class Nbest(object): return Nbest(fsa=one_best, shape=self.shape) - def total_scores(self) -> _k2.RaggedFloat: + def total_scores(self) -> k2.RaggedFloat: '''Get total scores of the FSAs in this Nbest. Note: @@ -99,7 +98,7 @@ class Nbest(object): log_semiring=False) # We use single precision here since we only wrap k2.RaggedFloat. # If k2.RaggedDouble is wrapped, we can use double precision here. - return _k2.RaggedFloat(self.shape, scores.float()) + return k2.RaggedFloat(self.shape, scores.float()) def top_k(self, k: int) -> 'Nbest': '''Get a subset of paths in the Nbest. The resulting Nbest is regular @@ -144,121 +143,66 @@ class Nbest(object): return Nbest(top_k_fsas, top_k_shape) -def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: - '''Rescore the 1st pass lattice with an LM. + def split(self, k: int, sort: bool = True) -> Tuple['Nbest', 'Nbest']: + '''Split the paths in the Nbest into two parts, the first part is the + first k paths for each sequence in the Nbest, the second part is the + remaining paths. + There may be less than k paths for the responding sequence in the part, - In general, the G in HLG used to obtain `lats` is a 3-gram LM. - This function replaces the 3-gram LM in `lats` with a 4-gram LM. + If the sort flag is true, we select the top-k paths according to the + total_scores of each path in descending order, If a utterance has less + than k paths, then the first part will have the really number of paths + and leaving the second part empty. - Args: - lats: - The decoding lattice from the 1st pass. We assume it is the result - of intersecting HLG with the network output. - G_with_epsilon_loops: - An LM. It is usually a 4-gram LM with epsilon self-loops. - It should be arc sorted. - Returns: - Return a new lattice rescored with a given G. - ''' - assert len(lats.shape) == 3, f'{lats.shape}' - assert hasattr(lats, 'lm_scores') - assert G_with_epsilon_loops.shape == (1, None, None), \ - f'{G_with_epsilon_loops.shape}' + Args: + k: + Number of paths in the first part of each utterance. + Returns: + Return a tuple of new Nbest. + ''' + # indexes contains idx01's for self.shape + indexes = torch.arange( + self.shape.num_elements(), dtype=torch.int32, + device=self.shape.device + ) - device = lats.device - lats.scores = lats.scores - lats.lm_scores - # Now lats contains only acoustic scores + if sort: + ragged_scores = self.total_scores() - # We will use lm_scores from the given G, so remove lats.lm_scores here - del lats.lm_scores - assert hasattr(lats, 'lm_scores') is False + # ragged_scores.values()[indexes] is sorted + indexes = k2.ragged.sort_sublist( + ragged_scores, descending=True, need_new2old_indexes=True + ) - # inverted_lats has word IDs as labels. - # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt - # if lats.aux_labels is a ragged tensor - inverted_lats = k2.invert(lats) - num_seqs = lats.shape[0] + ragged_indexes = k2.RaggedInt(self.shape, indexes) - b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + padded_indexes = k2.ragged.pad(ragged_indexes, value=-1) - while True: - try: - rescoring_lats = k2.intersect_device(G_with_epsilon_loops, - inverted_lats, - b_to_a_map, - sorted_match_a=True) - break - except RuntimeError as e: - logging.info(f'Caught exception:\n{e}\n') - # Usually, this is an OOM exception. We reduce - # the size of the lattice and redo k2.intersect_device() + # Select the idx01's of top-k paths of each utterance + first_indexes = padded_indexes[:, :k].flatten().contiguous() - # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here - # to avoid OOM. We may need to fine tune it. - logging.info(f'num_arcs before: {inverted_lats.num_arcs}') - inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True) - logging.info(f'num_arcs after: {inverted_lats.num_arcs}') + # Remove the padding elements + first_indexes = first_indexes[first_indexes >= 0] - rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) + first_fsas = k2.index_fsa(self.fsa, first_indexes) - # inv_rescoring_lats has token IDs as labels - # and word IDs as aux_labels. - inv_rescoring_lats = k2.invert(rescoring_lats) - return inv_rescoring_lats + first_row_ids = k2.index(self.shape.row_ids(1), first_indexes) + first_shape = k2.ragged.create_ragged_shape2(row_ids=first_row_ids) + first_nbest = Nbest(first_fsas, first_shape) -def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest: - '''Generate an n-best list from a lattice. + # Select the idx01's of remaining paths of each utterance + second_indexes = padded_indexes[:, k:].flatten().contiguous() - Args: - lats: - The decoding lattice from the first pass after LM rescoring. - lats is an FsaVec. It can be the return value of - :func:`whole_lattice_rescoring` - num_paths: - Size of n for n-best list. CAUTION: After removing paths - that represent the same token sequences, the number of paths - in different sequences may not be equal. - Return: - Return an Nbest object. Note the returned FSAs don't have epsilon - self-loops. - ''' - assert len(lats.shape) == 3 + # Remove the padding elements + second_indexes = second_indexes[second_indexes >= 0] - # CAUTION: We use `phones` instead of `tokens` here because - # :func:`compile_HLG` uses `phones` - # - # Note: compile_HLG is from k2-fsa/snowfall - assert hasattr(lats, 'phones') + second_fsas = k2.index_fsa(self.fsa, second_indexes) - assert not hasattr(lats, 'tokens') - lats.tokens = lats.phones - # we use tokens instead of phones in the following code + second_row_ids = k2.index(self.shape.row_ids(1), second_indexes) + second_shape = k2.ragged.create_ragged_shape2(row_ids=second_row_ids) - # First, extract `num_paths` paths for each sequence. - # paths is a k2.RaggedInt with axes [seq][path][arc_pos] - paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + second_nbest = Nbest(second_fsas, second_shape) - # token_seqs is a k2.RaggedInt sharing the same shape as `paths` - # but it contains token IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - # Its axes are [seq][path][token_id] - token_seqs = k2.index(lats.tokens, paths) + return first_nbest, second_nbest - # Remove epsilons (0s) and -1 from token_seqs - token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) - - # unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id]. - # But then number of pathsin each sequence may be different. - unique_token_seqs, _, _ = k2.ragged.unique_sequences( - token_seqs, need_num_repeats=False, need_new2old_indexes=False) - - seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) - - # Remove the seq axis. - # Now unique_token_seqs has only two axes [path][token_id] - unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) - - token_fsas = k2.linear_fsa(unique_token_seqs) - - return Nbest(fsa=token_fsas, shape=seq_to_path_shape) diff --git a/icefall/utils.py b/icefall/utils.py index ca8a338a0..1bf0f88de 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -5,7 +5,7 @@ import subprocess from collections import defaultdict from contextlib import contextmanager from datetime import datetime -from nbest import Nbest +from icefall.nbest import Nbest from pathlib import Path from typing import Dict, Iterable, List, TextIO, Tuple, Union