mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
modify attn_offsets
This commit is contained in:
parent
6aaa971b34
commit
ee485c02fc
@ -1598,6 +1598,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
assert attn_scores.shape == (num_heads, new_batch_size, time1, time2)
|
||||
|
||||
assert attn_mask is None
|
||||
if attn_mask is not None:
|
||||
# TODO:
|
||||
assert attn_mask.dtype == torch.bool
|
||||
@ -1607,9 +1608,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# compares the final weights with zero.
|
||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||
|
||||
assert key_padding_mask is not None
|
||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||
attn_offsets = (~key_padding_mask).float() # 0 at padding positions
|
||||
# Used to mask out the padding positions
|
||||
attn_offsets = torch.ones(batch_size, seq_len, device=x.device)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||
attn_offsets = attn_offsets.masked_fill(key_padding_mask, 0.0) # 0 at padding positions
|
||||
|
||||
# (seq_len, batch, 1)
|
||||
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
|
||||
@ -1619,14 +1623,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
kernel=block_size * 3, stride=block_size, padding=block_size,
|
||||
).squeeze(-1)
|
||||
|
||||
# For the blocks are all padding
|
||||
# Used for the blocks are all padding
|
||||
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
|
||||
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
|
||||
|
||||
attn_offsets = 1 - attn_offsets # 1 at padding positions
|
||||
# attn_offsets[attn_offsets != 0] = float("-inf")
|
||||
attn_offsets[attn_offsets != 0] = -1000
|
||||
# attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000)
|
||||
|
||||
# (1, new_batch_size, 1, time2)
|
||||
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user