diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ed6a6ea82..59a4f67df 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from dataclasses import dataclass from typing import Dict, List, Optional @@ -21,10 +22,11 @@ from typing import Dict, List, Optional import k2 import sentencepiece as spm import torch +from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence from model import Transducer -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts +from icefall.decode import get_lattice, Nbest, one_best_decoding +from icefall.utils import get_alignments, get_texts def fast_beam_search_one_best( @@ -553,6 +555,9 @@ def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, + decoding_graph: Optional[k2.Fsa] = None, + ngram_rescoring: bool = False, + gamma_blank: float = 1.0, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -602,6 +607,18 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + if ngram_rescoring: + vocab_size = model.decoder.vocab_size + total_t = encoder_out.shape[0] + # cached all joiner outputs during greedy search, + # from which non-blank frames are selected before n-gram rescoring. + all_logits = torch.zeros([total_t, vocab_size], device=device) + + # A flag indicating a frame is a blank frame or not. + # 0 for blank frame and 1 for non-blank frame. + # Used to select non-blank frames for n-gram rescoring. + non_blank_flag = torch.zeros([total_t], device=device) + offset = 0 for batch_size in batch_size_list: start = offset @@ -616,10 +633,23 @@ def greedy_search_batch( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), project_input=False ) - # logits'shape (batch_size, 1, 1, vocab_size) + # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape + + if ngram_rescoring: + all_logits[start:end] = logits + + assert logits.ndim == 2, logits.shape + logits_softmax = logits.softmax(dim=1) + + + # 0 for blank frame and 1 for non-blank frame. + non_blank_flag[start:end] = torch.where( + logits_softmax[:, 0] >= gamma_blank, 0, 1 + ) + + y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): @@ -643,6 +673,91 @@ def greedy_search_batch( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + if not ngram_rescoring: + return ans + + assert decoding_graph is not None + + # Transform logits to shape [N, T, vocab_size] format to make it easier + # to select non-blank frames. + packed_all_logits = PackedSequence( + all_logits, torch.tensor(batch_size_list) + ) + all_logits_unpacked, _ = pad_packed_sequence( + packed_all_logits, batch_first=True + ) + + # Transform non_blank_flag to shape [N, T] + packed_non_blank_flag = PackedSequence( + non_blank_flag, torch.tensor(batch_size_list) + ) + non_blank_flag_unpacked, _ = pad_packed_sequence( + packed_non_blank_flag, batch_first=True + ) + + non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1) + max_frame_to_rescore = non_blank_logits_lens.max() + + non_blank_logits = torch.zeros( + [N, int(max_frame_to_rescore), vocab_size], device=device + ) + + # torch.index_select only acceptec a single dimension to index from. + # So we need generate non_blank_logits one by one. + # Maybe there is another efficient way to do this. + for i in range(N): + cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0] + assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0] + non_blank_logits[ + i, : int(non_blank_logits_lens[i]), : + ] = torch.index_select( + all_logits_unpacked[i, :], 0, cur_non_blank_index + ) + + + + number_selected_frames = non_blank_flag.sum() + logging.info(f"{number_selected_frames} are selected out of {total_t} frames") + # Split log_softmax into two seperate steps, + # so we cound do blank deweight in probability domain if needed. + logits_to_rescore_softmax = non_blank_logits.softmax(dim=2) + logits_to_rescore = logits_to_rescore_softmax.log() + + # In paper: https://arxiv.org/pdf/2101.06856.pdf + # blank deweight is applied before non_blank frames selected. + # However, in current setup, that results in a higher WER. + # So just put this blank deweight before ngram rescoring. + # (TODO): debug this blank deweight issue. + + blank_deweight = 0.0 + logits_to_rescore[:, :, 0] -= blank_deweight + + supervision_segments = torch.zeros([N, 3], dtype=torch.int32) + supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32) + supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32) + + lattice = get_lattice( + nnet_output=logits_to_rescore, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=1000, + subsampling_factor=1, + ) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=True, + ) + + token_ids = get_alignments(best_path, "labels", remove_zero_blank=True) + + ans = [] + for i in range(N): + usi = unsorted_indices[i] + ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])]) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 701cad73c..7aed94674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -136,6 +136,28 @@ def get_parser(): "`epoch` are loaded for averaging. ", ) + parser.add_argument( + "--ngram-rescoring", + type=str2bool, + default=False, + help="Whether to use ngram_rescoring.", + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="trivial_graph", + help="one of [trivial_grpah, HLG, Trival_LG, LG]" + "used by greedy_search_batch with ngram-rescoring=True.", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="./data/lang_bpe_500/", + help="Path to decoding graphs", + ) + parser.add_argument( "--exp-dir", type=str, @@ -219,6 +241,12 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--gamma-blank", + type=int, + default=1.0, + ) + return parser @@ -293,6 +321,9 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + decoding_graph=decoding_graph, + ngram_rescoring=params.ngram_rescoring, + gamma_blank=params.gamma_blank, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -498,6 +529,11 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.ngram_rescoring: + params.suffix += "-ngram-rescoring" + params.suffix += f"-{params.decoding_graph}" + params.suffix += f"-gamma_blank-{params.gamma_blank}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -605,6 +641,24 @@ def main(): else: decoding_graph = None + if params.ngram_rescoring and params.decoding_method == "greedy_search": + assert params.decoding_graph in [ + "trivial_graph", + "L", + ], f"Unsupported decoding graph {params.decoding_graph}" + if params.decoding_graph == "trivial_graph": + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = k2.Fsa.from_dict( + torch.load( + f"data/lang_bpe_500/{params.decoding_graph}.pt", + map_location=device, + ) + ) + decoding_graph = k2.add_epsilon_self_loops(decoding_graph) + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/icefall/utils.py b/icefall/utils.py index 3bfd5e5b1..bf11eae27 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -236,7 +236,9 @@ def get_texts( return aux_labels.tolist() -def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: +def get_alignments( + best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False +) -> List[List[int]]: """Extract labels or aux_labels from the best-path FSAs. Args: @@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: token_shape, getattr(best_paths, kind).contiguous() ) tokens = tokens.remove_values_eq(-1) + if remove_zero_blank: + tokens = tokens.remove_values_eq(0) return tokens.tolist()