diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 3a8620226..5d5d38947 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -804,7 +804,7 @@ def deprecated_greedy_search_batch( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(batch_size)] + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(batch_size)] decoder_input = torch.tensor( hyps, @@ -909,7 +909,7 @@ def deprecated_greedy_search_batch_for_cross_attn( decoder_out.unsqueeze(1), attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out), None, - apply_attn=False, + apply_attn=True, project_input=False, ) # logits'shape (batch_size, 1, 1, vocab_size)