Merge branch 'scaled_adam_exp271' into scaled_adam_exp274
This commit is contained in:
parent
31d9bbfb3c
commit
cefcd061bd
@ -1044,8 +1044,19 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
|
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
# the following are for diagnosics only, see --print-diagnostics option.
|
# this is to stop a failure mode where the output gets very small and is
|
||||||
# they only copy their inputs.
|
# dominated by the mean (the min_positive and max_positive will stop the mean
|
||||||
|
# being much larger than the variance). Make min_abs very small because even for normal,
|
||||||
|
# functional self_attn layers, the output rms can be very small.
|
||||||
|
self.out_balancer = ActivationBalancer(embed_dim,
|
||||||
|
channel_dim=-1,
|
||||||
|
min_positive=0.33,
|
||||||
|
max_positive=0.66,
|
||||||
|
min_abs=0.005, max_abs=1.0,
|
||||||
|
min_prob=0.05)
|
||||||
|
|
||||||
|
|
||||||
|
# the following are for diagnosics only, see --print-diagnostics option
|
||||||
self.copy_pos_query = Identity()
|
self.copy_pos_query = Identity()
|
||||||
self.copy_query = Identity()
|
self.copy_query = Identity()
|
||||||
|
|
||||||
@ -1061,6 +1072,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
whitening_limit=2.0,
|
whitening_limit=2.0,
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.025)
|
grad_scale=0.025)
|
||||||
|
self.out_balancer2 = ActivationBalancer(embed_dim,
|
||||||
|
channel_dim=-1,
|
||||||
|
min_positive=0.33,
|
||||||
|
max_positive=0.66,
|
||||||
|
min_abs=0.005, max_abs=1.0,
|
||||||
|
min_prob=0.05)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1113,8 +1131,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.attention_dim,
|
self.attention_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.dropout,
|
self.dropout,
|
||||||
self.out_proj.weight,
|
|
||||||
self.out_proj.bias,
|
|
||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
@ -1129,8 +1145,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attention_dim: int,
|
attention_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
out_proj_weight: Tensor,
|
|
||||||
out_proj_bias: Tensor,
|
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
@ -1142,7 +1156,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attention_dim: dimension inside attention mechanism
|
attention_dim: dimension inside attention mechanism
|
||||||
num_heads: parallel attention heads.
|
num_heads: parallel attention heads.
|
||||||
dropout_p: probability of an element to be zeroed.
|
dropout_p: probability of an element to be zeroed.
|
||||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
|
||||||
training: apply dropout if is ``True``.
|
training: apply dropout if is ``True``.
|
||||||
key_padding_mask: if provided, specified padding elements in the key will
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
be ignored by the attention. This is an binary mask. When the value is True,
|
be ignored by the attention. This is an binary mask. When the value is True,
|
||||||
@ -1349,9 +1362,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
.view(seq_len, bsz, attention_dim // 2)
|
.view(seq_len, bsz, attention_dim // 2)
|
||||||
)
|
)
|
||||||
attn_output = nn.functional.linear(
|
attn_output = self.out_proj(attn_output)
|
||||||
attn_output, out_proj_weight, out_proj_bias
|
attn_output = self.out_balancer(attn_output)
|
||||||
)
|
|
||||||
|
|
||||||
return attn_output, attn_output_weights
|
return attn_output, attn_output_weights
|
||||||
|
|
||||||
@ -1391,7 +1403,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
.view(seq_len, bsz, self.attention_dim // 2)
|
.view(seq_len, bsz, self.attention_dim // 2)
|
||||||
)
|
)
|
||||||
# returned value is of shape (seq_len, bsz, embed_dim), like x.
|
# returned value is of shape (seq_len, bsz, embed_dim), like x.
|
||||||
return self.out_proj2(attn_output)
|
attn_output = self.out_proj2(attn_output)
|
||||||
|
attn_output = self.out_balancer2(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def _print_attn_stats(
|
def _print_attn_stats(
|
||||||
|
|||||||
Reference in New Issue
Block a user