diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 6b6190a09..646bcc618 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional import k2 +import sentencepiece as spm import torch from model import Transducer @@ -34,6 +35,7 @@ def fast_beam_search_one_best( beam: float, max_states: int, max_contexts: int, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -56,6 +58,8 @@ def fast_beam_search_one_best( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -67,6 +71,7 @@ def fast_beam_search_one_best( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) best_path = one_best_decoding(lattice) @@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG( num_paths: int, nbest_scale: float = 0.5, use_double_scores: bool = True, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -131,6 +137,7 @@ def fast_beam_search_nbest_LG( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -201,6 +208,7 @@ def fast_beam_search_nbest( num_paths: int, nbest_scale: float = 0.5, use_double_scores: bool = True, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -247,6 +255,7 @@ def fast_beam_search_nbest( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -282,6 +291,7 @@ def fast_beam_search_nbest_oracle( ref_texts: List[List[int]], use_double_scores: bool = True, nbest_scale: float = 0.5, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -333,6 +343,7 @@ def fast_beam_search_nbest_oracle( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -373,6 +384,7 @@ def fast_beam_search( beam: float, max_states: int, max_contexts: int, + temperature: float = 1.0, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -440,7 +452,7 @@ def fast_beam_search( project_input=False, ) logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) + log_probs = (logits / temperature).log_softmax(dim=-1) decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) @@ -783,6 +795,7 @@ def modified_beam_search( encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, beam: int = 4, + temperature: float = 1.0, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -879,7 +892,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1043,6 +1058,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + temperature: float = 1.0, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1132,7 +1148,7 @@ def beam_search( ) # TODO(fangjun): Scale the blank posterior - log_prob = logits.log_softmax(dim=-1) + log_prob = (logits / temperature).log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) log_prob = log_prob.squeeze() # Now log_prob is (vocab_size,) @@ -1171,3 +1187,155 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys + + +def fast_beam_search_with_nbest_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, List[List[int]]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans