mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update beam_search.py
This commit is contained in:
parent
0ed25de93a
commit
b9a8f107a2
@ -818,17 +818,16 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
encoder_out_for_attn = encoder_out.unsqueeze(2)
|
# encoder_out_for_attn = encoder_out.unsqueeze(2)
|
||||||
|
|
||||||
|
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
current_encoder_out = model.joiner.label_level_am_attention(
|
current_encoder_out = model.joiner.label_level_am_attention(
|
||||||
encoder_out,
|
encoder_out[:, : t + 1, :].unsqueeze(2),
|
||||||
decoder_out,
|
decoder_out.unsqueeze(2),
|
||||||
encoder_out_lens
|
encoder_out_lens,
|
||||||
)
|
)
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user