Merge branch 'scaled_adam_exp271' into scaled_adam_exp274

This commit is contained in:
Daniel Povey 2022-11-03 19:55:23 +08:00
parent 31d9bbfb3c
commit cefcd061bd

View File

@ -1044,8 +1044,19 @@ class RelPositionMultiheadAttention(nn.Module):
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
initial_scale=0.05)
# the following are for diagnosics only, see --print-diagnostics option.
# they only copy their inputs.
# this is to stop a failure mode where the output gets very small and is
# dominated by the mean (the min_positive and max_positive will stop the mean
# being much larger than the variance). Make min_abs very small because even for normal,
# functional self_attn layers, the output rms can be very small.
self.out_balancer = ActivationBalancer(embed_dim,
channel_dim=-1,
min_positive=0.33,
max_positive=0.66,
min_abs=0.005, max_abs=1.0,
min_prob=0.05)
# the following are for diagnosics only, see --print-diagnostics option
self.copy_pos_query = Identity()
self.copy_query = Identity()
@ -1061,6 +1072,13 @@ class RelPositionMultiheadAttention(nn.Module):
whitening_limit=2.0,
prob=(0.025, 0.25),
grad_scale=0.025)
self.out_balancer2 = ActivationBalancer(embed_dim,
channel_dim=-1,
min_positive=0.33,
max_positive=0.66,
min_abs=0.005, max_abs=1.0,
min_prob=0.05)
def forward(
@ -1113,8 +1131,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.attention_dim,
self.num_heads,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
@ -1129,8 +1145,6 @@ class RelPositionMultiheadAttention(nn.Module):
attention_dim: int,
num_heads: int,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
@ -1142,7 +1156,6 @@ class RelPositionMultiheadAttention(nn.Module):
attention_dim: dimension inside attention mechanism
num_heads: parallel attention heads.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
@ -1349,9 +1362,8 @@ class RelPositionMultiheadAttention(nn.Module):
.contiguous()
.view(seq_len, bsz, attention_dim // 2)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
attn_output = self.out_proj(attn_output)
attn_output = self.out_balancer(attn_output)
return attn_output, attn_output_weights
@ -1391,7 +1403,9 @@ class RelPositionMultiheadAttention(nn.Module):
.view(seq_len, bsz, self.attention_dim // 2)
)
# returned value is of shape (seq_len, bsz, embed_dim), like x.
return self.out_proj2(attn_output)
attn_output = self.out_proj2(attn_output)
attn_output = self.out_balancer2(attn_output)
return attn_output
def _print_attn_stats(