diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index cd1c40294..1669253d0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1391,7 +1391,9 @@ class FeedforwardModule(nn.Module): aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) self.hidden_balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.1), (12000.0, 0.05)), + max_abs=ScheduledFloat((0.0, 4.0), (12000.0, 10.0), default=10), min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) @@ -1435,7 +1437,7 @@ class NonlinAttentionModule(nn.Module): channels // 2, channel_dim=-1, min_positive=0.05, max_positive=1.0, min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0), - (4000.0, 10.0), + (12000.0, 10.0), default=1.0), ) self.sigmoid = nn.Sigmoid() @@ -1529,7 +1531,7 @@ class ConvolutionModule(nn.Module): self.deriv_balancer1 = ActivationBalancer( 2 * channels, channel_dim=-1, max_abs=ScheduledFloat((0.0, 2.0), - (4000.0, 10.0), + (12000.0, 10.0), default=1.0), min_positive=0.05, max_positive=1.0 ) @@ -1549,8 +1551,9 @@ class ConvolutionModule(nn.Module): self.deriv_balancer2 = ActivationBalancer( channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, - max_abs=20.0, + min_positive=ScheduledFloat((0.0, 0.1), (12000.0, 0.05)), + max_abs=ScheduledFloat((0.0, 4.0), (12000.0, 20.0), default=10), + max_positive=1.0, ) self.activation = DoubleSwish()