Remove in_balancer.

This commit is contained in:
Daniel Povey 2022-10-19 12:35:17 +08:00
parent f4442de1c4
commit 45c38dec61

View File

@ -869,8 +869,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.copy_pos_query = Identity()
self.copy_query = Identity()
self.in_balancer = ActivationBalancer(3 * attention_dim,
channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear(
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
@ -930,7 +928,7 @@ class RelPositionMultiheadAttention(nn.Module):
and S is the sequence length.
"""
x, weights = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)),
self.in_proj(x),
self.linear_pos(pos_emb),
self.attention_dim,
self.num_heads,