mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add balancer at output of FeedforwardModule
This commit is contained in:
parent
fe1793e288
commit
26916f41e7
@ -1381,21 +1381,27 @@ class FeedforwardModule(nn.Module):
|
||||
dropout: float):
|
||||
super(FeedforwardModule, self).__init__()
|
||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||
self.balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_prob=0.25)
|
||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_prob=0.25)
|
||||
self.activation = DoubleSwish()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01)
|
||||
self.out_balancer = ActivationBalancer(embed_dim,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=0.01, max_abs=5.0,
|
||||
channel_dim=-1, min_prob=0.1)
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
x = self.in_proj(x)
|
||||
x = self.balancer(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user