Merge branch 'scaled_adam_exp271' into scaled_adam_exp274

This commit is contained in:
Daniel Povey 2022-11-03 19:55:23 +08:00
parent 31d9bbfb3c
commit cefcd061bd

View File

@ -1044,8 +1044,19 @@ class RelPositionMultiheadAttention(nn.Module):
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
initial_scale=0.05) initial_scale=0.05)
# the following are for diagnosics only, see --print-diagnostics option. # this is to stop a failure mode where the output gets very small and is
# they only copy their inputs. # dominated by the mean (the min_positive and max_positive will stop the mean
# being much larger than the variance). Make min_abs very small because even for normal,
# functional self_attn layers, the output rms can be very small.
self.out_balancer = ActivationBalancer(embed_dim,
channel_dim=-1,
min_positive=0.33,
max_positive=0.66,
min_abs=0.005, max_abs=1.0,
min_prob=0.05)
# the following are for diagnosics only, see --print-diagnostics option
self.copy_pos_query = Identity() self.copy_pos_query = Identity()
self.copy_query = Identity() self.copy_query = Identity()
@ -1061,6 +1072,13 @@ class RelPositionMultiheadAttention(nn.Module):
whitening_limit=2.0, whitening_limit=2.0,
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.025) grad_scale=0.025)
self.out_balancer2 = ActivationBalancer(embed_dim,
channel_dim=-1,
min_positive=0.33,
max_positive=0.66,
min_abs=0.005, max_abs=1.0,
min_prob=0.05)
def forward( def forward(
@ -1113,8 +1131,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.attention_dim, self.attention_dim,
self.num_heads, self.num_heads,
self.dropout, self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training, training=self.training,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
attn_mask=attn_mask, attn_mask=attn_mask,
@ -1129,8 +1145,6 @@ class RelPositionMultiheadAttention(nn.Module):
attention_dim: int, attention_dim: int,
num_heads: int, num_heads: int,
dropout_p: float, dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True, training: bool = True,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
@ -1142,7 +1156,6 @@ class RelPositionMultiheadAttention(nn.Module):
attention_dim: dimension inside attention mechanism attention_dim: dimension inside attention mechanism
num_heads: parallel attention heads. num_heads: parallel attention heads.
dropout_p: probability of an element to be zeroed. dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``. training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True, be ignored by the attention. This is an binary mask. When the value is True,
@ -1349,9 +1362,8 @@ class RelPositionMultiheadAttention(nn.Module):
.contiguous() .contiguous()
.view(seq_len, bsz, attention_dim // 2) .view(seq_len, bsz, attention_dim // 2)
) )
attn_output = nn.functional.linear( attn_output = self.out_proj(attn_output)
attn_output, out_proj_weight, out_proj_bias attn_output = self.out_balancer(attn_output)
)
return attn_output, attn_output_weights return attn_output, attn_output_weights
@ -1391,7 +1403,9 @@ class RelPositionMultiheadAttention(nn.Module):
.view(seq_len, bsz, self.attention_dim // 2) .view(seq_len, bsz, self.attention_dim // 2)
) )
# returned value is of shape (seq_len, bsz, embed_dim), like x. # returned value is of shape (seq_len, bsz, embed_dim), like x.
return self.out_proj2(attn_output) attn_output = self.out_proj2(attn_output)
attn_output = self.out_balancer2(attn_output)
return attn_output
def _print_attn_stats( def _print_attn_stats(