Add balancer_ff2 to avoid too small ff2 module

This commit is contained in:
Daniel Povey 2022-12-31 01:09:17 +08:00
parent 9ee4472f36
commit c15578d0bb

View File

@ -461,6 +461,17 @@ class ZipformerEncoderLayer(nn.Module):
min_positive=0.45, max_positive=0.55, min_positive=0.45, max_positive=0.55,
min_abs=1.0, max_abs=4.0, min_abs=1.0, max_abs=4.0,
) )
# balancer for output of feedforward2, prevent it from staying too
# small. give this a very small probability, even at the start of
# training, it's to fix a rare problem and it's OK to fix it slowly.
self.balancer_ff2 = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55,
min_abs=ScheduledFloat((0.0, 0.0), (8000.0, 0.5), default=0.0),
max_abs=2.0,
prob=0.05,
)
self.whiten = Whiten(num_groups=1, self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(4.0, ratio=3.0), whitening_limit=_whitening_schedule(4.0, ratio=3.0),
prob=(0.025, 0.25), prob=(0.025, 0.25),
@ -574,7 +585,7 @@ class ZipformerEncoderLayer(nn.Module):
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward2(src) src = src + self.balancer_ff2(self.feed_forward2(src))
src = self.balancer1(src) src = self.balancer1(src)
src = self.norm(src) src = self.norm(src)