diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c2415eb7f..bcdf859a8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -342,8 +342,10 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: (12000.0, ratio * x), default=x) -def _aux_grad_scale() -> ScheduledFloat: - return ScheduledFloat((0.0, 0.2), (1000.0, 0.01)) +def _aux_grad_scale() -> float: + return 0.1 +def _aux_grad_prob() -> ScheduledFloat: + return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125)) class ZipformerEncoderLayer(nn.Module): """ @@ -1291,7 +1293,7 @@ class AttentionSqueeze(nn.Module): self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim, bias=False, - aux_grad_scale=_aux_grad_scale()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, bottleneck_dim) @@ -1342,7 +1344,7 @@ class AttentionSqueeze(nn.Module): self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim, - aux_grad_scale=_aux_grad_scale(), + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(), bias=False, initial_scale=0.05) def forward(self, @@ -1390,7 +1392,7 @@ class FeedforwardModule(nn.Module): dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, - aux_grad_scale=_aux_grad_scale()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0, @@ -1399,7 +1401,7 @@ class FeedforwardModule(nn.Module): self.dropout = nn.Dropout(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, - aux_grad_scale=_aux_grad_scale()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), @@ -1431,7 +1433,7 @@ class NonlinAttentionModule(nn.Module): super().__init__() self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True, - aux_grad_scale=_aux_grad_scale()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) # balancer that goes after the glu mechanism. self.balancer = ActivationBalancer( @@ -1445,7 +1447,7 @@ class NonlinAttentionModule(nn.Module): self.activation = Identity() # for diagnostics. self.out_proj = LinearWithAuxLoss(channels, channels, bias=True, - aux_grad_scale=_aux_grad_scale(), + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(), initial_scale=0.05) self.whiten1 = Whiten(num_groups=1, @@ -1522,7 +1524,7 @@ class ConvolutionModule(nn.Module): self.in_proj = LinearWithAuxLoss( channels, 2 * channels, - aux_grad_scale=_aux_grad_scale() + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob() ) @@ -1569,7 +1571,7 @@ class ConvolutionModule(nn.Module): self.out_proj = LinearWithAuxLoss( channels, channels, - aux_grad_scale=_aux_grad_scale(), + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(), initial_scale=0.05, ) @@ -1685,7 +1687,7 @@ class Conv2dSubsampling(nn.Module): self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, - aux_grad_scale=_aux_grad_scale()) + aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) self.dropout = nn.Dropout(dropout)