From 5f8080702786d87655b170b3ff5caba58ba02e62 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 14:15:09 +0800 Subject: [PATCH] Add LinearWithAuxLoss in nonlin_attention and AttentionSqueeze modules. --- .../pruned_transducer_stateless7/zipformer.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6a4f08f6d..c2415eb7f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -342,6 +342,9 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: (12000.0, ratio * x), default=x) +def _aux_grad_scale() -> ScheduledFloat: + return ScheduledFloat((0.0, 0.2), (1000.0, 0.01)) + class ZipformerEncoderLayer(nn.Module): """ Args: @@ -1286,8 +1289,9 @@ class AttentionSqueeze(nn.Module): super().__init__() self.bottleneck_dim = bottleneck_dim - self.in_proj = nn.Linear(embed_dim, embed_dim, - bias=False) + self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim, + bias=False, + aux_grad_scale=_aux_grad_scale()) self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, bottleneck_dim) @@ -1337,8 +1341,9 @@ class AttentionSqueeze(nn.Module): self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) - self.out_proj = ScaledLinear(embed_dim, embed_dim, - bias=False, initial_scale=0.05) + self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim, + aux_grad_scale=_aux_grad_scale(), + bias=False, initial_scale=0.05) def forward(self, x: Tensor, @@ -1385,7 +1390,7 @@ class FeedforwardModule(nn.Module): dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, - aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))) + aux_grad_scale=_aux_grad_scale()) self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0, @@ -1394,7 +1399,7 @@ class FeedforwardModule(nn.Module): self.dropout = nn.Dropout(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, - aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))) + aux_grad_scale=_aux_grad_scale()) self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), @@ -1425,7 +1430,8 @@ class NonlinAttentionModule(nn.Module): ) -> None: super().__init__() - self.in_proj = nn.Linear(channels, 2 * channels, bias=True) + self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True, + aux_grad_scale=_aux_grad_scale()) # balancer that goes after the glu mechanism. self.balancer = ActivationBalancer( @@ -1437,9 +1443,10 @@ class NonlinAttentionModule(nn.Module): self.sigmoid = nn.Sigmoid() self.activation = Identity() # for diagnostics. - self.out_proj = ScaledLinear(channels, channels, - bias=True, - initial_scale=0.05) + self.out_proj = LinearWithAuxLoss(channels, channels, + bias=True, + aux_grad_scale=_aux_grad_scale(), + initial_scale=0.05) self.whiten1 = Whiten(num_groups=1, whitening_limit=_whitening_schedule(5.0), @@ -1515,7 +1522,7 @@ class ConvolutionModule(nn.Module): self.in_proj = LinearWithAuxLoss( channels, 2 * channels, - aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)) + aux_grad_scale=_aux_grad_scale() ) @@ -1562,7 +1569,7 @@ class ConvolutionModule(nn.Module): self.out_proj = LinearWithAuxLoss( channels, channels, - aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)), + aux_grad_scale=_aux_grad_scale(), initial_scale=0.05, ) @@ -1678,7 +1685,7 @@ class Conv2dSubsampling(nn.Module): self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, - aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))) + aux_grad_scale=_aux_grad_scale()) self.dropout = nn.Dropout(dropout)