From 35f0ea001595ae18406f8e90e9372cf35dfdf636 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Nov 2022 13:47:22 +0800 Subject: [PATCH] Changes to whitening modules for memory efficiency, moving them inside; increase their prob. --- .../pruned_transducer_stateless7/zipformer.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a659803e7..81cfe12f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1041,6 +1041,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.pos_head_dim = pos_head_dim self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1202,7 +1203,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( dim=-1).mean(dim=(1,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}") + logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") class SelfAttention(nn.Module): @@ -1328,17 +1329,17 @@ class AttentionSqueeze(nn.Module): min_abs=0.2, max_abs=1.0, min_prob=0.05, ) + self.activation_whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=False, initial_scale=0.05) - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.01, 0.1), - grad_scale=0.01) - def forward(self, x: Tensor, attn_weights: Tensor): @@ -1367,11 +1368,11 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = self.in_proj(x) x = self.activation_balancer(x) + x = self.activation_whiten(x) scales = self.scale_balancer(scales) x = x * scales x = self.activation(x) # Identity only. For diagnostics. x = self.out_proj(x) - x = self.out_whiten(x) return x @@ -1548,6 +1549,11 @@ class ConvolutionModule(nn.Module): self.activation = DoubleSwish() + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + self.pointwise_conv2 = ScaledConv1d( channels, channels, @@ -1558,11 +1564,6 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.01, 0.1), - grad_scale=0.01) - def forward(self, x: Tensor, @@ -1597,10 +1598,13 @@ class ConvolutionModule(nn.Module): x = self.deriv_balancer2(x) x = self.activation(x) + x = x.transpose(1, 2) + x = self.whiten(x) # (batch, time, channel) + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) # (batch, channel, time) - x = x.permute(2, 0, 1) - x = self.out_whiten(x) + x = x.permute(2, 0, 1) # (time, batch, channel) return x class Conv2dSubsampling(nn.Module):