mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add greedy_search_batch
This commit is contained in:
parent
51812b2ddc
commit
59daf83b59
@ -659,6 +659,8 @@ def greedy_search_batch(
|
|||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
|
||||||
N = encoder_out.size(0)
|
N = encoder_out.size(0)
|
||||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
assert N == batch_size_list[0], (N, batch_size_list)
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
@ -675,7 +677,9 @@ def greedy_search_batch(
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
) # (N, context_size)
|
) # (N, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
k = torch.zeros(N, 1, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
# decoder_out: (N, 1, decoder_out_dim)
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
|
|
||||||
@ -709,18 +713,32 @@ def greedy_search_batch(
|
|||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
|
||||||
|
c = torch.tensor(
|
||||||
|
[h[-context_size - 1 :] for h in hyps[:batch_size]],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
k[:, 0] = torch.sum(
|
||||||
|
(
|
||||||
|
c[:, -context_size - 1 : -1]
|
||||||
|
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
|
||||||
|
).int(),
|
||||||
|
dim=1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
sorted_ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
ans = []
|
ans = []
|
||||||
ans_timestamps = []
|
ans_timestamps = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user