diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 71d49af88..8731fee2c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -343,9 +343,12 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: default=x) def _aux_grad_scale() -> float: - return 0.05 -def _aux_grad_prob() -> ScheduledFloat: + return 0.1 +def _aux_grad_prob_out() -> ScheduledFloat: return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125)) +def _aux_grad_prob_in() -> ScheduledFloat: + return ScheduledFloat((0.0, 0.25), (1000.0, 0.0)) + class ZipformerEncoderLayer(nn.Module): """ @@ -1289,7 +1292,7 @@ class AttentionSqueeze(nn.Module): self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim, bias=False, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, bottleneck_dim) @@ -1340,7 +1343,7 @@ class AttentionSqueeze(nn.Module): self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(), + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), bias=False, initial_scale=0.05) def forward(self, @@ -1388,7 +1391,7 @@ class FeedforwardModule(nn.Module): dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, @@ -1399,7 +1402,7 @@ class FeedforwardModule(nn.Module): self.dropout = nn.Dropout(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), @@ -1511,7 +1514,7 @@ class ConvolutionModule(nn.Module): self.in_proj = LinearWithAuxLoss( channels, 2 * channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob() + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in() ) @@ -1564,7 +1567,7 @@ class ConvolutionModule(nn.Module): self.out_proj = LinearWithAuxLoss( channels, channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(), + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), initial_scale=0.05, ) @@ -1689,7 +1692,7 @@ class Conv2dSubsampling(nn.Module): self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.01)) self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) self.dropout = nn.Dropout(dropout)