From 9a47c08d085f00b63ce2d7c6d0fee16812691ed7 Mon Sep 17 00:00:00 2001 From: Erwan Zerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:10:50 +0200 Subject: [PATCH] Update padding modified beam search (#1217) --- .../beam_search.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index fd59d4b7f..97e259b40 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1008,7 +1008,7 @@ def modified_beam_search( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), context_state=None if context_graph is None else context_graph.root, timestamp=[], @@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search( B = HypothesisList() B.add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1753,7 +1753,11 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) max_sym_per_utt = 20000 @@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state_cost=NgramLmStateCost(ngram_lm), ) @@ -2446,7 +2450,7 @@ def modified_beam_search_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), @@ -2709,7 +2713,7 @@ def modified_beam_search_lm_shallow_fusion( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1),