mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add balancer at output of FeedforwardModule
This commit is contained in:
parent
fe1793e288
commit
26916f41e7
@ -1381,21 +1381,27 @@ class FeedforwardModule(nn.Module):
|
|||||||
dropout: float):
|
dropout: float):
|
||||||
super(FeedforwardModule, self).__init__()
|
super(FeedforwardModule, self).__init__()
|
||||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||||
self.balancer = ActivationBalancer(feedforward_dim,
|
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||||
channel_dim=-1, max_abs=10.0,
|
channel_dim=-1, max_abs=10.0,
|
||||||
min_prob=0.25)
|
min_prob=0.25)
|
||||||
self.activation = DoubleSwish()
|
self.activation = DoubleSwish()
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.01)
|
initial_scale=0.01)
|
||||||
|
self.out_balancer = ActivationBalancer(embed_dim,
|
||||||
|
min_positive=0.4, max_positive=0.6,
|
||||||
|
min_abs=0.01, max_abs=5.0,
|
||||||
|
channel_dim=-1, min_prob=0.1)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor):
|
x: Tensor):
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
x = self.balancer(x)
|
x = self.hidden_balancer(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user