Fix a bug in beam_search

This commit is contained in:
yifanyang 2023-02-06 17:40:01 +08:00
parent 59daf83b59
commit f43ea5db30

View File

@ -665,7 +665,7 @@ def greedy_search_batch(
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
hyps = [[-1] * context_size + [blank_id] for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
@ -675,7 +675,7 @@ def greedy_search_batch(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
)[:, -context_size : ] # (N, context_size)
k = torch.zeros(N, 1, device=device, dtype=torch.int64)
@ -717,12 +717,13 @@ def greedy_search_batch(
c = torch.tensor(
[h[-context_size - 1 :] for h in hyps[:batch_size]],
device=device,
dtype=torch.int64,
)
k[:, 0] = torch.sum(
(
c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
c[:, : context_size]
== c[:, context_size : context_size + 1].expand_as(c[:, : context_size])
).int(),
dim=1,
keepdim=True,
@ -736,7 +737,7 @@ def greedy_search_batch(
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
sorted_ans = [h[context_size + 1 :] for h in hyps]
ans = []
ans_timestamps = []
for i in range(N):