mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
updated code for decoding
This commit is contained in:
parent
b9a8f107a2
commit
d70f6e21f2
@ -825,13 +825,15 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
||||
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
current_encoder_out = model.joiner.label_level_am_attention(
|
||||
encoder_out[:, : t + 1, :].unsqueeze(2),
|
||||
encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(2),
|
||||
encoder_out_lens,
|
||||
# encoder_out_lens,
|
||||
None,
|
||||
)
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
None,
|
||||
apply_attn=False,
|
||||
project_input=False,
|
||||
)
|
||||
|
@ -73,7 +73,7 @@ class CrossAttention(nn.Module):
|
||||
batch_size,
|
||||
lm_seq_len,
|
||||
am_seq_len,
|
||||
), f"{attn_weights.shape}"
|
||||
), f"{attn_weights.shape} {x.shape}"
|
||||
|
||||
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
||||
# print("projected x.shape", x.shape)
|
||||
@ -406,19 +406,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if key_padding_mask is not None:
|
||||
# (batch, max_len)
|
||||
|
||||
key_padding_mask = (
|
||||
key_padding_mask.unsqueeze(1)
|
||||
.expand(
|
||||
key_padding_mask.shape[0], # b
|
||||
self.prune_range,
|
||||
key_padding_mask.shape[1], # l
|
||||
if b_p_dim == key_padding_mask.shape[0] * self.prune_range:
|
||||
key_padding_mask = (
|
||||
key_padding_mask.unsqueeze(1)
|
||||
.expand(
|
||||
key_padding_mask.shape[0], # b
|
||||
self.prune_range,
|
||||
key_padding_mask.shape[1], # l
|
||||
)
|
||||
.reshape(b_p_dim, am_seq_len)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
.reshape(b_p_dim, am_seq_len)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
# (1, b * p, 1, T)
|
||||
else:
|
||||
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(0)
|
||||
# (1, b, 1, T)
|
||||
# print(key_padding_mask.shape)
|
||||
# (1, b * p, 1, T)
|
||||
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask,
|
||||
@ -492,7 +496,7 @@ class AlignmentAttentionModule(nn.Module):
|
||||
def forward(
|
||||
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
|
||||
) -> Tensor:
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None
|
||||
# (batch, max_len)
|
||||
|
||||
if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
|
||||
|
Loading…
x
Reference in New Issue
Block a user