Remove balancer from SelfAttention module.

This commit is contained in:
Daniel Povey 2022-11-23 18:41:36 +08:00
parent f2dbf87461
commit 9ceb41acb4

View File

@ -1226,14 +1226,6 @@ class SelfAttention(nn.Module):
embed_dim, bias=True,
initial_scale=0.05)
# intended to prevent an observed failure mode where the output of this module is
# dominated by its mean.
self.out_balancer = ActivationBalancer(embed_dim,
channel_dim=-1,
min_positive=0.33,
max_positive=0.66,
min_abs=0.005, max_abs=1.0,
min_prob=0.05)
def forward(
self,
@ -1267,7 +1259,6 @@ class SelfAttention(nn.Module):
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
x = self.out_balancer(x)
return x