Update beam_search.py

This commit is contained in:
jinzr 2023-07-25 23:51:21 +08:00
parent 0ed25de93a
commit b9a8f107a2

View File

@ -818,18 +818,17 @@ def deprecated_greedy_search_batch_for_cross_attn(
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
encoder_out_for_attn = encoder_out.unsqueeze(2)
# encoder_out_for_attn = encoder_out.unsqueeze(2)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
current_encoder_out = model.joiner.label_level_am_attention(
encoder_out,
decoder_out,
encoder_out_lens
)
encoder_out[:, : t + 1, :].unsqueeze(2),
decoder_out.unsqueeze(2),
encoder_out_lens,
)
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),