This commit is contained in:
JinZr 2023-07-25 19:56:58 +08:00
parent 3b4fa4863f
commit 755430c29e

View File

@ -407,14 +407,18 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (batch, max_len)
key_padding_mask = (
(
key_padding_mask.unsqueeze(0)
.repeat(1, self.prune_range, 1)
.unsqueeze(2)
key_padding_mask.unsqueeze(1)
.expand(
key_padding_mask.shape[0], # b
self.prune_range,
key_padding_mask.shape[1], # l
)
if key_padding_mask.shape[0] != attn_scores.shape[1]
else key_padding_mask.unsqueeze(0).unsqueeze(2)
.reshape(b_p_dim, am_seq_len)
.unsqueeze(1)
.unsqueeze(0)
)
# print(key_padding_mask.shape)
# (1, b * p, 1, T)
attn_scores = attn_scores.masked_fill(
key_padding_mask,