This commit is contained in:
JinZr 2023-07-25 19:56:58 +08:00
parent 3b4fa4863f
commit 755430c29e

View File

@ -407,14 +407,18 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (batch, max_len) # (batch, max_len)
key_padding_mask = ( key_padding_mask = (
( key_padding_mask.unsqueeze(1)
key_padding_mask.unsqueeze(0) .expand(
.repeat(1, self.prune_range, 1) key_padding_mask.shape[0], # b
.unsqueeze(2) self.prune_range,
key_padding_mask.shape[1], # l
) )
if key_padding_mask.shape[0] != attn_scores.shape[1] .reshape(b_p_dim, am_seq_len)
else key_padding_mask.unsqueeze(0).unsqueeze(2) .unsqueeze(1)
.unsqueeze(0)
) )
# print(key_padding_mask.shape)
# (1, b * p, 1, T)
attn_scores = attn_scores.masked_fill( attn_scores = attn_scores.masked_fill(
key_padding_mask, key_padding_mask,