mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add balancer_ff2 to avoid too small ff2 module
This commit is contained in:
parent
9ee4472f36
commit
c15578d0bb
@ -461,6 +461,17 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
min_positive=0.45, max_positive=0.55,
|
min_positive=0.45, max_positive=0.55,
|
||||||
min_abs=1.0, max_abs=4.0,
|
min_abs=1.0, max_abs=4.0,
|
||||||
)
|
)
|
||||||
|
# balancer for output of feedforward2, prevent it from staying too
|
||||||
|
# small. give this a very small probability, even at the start of
|
||||||
|
# training, it's to fix a rare problem and it's OK to fix it slowly.
|
||||||
|
self.balancer_ff2 = Balancer(
|
||||||
|
embed_dim, channel_dim=-1,
|
||||||
|
min_positive=0.45, max_positive=0.55,
|
||||||
|
min_abs=ScheduledFloat((0.0, 0.0), (8000.0, 0.5), default=0.0),
|
||||||
|
max_abs=2.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
@ -574,7 +585,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.balancer_ff2(self.feed_forward2(src))
|
||||||
|
|
||||||
src = self.balancer1(src)
|
src = self.balancer1(src)
|
||||||
src = self.norm(src)
|
src = self.norm(src)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user