From de2f5e3e6d66ddccb44c4c41ab04260acce6fb2f Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 16:15:56 +0800 Subject: [PATCH] support RNNLM shallow fusion for LSTM transducer --- .../ASR/lstm_transducer_stateless2/decode.py | 143 ++++- .../beam_search.py | 514 +++++++++--------- icefall/rnn_lm/model.py | 126 ++++- 3 files changed, 503 insertions(+), 280 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index c7b53ebc0..1d46c0177 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -115,7 +115,8 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_shallow_fusion, + ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model @@ -128,6 +129,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -216,7 +218,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified_beam_search_ngram_rescoring + - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -307,21 +309,74 @@ def get_parser(): ) parser.add_argument( - "--tokens-ngram", - type=int, - default=3, - help="""Token Ngram used for rescoring. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search3. + It specifies the path to RNN LM exp dir. + """, ) parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="""ID of the backoff symbol. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, ) + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + parser.add_argument( + "--ilm-scale", + type=float, + default=-0.1 + ) add_model_arguments(parser) return parser @@ -336,6 +391,8 @@ def decode_one_batch( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -469,14 +526,14 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_ngram_rescoring": - hyp_tokens = modified_beam_search_ngram_rescoring( + elif params.decoding_method == "modified_beam_search_sf_rnnlm": + hyp_tokens = modified_beam_search_sf_rnnlm_batched( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -531,7 +588,9 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + rnnlm: Optional[NgramLm] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -572,6 +631,9 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") hyps_dict = decode_one_batch( params=params, @@ -582,6 +644,8 @@ def decode_dataset( batch=batch, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -607,7 +671,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -667,7 +731,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_ngram_rescoring", + "modified_beam_search_sf_rnnlm", ) params.res_dir = params.exp_dir / params.decoding_method @@ -692,7 +756,12 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + + if "rnnlm" in params.decoding_method: + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + + if "ILME" in params.decoding_method: + params.suffix += f"-ILME-scale={params.ilm_scale}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -806,14 +875,28 @@ def main(): model.to(device) model.eval() - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") + # only load rnnlm if used + if "rnnlm" in params.decoding_method: + rnn_lm_scale = params.rnn_lm_scale + + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + + else: + rnn_lm_model = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -861,6 +944,8 @@ def main(): decoding_graph=decoding_graph, ngram_lm=ngram_lm, ngram_lm_scale=params.ngram_lm_scale, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0004a24eb..01cc566e8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import k2 import sentencepiece as spm @@ -25,13 +25,8 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import add_eos, add_sos, get_texts def fast_beam_search_one_best( @@ -43,8 +38,7 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -68,12 +62,8 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -87,11 +77,8 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search_nbest_LG( @@ -106,8 +93,7 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -144,12 +130,8 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -214,10 +196,9 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + + return hyps def fast_beam_search_nbest( @@ -232,8 +213,7 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -270,12 +250,8 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -304,10 +280,9 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + + return hyps def fast_beam_search_nbest_oracle( @@ -323,8 +298,7 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -365,12 +339,8 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -409,10 +379,8 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search( @@ -502,11 +470,8 @@ def fast_beam_search( def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: """Greedy search for a single utterance. Args: model: @@ -516,12 +481,8 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -547,10 +508,6 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -577,7 +534,6 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -592,21 +548,14 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - if not return_timestamps: - return hyp - else: - return DecodingResults( - tokens=[hyp], - timestamps=[timestamp], - ) + return hyp def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -616,12 +565,9 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -646,10 +592,6 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - decoder_input = torch.tensor( hyps, device=device, @@ -663,7 +605,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -685,7 +627,6 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) - timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -700,19 +641,11 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) + return ans @dataclass @@ -725,11 +658,9 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] - state_cost: Optional[NgramLmStateCost] = None + state: Optional = None + lm_score: Optional=None @property def key(self) -> str: @@ -878,8 +809,7 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -894,12 +824,9 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -917,7 +844,7 @@ def modified_beam_search( device = next(model.parameters()).device batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -927,7 +854,6 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) @@ -935,7 +861,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1013,44 +939,30 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) + return ans def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -1065,13 +977,8 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -1091,7 +998,6 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1150,24 +1056,17 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - if not return_timestamps: - return ys - else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return ys def beam_search( @@ -1175,8 +1074,7 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1191,13 +1089,8 @@ def beam_search( Beam size. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -1224,7 +1117,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) max_sym_per_utt = 20000 @@ -1285,13 +1178,7 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1300,14 +1187,7 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1323,11 +1203,7 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return ys def fast_beam_search_with_nbest_rescoring( @@ -1347,8 +1223,7 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: +) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1390,13 +1265,10 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1474,18 +1346,16 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + ans: Dict[str, List[List[int]]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) + ans[key] = hyps return ans @@ -1509,8 +1379,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: +) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1556,13 +1425,10 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1674,45 +1540,151 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) + ans[key] = hyps return ans +def modified_beam_search_sf_rnnlm( + model: Transducer, + encoder_out: torch.Tensor, + sp, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +): + encoder_out = model.joiner.encoder_proj(encoder_out) + lm_scale = rnnlm_scale -def modified_beam_search_ngram_rescoring( + assert rnnlm is not None + assert encoder_out.ndim == 2, encoder_out.shape + rnnlm.clean_cache() + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + T = encoder_out.shape[0] + for t in range(T): + current_encoder_out = encoder_out[t : t + 1] + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyp in A] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + # decoder_out is of shape (num_hyps, joiner_dim) + current_encoder_out = current_encoder_out.repeat(len(A), 1) + # current_encoder_out is of shape (num_hyps, encoder_out_dim) + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk( + beam + ) # get topk tokens and scores + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[hyp_idx] # get hyp + new_ys = hyp.ys[:] + state = "ys=" + "+".join(list(map(str, new_ys))) + tokens = k2.RaggedTensor([new_ys[context_size:]]) + + lm_score = rnnlm.predict( + tokens, state, sos_id, eos_id, blank_id + ) # get rnnlm score + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] # get token + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + # state_cost = hyp.state_cost.forward_one_step(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + else: + new_ys = new_ys + new_log_prob = hyp_log_prob + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + return best_hyp.ys[context_size:] + +def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, + sp: spm.SentencePieceProcessor, + rnnlm: RnnLmModel, + rnnlm_scale: float, beam: int = 4, - temperature: float = 1.0, ) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + """Modified_beam_search + RNNLM shallow fusion Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + rnnlm (RnnLmModel): + RNNLM + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, lengths=encoder_out_lens.cpu(), @@ -1721,34 +1693,41 @@ def modified_beam_search_ngram_rescoring( ) blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - lm_scale = ngram_lm_scale - + batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + B = [HypothesisList() for _ in range(N)] for i in range(N): B[i].add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), + state=init_states, + lm_score=init_score.reshape(-1) ) ) + rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - + offset = 0 finalized_B = [] for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] + current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -1760,49 +1739,44 @@ def modified_beam_search_ngram_rescoring( A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - + ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyps in A for hyp in hyps], device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, 1, 1, encoder_out_dim) - + logits = model.joiner( current_encoder_out, decoder_out, project_input=False, ) # (num_hyps, 1, 1, vocab_size) - + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( + log_probs = logits.log_softmax( dim=-1 ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) + log_probs = log_probs.reshape(-1) + row_splits = hyps_shape.row_splits(1) * vocab_size log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() @@ -1810,7 +1784,12 @@ def modified_beam_search_ngram_rescoring( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) + + # for all hyps with a non-blank new token, score it + token_list = [] + hs = [] + cs = [] for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1818,28 +1797,63 @@ def modified_beam_search_ngram_rescoring( warnings.simplefilter("ignore") topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + assert new_token != 0, new_token + token_list.append([new_token]) + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) - + + ys.append(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + + lm_score = scores[count] + state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) + count += 1 + new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score ) - B[i].add(new_hyp) + B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] @@ -1850,4 +1864,4 @@ def modified_beam_search_ngram_rescoring( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans + return ans \ No newline at end of file diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 88b2cc41f..2552f65a6 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -18,8 +18,9 @@ import logging import torch import torch.nn.functional as F +import k2 -from icefall.utils import make_pad_mask +from icefall.utils import add_eos, add_sos, make_pad_mask class RnnLmModel(torch.nn.Module): @@ -72,6 +73,8 @@ class RnnLmModel(torch.nn.Module): else: logging.info("Not tying weights") + self.cache = {} + def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: @@ -118,3 +121,124 @@ class RnnLmModel(torch.nn.Module): nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss + + def get_init_states(self, sos): + p = next(self.parameters()) + + def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): + device = next(self.parameters()).device + batch_size = len(token_lens) + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64).to(device) + y_tokens = y_tokens.to(torch.int64).to(device) + sentence_lengths = sentence_lengths.to(torch.int64).to(device) + + embedding = self.input_embedding(x_tokens) + + # Note: We use batch_first==True + rnn_out, states = self.rnn(embedding) + logits = self.output_linear(rnn_out) + mask = torch.zeros(logits.shape).bool().to(device) + for i in range(batch_size): + mask[i, token_lens[i], :] = True + logits = logits[mask].reshape(batch_size, -1) + + return logits[:,:].log_softmax(-1), states + + def clean_cache(self): + self.cache = {} + + def score_token(self, tokens: torch.Tensor, state=None): + device = next(self.parameters()).device + batch_size = tokens.size(0) + if state: + h,c = state + else: + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) + + embedding = self.input_embedding(tokens) + rnn_out, states = self.rnn(embedding, (h,c)) + logits = self.output_linear(rnn_out) + + return logits[:,0].log_softmax(-1), states + + def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): + batch_size = len(token_lens) + if state: + h,c = state + else: + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + + device = next(self.parameters()).device + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64).to(device) + y_tokens = y_tokens.to(torch.int64).to(device) + sentence_lengths = sentence_lengths.to(torch.int64).to(device) + + embedding = self.input_embedding(x_tokens) + + # Note: We use batch_first==True + rnn_out, states = self.rnn(embedding, (h,c)) + logits = self.output_linear(rnn_out) + + return logits, states + +if __name__=="__main__": + LM = RnnLmModel(500, 2048, 2048, 3, True) + h0 = torch.zeros(3, 1, 2048) + c0 = torch.zeros(3, 1, 2048) + seq = [[0,1,2,3]] + seq_lens = [len(s) for s in seq] + tokens = k2.RaggedTensor(seq) + output1, state = LM.forward_with_state( + tokens, + seq_lens, + 1, + 1, + 0, + state=(h0,c0) + ) + seq = [[0,1,2,3,4]] + seq_lens = [len(s) for s in seq] + tokens = k2.RaggedTensor(seq) + output2, _ = LM.forward_with_state( + tokens, + seq_lens, + 1, + 1, + 0, + state=(h0,c0) + ) + + seq = [[4]] + seq_lens = [len(s) for s in seq] + output3 = LM.score_token(seq, seq_lens, state) + + print("Finished") + + +