Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-03-12 15:43:16 +08:00
parent 33c0f8f7f6
commit 949b53274c

View File

@ -262,7 +262,7 @@ def modified_beam_search(
for t in range(T): for t in range(T):
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, encoder_out_dim) # current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
# fmt: on # fmt: on
A = list(B) A = list(B)
B = HypothesisList() B = HypothesisList()
@ -278,11 +278,11 @@ def modified_beam_search(
# decoder_input is of shape (num_hyps, context_size) # decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1,1, decoder_output_dim) # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand( current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1 decoder_out.size(0), 1, 1, -1
) ) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner( logits = model.joiner(
current_encoder_out, current_encoder_out,