diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index fd59d4b7f..97e259b40 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1008,7 +1008,7 @@ def modified_beam_search( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), context_state=None if context_graph is None else context_graph.root, timestamp=[], @@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search( B = HypothesisList() B.add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1753,7 +1753,11 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) max_sym_per_utt = 20000 @@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state_cost=NgramLmStateCost(ngram_lm), ) @@ -2446,7 +2450,7 @@ def modified_beam_search_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), @@ -2709,7 +2713,7 @@ def modified_beam_search_lm_shallow_fusion( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1),