From de2f5e3e6d66ddccb44c4c41ab04260acce6fb2f Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 16:15:56 +0800 Subject: [PATCH 01/14] support RNNLM shallow fusion for LSTM transducer --- .../ASR/lstm_transducer_stateless2/decode.py | 143 ++++- .../beam_search.py | 514 +++++++++--------- icefall/rnn_lm/model.py | 126 ++++- 3 files changed, 503 insertions(+), 280 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index c7b53ebc0..1d46c0177 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -115,7 +115,8 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_shallow_fusion, + ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model @@ -128,6 +129,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -216,7 +218,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified_beam_search_ngram_rescoring + - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -307,21 +309,74 @@ def get_parser(): ) parser.add_argument( - "--tokens-ngram", - type=int, - default=3, - help="""Token Ngram used for rescoring. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search3. + It specifies the path to RNN LM exp dir. + """, ) parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="""ID of the backoff symbol. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, ) + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + parser.add_argument( + "--ilm-scale", + type=float, + default=-0.1 + ) add_model_arguments(parser) return parser @@ -336,6 +391,8 @@ def decode_one_batch( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -469,14 +526,14 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_ngram_rescoring": - hyp_tokens = modified_beam_search_ngram_rescoring( + elif params.decoding_method == "modified_beam_search_sf_rnnlm": + hyp_tokens = modified_beam_search_sf_rnnlm_batched( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -531,7 +588,9 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + rnnlm: Optional[NgramLm] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -572,6 +631,9 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") hyps_dict = decode_one_batch( params=params, @@ -582,6 +644,8 @@ def decode_dataset( batch=batch, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -607,7 +671,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -667,7 +731,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_ngram_rescoring", + "modified_beam_search_sf_rnnlm", ) params.res_dir = params.exp_dir / params.decoding_method @@ -692,7 +756,12 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + + if "rnnlm" in params.decoding_method: + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + + if "ILME" in params.decoding_method: + params.suffix += f"-ILME-scale={params.ilm_scale}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -806,14 +875,28 @@ def main(): model.to(device) model.eval() - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") + # only load rnnlm if used + if "rnnlm" in params.decoding_method: + rnn_lm_scale = params.rnn_lm_scale + + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + + else: + rnn_lm_model = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -861,6 +944,8 @@ def main(): decoding_graph=decoding_graph, ngram_lm=ngram_lm, ngram_lm_scale=params.ngram_lm_scale, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0004a24eb..01cc566e8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import k2 import sentencepiece as spm @@ -25,13 +25,8 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import add_eos, add_sos, get_texts def fast_beam_search_one_best( @@ -43,8 +38,7 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -68,12 +62,8 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -87,11 +77,8 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search_nbest_LG( @@ -106,8 +93,7 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -144,12 +130,8 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -214,10 +196,9 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + + return hyps def fast_beam_search_nbest( @@ -232,8 +213,7 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -270,12 +250,8 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -304,10 +280,9 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + + return hyps def fast_beam_search_nbest_oracle( @@ -323,8 +298,7 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -365,12 +339,8 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -409,10 +379,8 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search( @@ -502,11 +470,8 @@ def fast_beam_search( def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: """Greedy search for a single utterance. Args: model: @@ -516,12 +481,8 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -547,10 +508,6 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -577,7 +534,6 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -592,21 +548,14 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - if not return_timestamps: - return hyp - else: - return DecodingResults( - tokens=[hyp], - timestamps=[timestamp], - ) + return hyp def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -616,12 +565,9 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -646,10 +592,6 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - decoder_input = torch.tensor( hyps, device=device, @@ -663,7 +605,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -685,7 +627,6 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) - timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -700,19 +641,11 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) + return ans @dataclass @@ -725,11 +658,9 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] - state_cost: Optional[NgramLmStateCost] = None + state: Optional = None + lm_score: Optional=None @property def key(self) -> str: @@ -878,8 +809,7 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -894,12 +824,9 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -917,7 +844,7 @@ def modified_beam_search( device = next(model.parameters()).device batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -927,7 +854,6 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) @@ -935,7 +861,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1013,44 +939,30 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) + return ans def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -1065,13 +977,8 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -1091,7 +998,6 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1150,24 +1056,17 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - if not return_timestamps: - return ys - else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return ys def beam_search( @@ -1175,8 +1074,7 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1191,13 +1089,8 @@ def beam_search( Beam size. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -1224,7 +1117,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) max_sym_per_utt = 20000 @@ -1285,13 +1178,7 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1300,14 +1187,7 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1323,11 +1203,7 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return ys def fast_beam_search_with_nbest_rescoring( @@ -1347,8 +1223,7 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: +) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1390,13 +1265,10 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1474,18 +1346,16 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + ans: Dict[str, List[List[int]]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) + ans[key] = hyps return ans @@ -1509,8 +1379,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: +) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1556,13 +1425,10 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1674,45 +1540,151 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) + ans[key] = hyps return ans +def modified_beam_search_sf_rnnlm( + model: Transducer, + encoder_out: torch.Tensor, + sp, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +): + encoder_out = model.joiner.encoder_proj(encoder_out) + lm_scale = rnnlm_scale -def modified_beam_search_ngram_rescoring( + assert rnnlm is not None + assert encoder_out.ndim == 2, encoder_out.shape + rnnlm.clean_cache() + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + T = encoder_out.shape[0] + for t in range(T): + current_encoder_out = encoder_out[t : t + 1] + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyp in A] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + # decoder_out is of shape (num_hyps, joiner_dim) + current_encoder_out = current_encoder_out.repeat(len(A), 1) + # current_encoder_out is of shape (num_hyps, encoder_out_dim) + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk( + beam + ) # get topk tokens and scores + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[hyp_idx] # get hyp + new_ys = hyp.ys[:] + state = "ys=" + "+".join(list(map(str, new_ys))) + tokens = k2.RaggedTensor([new_ys[context_size:]]) + + lm_score = rnnlm.predict( + tokens, state, sos_id, eos_id, blank_id + ) # get rnnlm score + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] # get token + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + # state_cost = hyp.state_cost.forward_one_step(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + else: + new_ys = new_ys + new_log_prob = hyp_log_prob + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + return best_hyp.ys[context_size:] + +def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, + sp: spm.SentencePieceProcessor, + rnnlm: RnnLmModel, + rnnlm_scale: float, beam: int = 4, - temperature: float = 1.0, ) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + """Modified_beam_search + RNNLM shallow fusion Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + rnnlm (RnnLmModel): + RNNLM + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, lengths=encoder_out_lens.cpu(), @@ -1721,34 +1693,41 @@ def modified_beam_search_ngram_rescoring( ) blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - lm_scale = ngram_lm_scale - + batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + B = [HypothesisList() for _ in range(N)] for i in range(N): B[i].add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), + state=init_states, + lm_score=init_score.reshape(-1) ) ) + rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - + offset = 0 finalized_B = [] for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] + current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -1760,49 +1739,44 @@ def modified_beam_search_ngram_rescoring( A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - + ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyps in A for hyp in hyps], device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, 1, 1, encoder_out_dim) - + logits = model.joiner( current_encoder_out, decoder_out, project_input=False, ) # (num_hyps, 1, 1, vocab_size) - + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( + log_probs = logits.log_softmax( dim=-1 ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) + log_probs = log_probs.reshape(-1) + row_splits = hyps_shape.row_splits(1) * vocab_size log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() @@ -1810,7 +1784,12 @@ def modified_beam_search_ngram_rescoring( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) + + # for all hyps with a non-blank new token, score it + token_list = [] + hs = [] + cs = [] for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1818,28 +1797,63 @@ def modified_beam_search_ngram_rescoring( warnings.simplefilter("ignore") topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + assert new_token != 0, new_token + token_list.append([new_token]) + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) - + + ys.append(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + + lm_score = scores[count] + state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) + count += 1 + new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score ) - B[i].add(new_hyp) + B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] @@ -1850,4 +1864,4 @@ def modified_beam_search_ngram_rescoring( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans + return ans \ No newline at end of file diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 88b2cc41f..2552f65a6 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -18,8 +18,9 @@ import logging import torch import torch.nn.functional as F +import k2 -from icefall.utils import make_pad_mask +from icefall.utils import add_eos, add_sos, make_pad_mask class RnnLmModel(torch.nn.Module): @@ -72,6 +73,8 @@ class RnnLmModel(torch.nn.Module): else: logging.info("Not tying weights") + self.cache = {} + def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: @@ -118,3 +121,124 @@ class RnnLmModel(torch.nn.Module): nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss + + def get_init_states(self, sos): + p = next(self.parameters()) + + def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): + device = next(self.parameters()).device + batch_size = len(token_lens) + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64).to(device) + y_tokens = y_tokens.to(torch.int64).to(device) + sentence_lengths = sentence_lengths.to(torch.int64).to(device) + + embedding = self.input_embedding(x_tokens) + + # Note: We use batch_first==True + rnn_out, states = self.rnn(embedding) + logits = self.output_linear(rnn_out) + mask = torch.zeros(logits.shape).bool().to(device) + for i in range(batch_size): + mask[i, token_lens[i], :] = True + logits = logits[mask].reshape(batch_size, -1) + + return logits[:,:].log_softmax(-1), states + + def clean_cache(self): + self.cache = {} + + def score_token(self, tokens: torch.Tensor, state=None): + device = next(self.parameters()).device + batch_size = tokens.size(0) + if state: + h,c = state + else: + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) + + embedding = self.input_embedding(tokens) + rnn_out, states = self.rnn(embedding, (h,c)) + logits = self.output_linear(rnn_out) + + return logits[:,0].log_softmax(-1), states + + def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): + batch_size = len(token_lens) + if state: + h,c = state + else: + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + + device = next(self.parameters()).device + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64).to(device) + y_tokens = y_tokens.to(torch.int64).to(device) + sentence_lengths = sentence_lengths.to(torch.int64).to(device) + + embedding = self.input_embedding(x_tokens) + + # Note: We use batch_first==True + rnn_out, states = self.rnn(embedding, (h,c)) + logits = self.output_linear(rnn_out) + + return logits, states + +if __name__=="__main__": + LM = RnnLmModel(500, 2048, 2048, 3, True) + h0 = torch.zeros(3, 1, 2048) + c0 = torch.zeros(3, 1, 2048) + seq = [[0,1,2,3]] + seq_lens = [len(s) for s in seq] + tokens = k2.RaggedTensor(seq) + output1, state = LM.forward_with_state( + tokens, + seq_lens, + 1, + 1, + 0, + state=(h0,c0) + ) + seq = [[0,1,2,3,4]] + seq_lens = [len(s) for s in seq] + tokens = k2.RaggedTensor(seq) + output2, _ = LM.forward_with_state( + tokens, + seq_lens, + 1, + 1, + 0, + state=(h0,c0) + ) + + seq = [[4]] + seq_lens = [len(s) for s in seq] + output3 = LM.score_token(seq, seq_lens, state) + + print("Finished") + + + From 63d0a52dbd703a0c1692b7ac9f4557fcb0e85df8 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 16:37:29 +0800 Subject: [PATCH 02/14] support RNNLM shallow fusion in stateless5 --- .../beam_search.py | 3 - .../pruned_transducer_stateless5/decode.py | 182 ++++++++++++------ 2 files changed, 124 insertions(+), 61 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 01cc566e8..d569b0752 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -23,7 +23,6 @@ import sentencepiece as spm import torch from model import Transducer -from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.rnn_lm.model import RnnLmModel from icefall.utils import add_eos, add_sos, get_texts @@ -658,8 +657,6 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - state_cost: Optional[NgramLmStateCost] = None - state: Optional = None lm_score: Optional=None @property diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 632932214..59c646717 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -19,36 +19,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -56,10 +56,10 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 30 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -69,10 +69,10 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -82,10 +82,10 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ @@ -115,6 +115,7 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_rnnlm_shallow_fusion, ) from train import add_model_arguments, get_params, get_transducer_model @@ -125,6 +126,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -183,7 +185,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless5/exp", + default="lstm_transducer_stateless2/exp", help="The experiment dir", ) @@ -213,6 +215,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified-beam-search3 # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -240,16 +243,6 @@ def get_parser(): """, ) - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - parser.add_argument( "--max-contexts", type=int, @@ -275,6 +268,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -302,28 +296,69 @@ def get_parser(): ) parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search3. + It specifies the path to RNN LM exp dir. """, ) parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, ) parser.add_argument( - "--left-context", + "--rnn-lm-epoch", type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, ) + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) add_model_arguments(parser) return parser @@ -336,6 +371,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -361,7 +398,7 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -474,12 +511,21 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -523,7 +569,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -538,7 +586,7 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -564,6 +612,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + logging.info(f"Decoding {batch_idx}-th batch") hyps_dict = decode_one_batch( params=params, @@ -572,6 +621,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -597,7 +648,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -657,6 +708,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_sf_rnnlm", ) params.res_dir = params.exp_dir / params.decoding_method @@ -665,10 +717,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -686,6 +734,8 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -706,11 +756,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") @@ -796,6 +841,25 @@ def main(): model.to(device) model.eval() + rnn_lm_model = None + rnn_lm_scale = params.rnn_lm_scale + if params.decoding_method == "modified_beam_search3": + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -839,6 +903,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results( From 86662f0b97622c1367fff5e9f974f96ac3874ccf Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:24:53 +0800 Subject: [PATCH 03/14] update results --- egs/librispeech/ASR/RESULTS.md | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 92323a556..57dd9f230 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -101,6 +101,7 @@ The WERs are: |-------------------------------------|------------|------------|-------------------------| | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -155,6 +156,27 @@ for m in greedy_search fast_beam_search modified_beam_search; do done ``` +To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM +can be found here: + +for iter in 472000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + done +done + Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1311,6 +1333,7 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di |-------------------------------------|------------|------------|-----------------------------------------| | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | ```bash @@ -1356,6 +1379,36 @@ for method in greedy_search modified_beam_search fast_beam_search; do done ``` +To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM +can be found here: + +```bash +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless5/decode.py \ + --epoch 30 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless5/exp-B \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --max-sym-per-frame 1 \ + --num-encoder-layers 24 \ + --dim-feedforward 1536 \ + --nhead 8 \ + --encoder-dim 384 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --use-averaged-model True + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 +done +``` + You can find a pretrained model, training logs, decoding logs, and decoding results at: From 0a46a39e24a687487eeab3396d35fd395b156a0c Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:25:31 +0800 Subject: [PATCH 04/14] update decoding commands --- .../ASR/lstm_transducer_stateless2/decode.py | 33 +++--- .../beam_search.py | 105 +----------------- .../pruned_transducer_stateless5/decode.py | 95 +++++++++++----- 3 files changed, 88 insertions(+), 145 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1d46c0177..c43328e08 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -91,6 +91,21 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search (with RNNLM shallow fusion) +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 """ @@ -121,7 +136,6 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -389,8 +403,6 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: @@ -526,11 +538,12 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_sf_rnnlm": - hyp_tokens = modified_beam_search_sf_rnnlm_batched( + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + beam=params.beam_size, sp=sp, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, @@ -586,9 +599,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - rnnlm: Optional[NgramLm] = None, + rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -642,8 +653,6 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) @@ -731,7 +740,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_sf_rnnlm", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -942,8 +951,6 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=params.ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index d569b0752..e454bc1a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -656,6 +657,7 @@ class Hypothesis: # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor + state: Optional=None lm_score: Optional=None @@ -1542,107 +1544,6 @@ def fast_beam_search_with_nbest_rnn_rescoring( ans[key] = hyps return ans - -def modified_beam_search_sf_rnnlm( - model: Transducer, - encoder_out: torch.Tensor, - sp, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, -): - encoder_out = model.joiner.encoder_proj(encoder_out) - lm_scale = rnnlm_scale - - assert rnnlm is not None - assert encoder_out.ndim == 2, encoder_out.shape - rnnlm.clean_cache() - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - T = encoder_out.shape[0] - for t in range(T): - current_encoder_out = encoder_out[t : t + 1] - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyp in A] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - # decoder_out is of shape (num_hyps, joiner_dim) - current_encoder_out = current_encoder_out.repeat(len(A), 1) - # current_encoder_out is of shape (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk( - beam - ) # get topk tokens and scores - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[hyp_idx] # get hyp - new_ys = hyp.ys[:] - state = "ys=" + "+".join(list(map(str, new_ys))) - tokens = k2.RaggedTensor([new_ys[context_size:]]) - - lm_score = rnnlm.predict( - tokens, state, sos_id, eos_id, blank_id - ) # get rnnlm score - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] # get token - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - # state_cost = hyp.state_cost.forward_one_step(new_token) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score - else: - new_ys = new_ys - new_log_prob = hyp_log_prob - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - return best_hyp.ys[context_size:] def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 59c646717..8c69cfd6e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,47 +20,43 @@ """ Usage: (1) greedy search -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method greedy_search - (2) beam search (not recommended) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 - (3) modified beam search -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 - (4) fast beam search (one best) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 - (5) fast beam search (nbest) -./lstm_transducer_stateless2/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -67,12 +64,11 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 - (6) fast beam search (nbest oracle WER) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -80,17 +76,34 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 - (7) fast beam search (with LG) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search with RNNLM shallow fusion (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + + """ @@ -243,6 +256,16 @@ def get_parser(): """, ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -294,6 +317,15 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) parser.add_argument( "--rnn-lm-scale", @@ -517,6 +549,9 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -708,7 +743,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_sf_rnnlm", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -843,7 +878,7 @@ def main(): rnn_lm_model = None rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search3": + if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, From babcfd4b68a0f6729161eb1aa0c10e2c2aea2764 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:27:31 +0800 Subject: [PATCH 05/14] update author info --- .../ASR/lstm_transducer_stateless2/decode.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index c43328e08..fc077f062 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,7 +92,7 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 - + (8) modified beam search (with RNNLM shallow fusion) ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ @@ -105,7 +106,7 @@ Usage: --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 + --rnn-lm-tie-weights 1 """ @@ -131,7 +132,6 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_rnnlm_shallow_fusion, - ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model @@ -386,11 +386,7 @@ def get_parser(): last output linear layer """, ) - parser.add_argument( - "--ilm-scale", - type=float, - default=-0.1 - ) + parser.add_argument("--ilm-scale", type=float, default=-0.1) add_model_arguments(parser) return parser @@ -642,9 +638,13 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") + total_duration = sum( + [cut.duration for cut in batch["supervisions"]["cut"]] + ) + + logging.info( + f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}" + ) hyps_dict = decode_one_batch( params=params, @@ -765,10 +765,10 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - + if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - + if "ILME" in params.decoding_method: params.suffix += f"-ILME-scale={params.ilm_scale}" @@ -903,7 +903,7 @@ def main(): ) rnn_lm_model.to(device) rnn_lm_model.eval() - + else: rnn_lm_model = None From 6c8d1f9ef5feb448565b68a533689794eb83548c Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:48:58 +0800 Subject: [PATCH 06/14] update --- .../beam_search.py | 523 ++++++++++++++---- 1 file changed, 417 insertions(+), 106 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index e454bc1a6..7c5a5ace4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -17,16 +17,23 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import k2 import sentencepiece as spm import torch from model import Transducer +from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import add_eos, add_sos, get_texts +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) def fast_beam_search_one_best( @@ -38,7 +45,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -62,8 +70,12 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -77,8 +89,11 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_LG( @@ -93,7 +108,8 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -130,8 +146,12 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -196,9 +216,10 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest( @@ -213,7 +234,8 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -250,8 +272,12 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -280,9 +306,10 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_oracle( @@ -298,7 +325,8 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -339,8 +367,12 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -379,8 +411,10 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search( @@ -470,8 +504,11 @@ def fast_beam_search( def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """Greedy search for a single utterance. Args: model: @@ -481,8 +518,12 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -508,6 +549,10 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -534,6 +579,7 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) + timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -548,14 +594,21 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - return hyp + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -565,9 +618,12 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -592,6 +648,10 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + decoder_input = torch.tensor( hyps, device=device, @@ -605,7 +665,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -627,6 +687,7 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) + timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -641,11 +702,19 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) @dataclass @@ -657,9 +726,12 @@ class Hypothesis: # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor - state: Optional=None - lm_score: Optional=None + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] + + state_cost: Optional[NgramLmStateCost] = None @property def key(self) -> str: @@ -808,7 +880,8 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -823,9 +896,12 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -843,7 +919,7 @@ def modified_beam_search( device = next(model.parameters()).device batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -853,6 +929,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) @@ -860,7 +937,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -938,30 +1015,44 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -976,8 +1067,13 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -997,6 +1093,7 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1055,17 +1152,24 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1073,7 +1177,8 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1088,8 +1193,13 @@ def beam_search( Beam size. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -1116,7 +1226,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) max_sym_per_utt = 20000 @@ -1177,7 +1287,13 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1186,7 +1302,14 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1202,7 +1325,11 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( @@ -1222,7 +1349,8 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1264,10 +1392,13 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1345,16 +1476,18 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, List[List[int]]] = {} + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans @@ -1378,7 +1511,8 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1424,10 +1558,13 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1539,12 +1676,185 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans - + + +def modified_beam_search_ngram_rescoring( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, @@ -1559,18 +1869,18 @@ def modified_beam_search_rnnlm_shallow_fusion( Args: model (Transducer): The transducer model - encoder_out (torch.Tensor): + encoder_out (torch.Tensor): Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: + sp: Sentence piece generator. - rnnlm (RnnLmModel): + rnnlm (RnnLmModel): RNNLM - rnnlm_scale (float): + rnnlm_scale (float): scale of RNNLM in shallow fusion - beam (int, optional): + beam (int, optional): Beam size. Defaults to 4. Returns: @@ -1582,7 +1892,7 @@ def modified_beam_search_rnnlm_shallow_fusion( assert rnnlm is not None lm_scale = rnnlm_scale vocab_size = rnnlm.vocab_size - + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, lengths=encoder_out_lens.cpu(), @@ -1592,20 +1902,19 @@ def modified_beam_search_rnnlm_shallow_fusion( blank_id = model.decoder.blank_id sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - + batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) init_score, init_states = rnnlm.score_token(sos_token) - + B = [HypothesisList() for _ in range(N)] for i in range(N): B[i].add( @@ -1613,19 +1922,19 @@ def modified_beam_search_rnnlm_shallow_fusion( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, - lm_score=init_score.reshape(-1) + lm_score=init_score.reshape(-1), ) ) rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - + offset = 0 finalized_B = [] for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -1637,44 +1946,42 @@ def modified_beam_search_rnnlm_shallow_fusion( A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - + ys_log_probs = torch.cat( [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] ) - + decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyps in A for hyp in hyps], device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) - + current_encoder_out = torch.index_select( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, 1, 1, encoder_out_dim) - + logits = model.joiner( current_encoder_out, decoder_out, project_input=False, ) # (num_hyps, 1, 1, vocab_size) - + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = logits.log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) - + vocab_size = log_probs.size(-1) log_probs = log_probs.reshape(-1) - + row_splits = hyps_shape.row_splits(1) * vocab_size log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() @@ -1682,7 +1989,6 @@ def modified_beam_search_rnnlm_shallow_fusion( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) - # for all hyps with a non-blank new token, score it token_list = [] @@ -1698,7 +2004,7 @@ def modified_beam_search_rnnlm_shallow_fusion( for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - + new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): @@ -1708,13 +2014,18 @@ def modified_beam_search_rnnlm_shallow_fusion( cs.append(hyp.state[1]) # forward RNNLM to get new states and scores if len(token_list) != 0: - tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) - + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) - - count = 0 # index, used to locate score and lm states + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1722,36 +2033,36 @@ def modified_beam_search_rnnlm_shallow_fusion( warnings.simplefilter("ignore") topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() - + for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] ys = hyp.ys[:] - + lm_score = hyp.lm_score state = hyp.state - + hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - + ys.append(new_token) hyp_log_prob += ( lm_score[new_token] * lm_scale ) # add the lm score - + lm_score = scores[count] - state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 - + new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score + ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score ) - B[i].add(new_hyp) + B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] @@ -1762,4 +2073,4 @@ def modified_beam_search_rnnlm_shallow_fusion( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans From 9a01b9098deb56c9c4b048c000b4eead756c98f5 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:03:56 +0800 Subject: [PATCH 07/14] include previous added decoding method --- .../ASR/lstm_transducer_stateless2/decode.py | 65 +++++++++++++++---- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index fc077f062..20a5ebd8b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -131,11 +131,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_ngram_rescoring, modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model +from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -232,6 +234,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. @@ -386,7 +389,23 @@ def get_parser(): last output linear layer """, ) - parser.add_argument("--ilm-scale", type=float, default=-0.1) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -399,6 +418,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: @@ -534,6 +555,17 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( model=model, @@ -595,9 +627,11 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -638,13 +672,6 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration = sum( - [cut.duration for cut in batch["supervisions"]["cut"]] - ) - - logging.info( - f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}" - ) hyps_dict = decode_one_batch( params=params, @@ -653,6 +680,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) @@ -680,7 +709,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -740,6 +769,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_ngram_rescoring", "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -765,13 +795,10 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - if "ILME" in params.decoding_method: - params.suffix += f"-ILME-scale={params.ilm_scale}" - if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -884,6 +911,14 @@ def main(): model.to(device) model.eval() + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") # only load rnnlm if used if "rnnlm" in params.decoding_method: rnn_lm_scale = params.rnn_lm_scale @@ -951,6 +986,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=params.ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) From fb45b95c901de33562d76c277232464fb42bb2bd Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:11:39 +0800 Subject: [PATCH 08/14] minor fixes --- .../pruned_transducer_stateless5/decode.py | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8c69cfd6e..8ba36e582 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -86,7 +86,7 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 - + (8) modified beam search with RNNLM shallow fusion (with LG) ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ @@ -103,7 +103,7 @@ Usage: --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 - + """ @@ -198,7 +198,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="lstm_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) @@ -228,7 +228,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified-beam-search3 # for rnn lm shallow fusion + - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -265,7 +265,21 @@ def get_parser(): It specifies the scale for n-gram LM scores. """, ) - + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + parser.add_argument( "--max-contexts", type=int, @@ -317,7 +331,7 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - + parser.add_argument( "--simulate-streaming", type=str2bool, @@ -331,7 +345,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -430,7 +444,7 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -560,7 +574,7 @@ def decode_one_batch( for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -606,7 +620,7 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -683,7 +697,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -751,7 +765,9 @@ def main(): params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -791,6 +807,11 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") From b62fd917ae54fb0305a3f4fac931d850bfe231c1 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:17:05 +0800 Subject: [PATCH 09/14] remove redundant test lines --- icefall/rnn_lm/model.py | 88 ++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 59 deletions(-) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 2552f65a6..a6144727a 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -18,7 +18,6 @@ import logging import torch import torch.nn.functional as F -import k2 from icefall.utils import add_eos, add_sos, make_pad_mask @@ -121,9 +120,6 @@ class RnnLmModel(torch.nn.Module): nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss - - def get_init_states(self, sos): - p = next(self.parameters()) def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): device = next(self.parameters()).device @@ -153,35 +149,45 @@ class RnnLmModel(torch.nn.Module): for i in range(batch_size): mask[i, token_lens[i], :] = True logits = logits[mask].reshape(batch_size, -1) - - return logits[:,:].log_softmax(-1), states - + + return logits[:, :].log_softmax(-1), states + def clean_cache(self): self.cache = {} - + def score_token(self, tokens: torch.Tensor, state=None): device = next(self.parameters()).device batch_size = tokens.size(0) if state: - h,c = state + h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) - - embedding = self.input_embedding(tokens) - rnn_out, states = self.rnn(embedding, (h,c)) - logits = self.output_linear(rnn_out) - - return logits[:,0].log_softmax(-1), states + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) - def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): + embedding = self.input_embedding(tokens) + rnn_out, states = self.rnn(embedding, (h, c)) + logits = self.output_linear(rnn_out) + + return logits[:, 0].log_softmax(-1), states + + def forward_with_state( + self, tokens, token_lens, sos_id, eos_id, blank_id, state=None + ): batch_size = len(token_lens) if state: - h,c = state + h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + device = next(self.parameters()).device sos_tokens = add_sos(tokens, sos_id) @@ -202,43 +208,7 @@ class RnnLmModel(torch.nn.Module): embedding = self.input_embedding(x_tokens) # Note: We use batch_first==True - rnn_out, states = self.rnn(embedding, (h,c)) + rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) return logits, states - -if __name__=="__main__": - LM = RnnLmModel(500, 2048, 2048, 3, True) - h0 = torch.zeros(3, 1, 2048) - c0 = torch.zeros(3, 1, 2048) - seq = [[0,1,2,3]] - seq_lens = [len(s) for s in seq] - tokens = k2.RaggedTensor(seq) - output1, state = LM.forward_with_state( - tokens, - seq_lens, - 1, - 1, - 0, - state=(h0,c0) - ) - seq = [[0,1,2,3,4]] - seq_lens = [len(s) for s in seq] - tokens = k2.RaggedTensor(seq) - output2, _ = LM.forward_with_state( - tokens, - seq_lens, - 1, - 1, - 0, - state=(h0,c0) - ) - - seq = [[4]] - seq_lens = [len(s) for s in seq] - output3 = LM.score_token(seq, seq_lens, state) - - print("Finished") - - - From e3f218b62b13408e4688129efd12acb182077bf6 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:10:23 +0800 Subject: [PATCH 10/14] Update egs/librispeech/ASR/lstm_transducer_stateless2/decode.py Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/lstm_transducer_stateless2/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 20a5ebd8b..ac17da207 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -329,7 +329,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) From 2a52b8c125019feb305275b4e356ea5969a35046 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 3 Nov 2022 11:10:21 +0800 Subject: [PATCH 11/14] update docs --- .../ASR/lstm_transducer_stateless2/decode.py | 35 +++++++++++-------- .../beam_search.py | 25 ++++++++++--- .../pruned_transducer_stateless5/decode.py | 8 ++--- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 20a5ebd8b..40a0d5bf7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -235,7 +235,7 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -329,7 +329,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -338,7 +338,7 @@ def get_parser(): "--rnn-lm-exp-dir", type=str, default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -347,7 +347,7 @@ def get_parser(): "--rnn-lm-epoch", type=int, default=7, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the checkpoint to use. """, ) @@ -356,7 +356,7 @@ def get_parser(): "--rnn-lm-avg", type=int, default=2, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the number of checkpoints to average. """, ) @@ -911,14 +911,20 @@ def main(): model.to(device) model.eval() - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") + # only load N-gram LM when needed + if "ngram" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + else: + ngram_lm = None + ngram_lm_scale = None + # only load rnnlm if used if "rnnlm" in params.decoding_method: rnn_lm_scale = params.rnn_lm_scale @@ -941,6 +947,7 @@ def main(): else: rnn_lm_model = None + rnn_lm_scale = 0.0 if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -987,7 +994,7 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, ngram_lm=ngram_lm, - ngram_lm_scale=params.ngram_lm_scale, + ngram_lm_scale=ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7c5a5ace4..480146a59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -17,7 +17,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import k2 import sentencepiece as spm @@ -729,8 +729,15 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] + timestamp: List[int] = None + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state state_cost: Optional[NgramLmStateCost] = None @property @@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) - - # for all hyps with a non-blank new token, score it + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + The RNNLM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ token_list = [] hs = [] cs = [] @@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token token_list.append([new_token]) + # store the LSTM states hs.append(hyp.state[0]) cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8ba36e582..2711c4cc9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -228,7 +228,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -354,7 +354,7 @@ def get_parser(): "--rnn-lm-exp-dir", type=str, default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -363,7 +363,7 @@ def get_parser(): "--rnn-lm-epoch", type=int, default=7, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the checkpoint to use. """, ) @@ -372,7 +372,7 @@ def get_parser(): "--rnn-lm-avg", type=int, default=2, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the number of checkpoints to average. """, ) From 0df597291f71bd9c22f28b7482a9ff636dfab351 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:17:56 +0800 Subject: [PATCH 12/14] resolve conflict with timestamp feature --- egs/librispeech/ASR/beam_search.py | 1821 ++++++++++++++++++++++++++++ 1 file changed, 1821 insertions(+) create mode 100644 egs/librispeech/ASR/beam_search.py diff --git a/egs/librispeech/ASR/beam_search.py b/egs/librispeech/ASR/beam_search.py new file mode 100644 index 000000000..cc5c1c09d --- /dev/null +++ b/egs/librispeech/ASR/beam_search.py @@ -0,0 +1,1821 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import k2 +import sentencepiece as spm +import torch +from model import Transducer + +from icefall import NgramLm, NgramLmStateCost +from icefall.decode import Nbest, one_best_decoding +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) + + +def fast_beam_search_one_best( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_LG( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_oracle( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = (logits / temperature).log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + return lattice + + +def greedy_search( + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + y = logits.argmax().item() + if y not in (blank_id, unk_id): + hyp.append(y) + timestamp.append(t) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) + + +def greedy_search_batch( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + state: Optional = None + + lm_score: Optional = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + temperature: + Softmax temperature. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = (logits / temperature).log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i in (blank_id, unk_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + return ys + + +def fast_beam_search_with_nbest_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, List[List[int]]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans + + +def modified_beam_search_rnnlm_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +) -> List[List[int]]: + """Modified_beam_search + RNNLM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + rnnlm (RnnLmModel): + RNNLM + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + ) + ) + + rnnlm.clean_cache() + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + # for all hyps with a non-blank new token, score it + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + assert new_token != 0, new_token + token_list.append([new_token]) + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + + lm_score = scores[count] + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans From bdaeaae1ae33021fadd1f8e9b6bbe45f1bf5ae00 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:25:10 +0800 Subject: [PATCH 13/14] resolve conflicts --- .../beam_search.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 480146a59..b1fd75204 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union import k2 @@ -729,7 +729,7 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] = None + timestamp: List[int] = field(default_factory=list) # the lm score for next token given the current ys lm_score: Optional[torch.Tensor] = None @@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion( rnnlm: RnnLmModel, rnnlm_scale: float, beam: int = 4, + return_timestamps: bool = False, ) -> List[List[int]]: """Modified_beam_search + RNNLM shallow fusion @@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1), + timestamp=[], ) ) @@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion( hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): ys.append(new_token) + new_timestamp.append(t) hyp_log_prob += ( lm_score[new_token] * lm_scale ) # add the lm score @@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion( count += 1 new_hyp = Hypothesis( - ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestampe=new_timestamp, ) B[i].add(new_hyp) @@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion( best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) From b3c61b85e3340f5d5f68c3f09659e5a05d052665 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:32:09 +0800 Subject: [PATCH 14/14] minor fixes --- egs/librispeech/ASR/pruned_transducer_stateless5/decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 2711c4cc9..96aa66c29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -636,8 +636,8 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used.