From bdaeaae1ae33021fadd1f8e9b6bbe45f1bf5ae00 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:25:10 +0800 Subject: [PATCH] resolve conflicts --- .../beam_search.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 480146a59..b1fd75204 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union import k2 @@ -729,7 +729,7 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] = None + timestamp: List[int] = field(default_factory=list) # the lm score for next token given the current ys lm_score: Optional[torch.Tensor] = None @@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion( rnnlm: RnnLmModel, rnnlm_scale: float, beam: int = 4, + return_timestamps: bool = False, ) -> List[List[int]]: """Modified_beam_search + RNNLM shallow fusion @@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1), + timestamp=[], ) ) @@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion( hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): ys.append(new_token) + new_timestamp.append(t) hyp_log_prob += ( lm_score[new_token] * lm_scale ) # add the lm score @@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion( count += 1 new_hyp = Hypothesis( - ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestampe=new_timestamp, ) B[i].add(new_hyp) @@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion( best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + )