mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Update beam_search.py
This commit is contained in:
parent
0ed25de93a
commit
b9a8f107a2
@ -818,18 +818,17 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
encoder_out_for_attn = encoder_out.unsqueeze(2)
|
||||
|
||||
# encoder_out_for_attn = encoder_out.unsqueeze(2)
|
||||
|
||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||
for t in range(T):
|
||||
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
current_encoder_out = model.joiner.label_level_am_attention(
|
||||
encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_lens
|
||||
)
|
||||
encoder_out[:, : t + 1, :].unsqueeze(2),
|
||||
decoder_out.unsqueeze(2),
|
||||
encoder_out_lens,
|
||||
)
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user