mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Minor fixes.
This commit is contained in:
parent
33c0f8f7f6
commit
949b53274c
@ -262,7 +262,7 @@ def modified_beam_search(
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
@ -278,11 +278,11 @@ def modified_beam_search(
|
||||
# decoder_input is of shape (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim)
|
||||
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, 1, -1
|
||||
)
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
|
Loading…
x
Reference in New Issue
Block a user