Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-03-12 15:43:16 +08:00
parent 33c0f8f7f6
commit 949b53274c

View File

@ -262,7 +262,7 @@ def modified_beam_search(
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
@ -282,7 +282,7 @@ def modified_beam_search(
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1
)
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,