mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix decoding.
This commit is contained in:
parent
84d6224cad
commit
2084dba2f5
@ -267,6 +267,9 @@ def beam_search(
|
||||
|
||||
sym_per_utt = 0
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
@ -299,7 +302,12 @@ def beam_search(
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len,
|
||||
decoder_out_len,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user