Fix decoding.

This commit is contained in:
Fangjun Kuang 2022-01-06 06:17:53 +08:00
parent 7828c6ff73
commit 84d6224cad

View File

@ -66,6 +66,9 @@ def greedy_search(
# symbols per utterance decoded so far
sym_per_utt = 0
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
@ -75,7 +78,9 @@ def greedy_search(
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
)
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()