From a6657e6b40a169c73276dfd510bc3f917623e8ca Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 23 Nov 2022 19:08:19 +0800 Subject: [PATCH] Harmonize whitening modules, adding them to 3 submodules and changing configuration on 2 others and location in NonlinAttention. --- .../pruned_transducer_stateless7/zipformer.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 91898328b..f1dc64cbf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1226,6 +1226,11 @@ class SelfAttention(nn.Module): embed_dim, bias=True, initial_scale=0.05) + self.whiten = Whiten(num_groups=1, + whitening_limit=15.0, + prob=(0.025, 0.25), + grad_scale=0.01) + def forward( self, @@ -1259,6 +1264,7 @@ class SelfAttention(nn.Module): # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) + x = self.whiten(x) return x @@ -1325,7 +1331,7 @@ class AttentionSqueeze(nn.Module): bias=False, initial_scale=0.05) self.out_whiten = Whiten(num_groups=1, - whitening_limit=10.0, + whitening_limit=15.0, prob=(0.01, 0.1), grad_scale=0.01) @@ -1382,7 +1388,7 @@ class FeedforwardModule(nn.Module): self.out_proj = ScaledLinear(feedforward_dim, embed_dim, initial_scale=0.01) self.out_whiten = Whiten(num_groups=1, - whitening_limit=10.0, + whitening_limit=15.0, prob=(0.025, 0.25), grad_scale=0.01) @@ -1420,19 +1426,17 @@ class NonlinAttentionModule(nn.Module): min_abs=0.2, max_abs=10.0, min_prob=0.05, ) - # give it a high limit, because it is quite high-dimensional and is - # a projection of a lower-dimensional embedding. - self.whiten = Whiten(num_groups=2, - whitening_limit=20.0, - prob=(0.025, 0.25), - grad_scale=0.01) - self.activation = Identity() # for diagnostics. self.out_proj = ScaledLinear(channels, channels, bias=True, initial_scale=0.05) + self.whiten = Whiten(num_groups=1, + whitening_limit=15.0, + prob=(0.025, 0.25), + grad_scale=0.01) + def forward(self, @@ -1447,7 +1451,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) a Tensor with the same shape as x """ x = self.in_proj(x) - x = self.whiten(x) + v, s = x.chunk(2, dim=-1) if self.training and random.random() < 0.02: @@ -1472,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = self.activation(x) # diagnostics only, it's the identity. x = self.out_proj(x) + x = self.whiten(x) return x @@ -1549,6 +1554,12 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) + self.out_whiten = Whiten(num_groups=1, + whitening_limit=15.0, + prob=(0.01, 0.1), + grad_scale=0.01) + + def forward(self, x: Tensor, src_key_padding_mask: Optional[Tensor] = None, @@ -1584,8 +1595,9 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) - + x = x.permute(2, 0, 1) + x = self.out_whiten(x) + return x class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length).