From 70d603dc2845580d45fdbd795e4ac5bb2722c9d2 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 28 Jul 2023 17:25:28 +0800 Subject: [PATCH] aligned with the beam search approach for algn exp --- .../ASR/pruned_transducer_stateless2/beam_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 3a8620226..5d5d38947 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -804,7 +804,7 @@ def deprecated_greedy_search_batch( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(batch_size)] + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(batch_size)] decoder_input = torch.tensor( hyps, @@ -909,7 +909,7 @@ def deprecated_greedy_search_batch_for_cross_attn( decoder_out.unsqueeze(1), attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out), None, - apply_attn=False, + apply_attn=True, project_input=False, ) # logits'shape (batch_size, 1, 1, vocab_size)