diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index e454bc1a6..7c5a5ace4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -17,16 +17,23 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import k2 import sentencepiece as spm import torch from model import Transducer +from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import add_eos, add_sos, get_texts +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) def fast_beam_search_one_best( @@ -38,7 +45,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -62,8 +70,12 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -77,8 +89,11 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_LG( @@ -93,7 +108,8 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -130,8 +146,12 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -196,9 +216,10 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest( @@ -213,7 +234,8 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -250,8 +272,12 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -280,9 +306,10 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_oracle( @@ -298,7 +325,8 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -339,8 +367,12 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -379,8 +411,10 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search( @@ -470,8 +504,11 @@ def fast_beam_search( def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """Greedy search for a single utterance. Args: model: @@ -481,8 +518,12 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -508,6 +549,10 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -534,6 +579,7 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) + timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -548,14 +594,21 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - return hyp + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -565,9 +618,12 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -592,6 +648,10 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + decoder_input = torch.tensor( hyps, device=device, @@ -605,7 +665,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -627,6 +687,7 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) + timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -641,11 +702,19 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) @dataclass @@ -657,9 +726,12 @@ class Hypothesis: # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor - state: Optional=None - lm_score: Optional=None + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] + + state_cost: Optional[NgramLmStateCost] = None @property def key(self) -> str: @@ -808,7 +880,8 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -823,9 +896,12 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -843,7 +919,7 @@ def modified_beam_search( device = next(model.parameters()).device batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -853,6 +929,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) @@ -860,7 +937,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -938,30 +1015,44 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -976,8 +1067,13 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -997,6 +1093,7 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1055,17 +1152,24 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1073,7 +1177,8 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1088,8 +1193,13 @@ def beam_search( Beam size. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -1116,7 +1226,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) max_sym_per_utt = 20000 @@ -1177,7 +1287,13 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1186,7 +1302,14 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1202,7 +1325,11 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( @@ -1222,7 +1349,8 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1264,10 +1392,13 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1345,16 +1476,18 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, List[List[int]]] = {} + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans @@ -1378,7 +1511,8 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1424,10 +1558,13 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1539,12 +1676,185 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans - + + +def modified_beam_search_ngram_rescoring( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, @@ -1559,18 +1869,18 @@ def modified_beam_search_rnnlm_shallow_fusion( Args: model (Transducer): The transducer model - encoder_out (torch.Tensor): + encoder_out (torch.Tensor): Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: + sp: Sentence piece generator. - rnnlm (RnnLmModel): + rnnlm (RnnLmModel): RNNLM - rnnlm_scale (float): + rnnlm_scale (float): scale of RNNLM in shallow fusion - beam (int, optional): + beam (int, optional): Beam size. Defaults to 4. Returns: @@ -1582,7 +1892,7 @@ def modified_beam_search_rnnlm_shallow_fusion( assert rnnlm is not None lm_scale = rnnlm_scale vocab_size = rnnlm.vocab_size - + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, lengths=encoder_out_lens.cpu(), @@ -1592,20 +1902,19 @@ def modified_beam_search_rnnlm_shallow_fusion( blank_id = model.decoder.blank_id sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - + batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) init_score, init_states = rnnlm.score_token(sos_token) - + B = [HypothesisList() for _ in range(N)] for i in range(N): B[i].add( @@ -1613,19 +1922,19 @@ def modified_beam_search_rnnlm_shallow_fusion( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, - lm_score=init_score.reshape(-1) + lm_score=init_score.reshape(-1), ) ) rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - + offset = 0 finalized_B = [] for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -1637,44 +1946,42 @@ def modified_beam_search_rnnlm_shallow_fusion( A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - + ys_log_probs = torch.cat( [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] ) - + decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyps in A for hyp in hyps], device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) - + current_encoder_out = torch.index_select( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, 1, 1, encoder_out_dim) - + logits = model.joiner( current_encoder_out, decoder_out, project_input=False, ) # (num_hyps, 1, 1, vocab_size) - + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = logits.log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) - + vocab_size = log_probs.size(-1) log_probs = log_probs.reshape(-1) - + row_splits = hyps_shape.row_splits(1) * vocab_size log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() @@ -1682,7 +1989,6 @@ def modified_beam_search_rnnlm_shallow_fusion( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) - # for all hyps with a non-blank new token, score it token_list = [] @@ -1698,7 +2004,7 @@ def modified_beam_search_rnnlm_shallow_fusion( for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - + new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): @@ -1708,13 +2014,18 @@ def modified_beam_search_rnnlm_shallow_fusion( cs.append(hyp.state[1]) # forward RNNLM to get new states and scores if len(token_list) != 0: - tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) - + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) - - count = 0 # index, used to locate score and lm states + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1722,36 +2033,36 @@ def modified_beam_search_rnnlm_shallow_fusion( warnings.simplefilter("ignore") topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() - + for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] ys = hyp.ys[:] - + lm_score = hyp.lm_score state = hyp.state - + hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - + ys.append(new_token) hyp_log_prob += ( lm_score[new_token] * lm_scale ) # add the lm score - + lm_score = scores[count] - state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 - + new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score + ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score ) - B[i].add(new_hyp) + B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] @@ -1762,4 +2073,4 @@ def modified_beam_search_rnnlm_shallow_fusion( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans