diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f96ef34a9..4e8deb88f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -403,8 +403,9 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.01), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), + ff2_skip_rate: FloatLike = 0.01, bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, ) -> None: @@ -418,6 +419,9 @@ class ZipformerEncoderLayer(nn.Module): # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. @@ -608,7 +612,8 @@ class ZipformerEncoderLayer(nn.Module): if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.balancer_ff2(self.feed_forward2(src)) + if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate): + src = src + self.balancer_ff2(self.feed_forward2(src)) src = self.balancer1(src) src = self.norm(src)