Fix decoding.

This commit is contained in:
Fangjun Kuang 2022-01-08 19:10:49 +08:00
parent 84d6224cad
commit 2084dba2f5

View File

@ -267,6 +267,9 @@ def beam_search(
sym_per_utt = 0
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
@ -299,7 +302,12 @@ def beam_search(
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Ccale the blank posterior