mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-28 03:04:19 +00:00
Fix decoding.
This commit is contained in:
parent
84d6224cad
commit
2084dba2f5
@ -267,6 +267,9 @@ def beam_search(
|
|||||||
|
|
||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
@ -299,7 +302,12 @@ def beam_search(
|
|||||||
|
|
||||||
cached_key += f"-t-{t}"
|
cached_key += f"-t-{t}"
|
||||||
if cached_key not in joint_cache:
|
if cached_key not in joint_cache:
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len,
|
||||||
|
decoder_out_len,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Ccale the blank posterior
|
# TODO(fangjun): Ccale the blank posterior
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user