Remove persistent attention scores.
This commit is contained in:
parent
1825336841
commit
eb58e6d74b
@ -300,18 +300,15 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
attn_scores_in: Optional[Tensor] = None,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder layer (required).
|
src: the sequence to the encoder layer (required).
|
||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is
|
|
||||||
passed from layer to layer.
|
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
batch_split: if not None, this layer will only be applied to
|
batch_split: if not None, this layer will only be applied to
|
||||||
@ -329,10 +326,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att, attn_weights, attn_scores_out = self.self_attn(
|
src_att, attn_weights = self.self_attn(
|
||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_scores_in=attn_scores_in,
|
|
||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
@ -358,7 +354,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||||
src = src_orig + delta * self.bypass_scale
|
src = src_orig + delta * self.bypass_scale
|
||||||
|
|
||||||
return src, attn_scores_out
|
return src
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
@ -516,7 +512,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
output = src
|
output = src
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
attn_scores = None
|
|
||||||
|
|
||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||||
@ -527,10 +522,9 @@ class ConformerEncoder(nn.Module):
|
|||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
if i in layers_to_drop:
|
if i in layers_to_drop:
|
||||||
continue
|
continue
|
||||||
output, attn_scores = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
attn_scores,
|
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
@ -851,8 +845,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
|
||||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||||
@ -870,7 +862,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
attn_scores_in: Optional[Tensor],
|
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
@ -892,7 +883,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
- attn_scores_in: :math:`(N, L, L, num_heads)`
|
|
||||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
@ -905,19 +895,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
is provided, it will be added to the attention weight.
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
- Returns: (attn_output, attn_weights, attn_scores)
|
- Returns: (attn_output, attn_weights)
|
||||||
|
|
||||||
- attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size,
|
- attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size,
|
||||||
E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
- attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads
|
- attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads
|
||||||
and S is the sequence length.
|
and S is the sequence length.
|
||||||
- attn_scores: :math:`(N, S, S, H)`, these are the attn weights
|
|
||||||
before softmax.
|
|
||||||
"""
|
"""
|
||||||
x, weights, scores = self.multi_head_attention_forward(
|
x, weights = self.multi_head_attention_forward(
|
||||||
self.in_balancer(self.in_proj(x)),
|
self.in_balancer(self.in_proj(x)),
|
||||||
pos_emb,
|
pos_emb,
|
||||||
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
|
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.weight,
|
self.in_proj.weight,
|
||||||
@ -929,15 +916,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
)
|
)
|
||||||
if attn_scores_in is not None:
|
return x, weights
|
||||||
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out)
|
|
||||||
attn_scores_out = attn_scores_out + attn_scores_in
|
|
||||||
else:
|
|
||||||
# Here, add self.attn_scores_proj_in in order to make sure it has
|
|
||||||
# a grad.
|
|
||||||
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out +
|
|
||||||
self.attn_scores_proj_in)
|
|
||||||
return x, weights, attn_scores_out
|
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
@ -968,7 +947,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
attn_scores_in: Optional[Tensor],
|
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
in_proj_weight: Tensor,
|
in_proj_weight: Tensor,
|
||||||
@ -1023,8 +1001,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
- attn_weights: :math:`(N * H, S, S)` where N is the batch size,
|
- attn_weights: :math:`(N * H, S, S)` where N is the batch size,
|
||||||
H is the num-heads, S is the sequence length.
|
H is the num-heads, S is the sequence length.
|
||||||
- attn_scores: :math:`(N, S, S, H)` where N is the batch size,
|
|
||||||
S is the sequence length and H is the num-heads.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tgt_len, bsz, _ = x.size()
|
tgt_len, bsz, _ = x.size()
|
||||||
@ -1137,10 +1113,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) # (batch, head, time1, time2)
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head)
|
|
||||||
|
|
||||||
if attn_scores_in is not None:
|
|
||||||
attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
@ -1187,7 +1159,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output, out_proj_weight, out_proj_bias
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_output, attn_output_weights, attn_scores_out
|
return attn_output, attn_output_weights
|
||||||
|
|
||||||
|
|
||||||
def forward2(
|
def forward2(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user