updated code for decoding

This commit is contained in:
JinZr 2023-07-26 19:01:53 +08:00
parent b9a8f107a2
commit d70f6e21f2
2 changed files with 21 additions and 15 deletions

View File

@ -825,13 +825,15 @@ def deprecated_greedy_search_batch_for_cross_attn(
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
current_encoder_out = model.joiner.label_level_am_attention( current_encoder_out = model.joiner.label_level_am_attention(
encoder_out[:, : t + 1, :].unsqueeze(2), encoder_out.unsqueeze(2),
decoder_out.unsqueeze(2), decoder_out.unsqueeze(2),
encoder_out_lens, # encoder_out_lens,
None,
) )
logits = model.joiner( logits = model.joiner(
current_encoder_out, current_encoder_out,
decoder_out.unsqueeze(1), decoder_out.unsqueeze(1),
None,
apply_attn=False, apply_attn=False,
project_input=False, project_input=False,
) )

View File

@ -73,7 +73,7 @@ class CrossAttention(nn.Module):
batch_size, batch_size,
lm_seq_len, lm_seq_len,
am_seq_len, am_seq_len,
), f"{attn_weights.shape}" ), f"{attn_weights.shape} {x.shape}"
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim) x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
# print("projected x.shape", x.shape) # print("projected x.shape", x.shape)
@ -406,6 +406,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
# (batch, max_len) # (batch, max_len)
if b_p_dim == key_padding_mask.shape[0] * self.prune_range:
key_padding_mask = ( key_padding_mask = (
key_padding_mask.unsqueeze(1) key_padding_mask.unsqueeze(1)
.expand( .expand(
@ -417,8 +418,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
.unsqueeze(1) .unsqueeze(1)
.unsqueeze(0) .unsqueeze(0)
) )
# print(key_padding_mask.shape)
# (1, b * p, 1, T) # (1, b * p, 1, T)
else:
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(0)
# (1, b, 1, T)
# print(key_padding_mask.shape)
attn_scores = attn_scores.masked_fill( attn_scores = attn_scores.masked_fill(
key_padding_mask, key_padding_mask,
@ -492,7 +496,7 @@ class AlignmentAttentionModule(nn.Module):
def forward( def forward(
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
) -> Tensor: ) -> Tensor:
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None
# (batch, max_len) # (batch, max_len)
if am_pruned.ndim == 4 and lm_pruned.ndim == 4: if am_pruned.ndim == 4 and lm_pruned.ndim == 4: