mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Minor fixes.
This commit is contained in:
parent
33c0f8f7f6
commit
949b53274c
@ -262,7 +262,7 @@ def modified_beam_search(
|
|||||||
for t in range(T):
|
for t in range(T):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
A = list(B)
|
A = list(B)
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
@ -278,11 +278,11 @@ def modified_beam_search(
|
|||||||
# decoder_input is of shape (num_hyps, context_size)
|
# decoder_input is of shape (num_hyps, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim)
|
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||||
|
|
||||||
current_encoder_out = current_encoder_out.expand(
|
current_encoder_out = current_encoder_out.expand(
|
||||||
decoder_out.size(0), 1, 1, -1
|
decoder_out.size(0), 1, 1, -1
|
||||||
)
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user