From f43ea5db30a79962651d74fe7989c68703c81191 Mon Sep 17 00:00:00 2001 From: yifanyang Date: Mon, 6 Feb 2023 17:40:01 +0800 Subject: [PATCH] Fix a bug in beam_search --- .../ASR/pruned_transducer_stateless9/beam_search.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py index 7a39c2e53..181ad253e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py @@ -665,7 +665,7 @@ def greedy_search_batch( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] + hyps = [[-1] * context_size + [blank_id] for _ in range(N)] # timestamp[n][i] is the frame index after subsampling # on which hyp[n][i] is decoded @@ -675,7 +675,7 @@ def greedy_search_batch( hyps, device=device, dtype=torch.int64, - ) # (N, context_size) + )[:, -context_size : ] # (N, context_size) k = torch.zeros(N, 1, device=device, dtype=torch.int64) @@ -717,12 +717,13 @@ def greedy_search_batch( c = torch.tensor( [h[-context_size - 1 :] for h in hyps[:batch_size]], device=device, + dtype=torch.int64, ) k[:, 0] = torch.sum( ( - c[:, -context_size - 1 : -1] - == c[:, -1].expand_as(c[:, -context_size - 1 : -1]) + c[:, : context_size] + == c[:, context_size : context_size + 1].expand_as(c[:, : context_size]) ).int(), dim=1, keepdim=True, @@ -736,7 +737,7 @@ def greedy_search_batch( decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) - sorted_ans = [h[context_size:] for h in hyps] + sorted_ans = [h[context_size + 1 :] for h in hyps] ans = [] ans_timestamps = [] for i in range(N):