From 18f9a1d319b40ab26943335d3899dd31436bfc35 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 28 Jul 2023 09:24:59 +0800 Subject: [PATCH] a runnable version of decoding scripts --- .../ASR/pruned_transducer_stateless2/beam_search.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 13509021e..3a8620226 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -823,10 +823,10 @@ def deprecated_greedy_search_batch( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), project_input=False ) - print(current_encoder_out) - print(decoder_out.unsqueeze(1)) - print(logits) - exit() + # print(current_encoder_out) + # print(decoder_out.unsqueeze(1)) + # print(logits) + # exit() # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) assert logits.ndim == 2, logits.shape @@ -878,7 +878,7 @@ def deprecated_greedy_search_batch_for_cross_attn( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(batch_size)] + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(batch_size)] decoder_input = torch.tensor( hyps, @@ -893,6 +893,7 @@ def deprecated_greedy_search_batch_for_cross_attn( # encoder_out_for_attn = encoder_out.unsqueeze(2) # decoder_out: (batch_size, 1, decoder_out_dim) + # emitted = False for t in range(T): current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) @@ -906,7 +907,7 @@ def deprecated_greedy_search_batch_for_cross_attn( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), - torch.zeros_like(current_encoder_out), + attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out), None, apply_attn=False, project_input=False,