mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp271' into scaled_adam_exp274
This commit is contained in:
parent
31d9bbfb3c
commit
cefcd061bd
@ -1044,8 +1044,19 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option.
|
||||
# they only copy their inputs.
|
||||
# this is to stop a failure mode where the output gets very small and is
|
||||
# dominated by the mean (the min_positive and max_positive will stop the mean
|
||||
# being much larger than the variance). Make min_abs very small because even for normal,
|
||||
# functional self_attn layers, the output rms can be very small.
|
||||
self.out_balancer = ActivationBalancer(embed_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.33,
|
||||
max_positive=0.66,
|
||||
min_abs=0.005, max_abs=1.0,
|
||||
min_prob=0.05)
|
||||
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option
|
||||
self.copy_pos_query = Identity()
|
||||
self.copy_query = Identity()
|
||||
|
||||
@ -1061,6 +1072,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
whitening_limit=2.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
self.out_balancer2 = ActivationBalancer(embed_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.33,
|
||||
max_positive=0.66,
|
||||
min_abs=0.005, max_abs=1.0,
|
||||
min_prob=0.05)
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
@ -1113,8 +1131,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.attention_dim,
|
||||
self.num_heads,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
@ -1129,8 +1145,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
@ -1142,7 +1156,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attention_dim: dimension inside attention mechanism
|
||||
num_heads: parallel attention heads.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
@ -1349,9 +1362,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
.contiguous()
|
||||
.view(seq_len, bsz, attention_dim // 2)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.out_balancer(attn_output)
|
||||
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
@ -1391,7 +1403,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
.view(seq_len, bsz, self.attention_dim // 2)
|
||||
)
|
||||
# returned value is of shape (seq_len, bsz, embed_dim), like x.
|
||||
return self.out_proj2(attn_output)
|
||||
attn_output = self.out_proj2(attn_output)
|
||||
attn_output = self.out_balancer2(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _print_attn_stats(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user