mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix decoding.
This commit is contained in:
parent
7828c6ff73
commit
84d6224cad
@ -66,6 +66,9 @@ def greedy_search(
|
||||
# symbols per utterance decoded so far
|
||||
sym_per_utt = 0
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
if sym_per_frame >= max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
@ -75,7 +78,9 @@ def greedy_search(
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# fmt: on
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user