Bug fix RE masking

This commit is contained in:
Daniel Povey 2022-11-09 13:12:34 +08:00
parent 20e6d2a157
commit cba194aa26

View File

@ -1225,7 +1225,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
num_heads, batch_size, seq_len, seq_len
)
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(-1),
key_padding_mask.unsqueeze(1),
float("-inf"),
)