mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Limit magnitude of linear_pos
This commit is contained in:
parent
61619c031e
commit
bc5c782294
@ -465,6 +465,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
||||
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
|
||||
max_positive=1.0, max_abs=10.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
||||
)
|
||||
@ -774,7 +776,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
|
Loading…
x
Reference in New Issue
Block a user