Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-03-12 15:43:16 +08:00
parent 33c0f8f7f6
commit 949b53274c

View File

@ -262,7 +262,7 @@ def modified_beam_search(
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
@ -278,11 +278,11 @@ def modified_beam_search(
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1
)
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,