Add greedy_search_batch

This commit is contained in:
yifanyang 2023-02-06 16:57:46 +08:00
parent 51812b2ddc
commit 59daf83b59

View File

@ -659,6 +659,8 @@ def greedy_search_batch(
context_size = model.decoder.context_size
batch_size_list = packed_encoder_out.batch_sizes.tolist()
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
@ -675,7 +677,9 @@ def greedy_search_batch(
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
k = torch.zeros(N, 1, device=device, dtype=torch.int64)
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out: (N, 1, decoder_out_dim)
@ -709,18 +713,32 @@ def greedy_search_batch(
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
c = torch.tensor(
[h[-context_size - 1 :] for h in hyps[:batch_size]],
device=device,
)
k[:, 0] = torch.sum(
(
c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
).int(),
dim=1,
keepdim=True,
)
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])