modify attn_offsets

This commit is contained in:
yaozengwei 2023-07-21 15:38:22 +08:00
parent 6aaa971b34
commit ee485c02fc

View File

@ -1598,6 +1598,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
assert attn_scores.shape == (num_heads, new_batch_size, time1, time2) assert attn_scores.shape == (num_heads, new_batch_size, time1, time2)
assert attn_mask is None
if attn_mask is not None: if attn_mask is not None:
# TODO: # TODO:
assert attn_mask.dtype == torch.bool assert attn_mask.dtype == torch.bool
@ -1607,9 +1608,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# compares the final weights with zero. # compares the final weights with zero.
attn_scores = attn_scores.masked_fill(attn_mask, -1000) attn_scores = attn_scores.masked_fill(attn_mask, -1000)
assert key_padding_mask is not None # Used to mask out the padding positions
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape attn_offsets = torch.ones(batch_size, seq_len, device=x.device)
attn_offsets = (~key_padding_mask).float() # 0 at padding positions
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_offsets = attn_offsets.masked_fill(key_padding_mask, 0.0) # 0 at padding positions
# (seq_len, batch, 1) # (seq_len, batch, 1)
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
@ -1619,14 +1623,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
kernel=block_size * 3, stride=block_size, padding=block_size, kernel=block_size * 3, stride=block_size, padding=block_size,
).squeeze(-1) ).squeeze(-1)
# For the blocks are all padding # Used for the blocks are all padding
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size) all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1) all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
attn_offsets = 1 - attn_offsets # 1 at padding positions attn_offsets = 1 - attn_offsets # 1 at padding positions
# attn_offsets[attn_offsets != 0] = float("-inf")
attn_offsets[attn_offsets != 0] = -1000 attn_offsets[attn_offsets != 0] = -1000
# attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000)
# (1, new_batch_size, 1, time2) # (1, new_batch_size, 1, time2)
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)