From c15578d0bbf2364718872aa44e77730390d3989c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 31 Dec 2022 01:09:17 +0800 Subject: [PATCH] Add balancer_ff2 to avoid too small ff2 module --- .../ASR/pruned_transducer_stateless7/zipformer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 87d1e0333..be7130a97 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -461,6 +461,17 @@ class ZipformerEncoderLayer(nn.Module): min_positive=0.45, max_positive=0.55, min_abs=1.0, max_abs=4.0, ) + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=ScheduledFloat((0.0, 0.0), (8000.0, 0.5), default=0.0), + max_abs=2.0, + prob=0.05, + ) + self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(4.0, ratio=3.0), prob=(0.025, 0.25), @@ -574,7 +585,7 @@ class ZipformerEncoderLayer(nn.Module): if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.feed_forward2(src) + src = src + self.balancer_ff2(self.feed_forward2(src)) src = self.balancer1(src) src = self.norm(src)