diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py index 5b1352b63..7a39c2e53 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py @@ -659,6 +659,8 @@ def greedy_search_batch( context_size = model.decoder.context_size batch_size_list = packed_encoder_out.batch_sizes.tolist() + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -675,7 +677,9 @@ def greedy_search_batch( dtype=torch.int64, ) # (N, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False) + k = torch.zeros(N, 1, device=device, dtype=torch.int64) + + decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) # decoder_out: (N, 1, decoder_out_dim) @@ -709,18 +713,32 @@ def greedy_search_batch( if emitted: # update decoder output decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + + c = torch.tensor( + [h[-context_size - 1 :] for h in hyps[:batch_size]], + device=device, + ) + + k[:, 0] = torch.sum( + ( + c[:, -context_size - 1 : -1] + == c[:, -1].expand_as(c[:, -context_size - 1 : -1]) + ).int(), + dim=1, + keepdim=True, + ) + decoder_input = torch.tensor( decoder_input, device=device, dtype=torch.int64, ) - decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) sorted_ans = [h[context_size:] for h in hyps] ans = [] ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) ans_timestamps.append(timestamps[unsorted_indices[i]])