Add balancer at output of FeedforwardModule

This commit is contained in:
Daniel Povey 2022-11-22 14:43:46 +08:00
parent fe1793e288
commit 26916f41e7

View File

@ -1381,21 +1381,27 @@ class FeedforwardModule(nn.Module):
dropout: float): dropout: float):
super(FeedforwardModule, self).__init__() super(FeedforwardModule, self).__init__()
self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.in_proj = nn.Linear(embed_dim, feedforward_dim)
self.balancer = ActivationBalancer(feedforward_dim, self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0, channel_dim=-1, max_abs=10.0,
min_prob=0.25) min_prob=0.25)
self.activation = DoubleSwish() self.activation = DoubleSwish()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.out_proj = ScaledLinear(feedforward_dim, embed_dim, self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.01) initial_scale=0.01)
self.out_balancer = ActivationBalancer(embed_dim,
min_positive=0.4, max_positive=0.6,
min_abs=0.01, max_abs=5.0,
channel_dim=-1, min_prob=0.1)
def forward(self, def forward(self,
x: Tensor): x: Tensor):
x = self.in_proj(x) x = self.in_proj(x)
x = self.balancer(x) x = self.hidden_balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.dropout(x) x = self.dropout(x)
x = self.out_proj(x) x = self.out_proj(x)
x = self.out_balancer(x)
return x return x