mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix balancer code
This commit is contained in:
parent
11a04c50ae
commit
2eef001d39
@ -176,13 +176,13 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
|
|
||||||
self.pre_norm_final = Identity()
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
self.balancer = ActivationBalancer(channel_dim=-1,
|
||||||
min_positive=0.45,
|
min_positive=0.45,
|
||||||
max_positive=0.55)
|
max_positive=0.55,
|
||||||
|
max_positive=6.0)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -232,7 +232,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
src = self.balancer(self.norm_final(self.pre_norm_final(src)))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user