diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index adad7ca97..140638d3c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1332,6 +1332,11 @@ class AttentionSqueeze(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=False, initial_scale=0.05) + + self.out_whiten = Whiten(num_groups=1, + whitening_limit=10.0, + prob=(0.01, 0.1), + grad_scale=0.01) self.out_balancer = ActivationBalancer( embed_dim, channel_dim=-1, min_positive=0.45, max_positive=0.55, @@ -1371,6 +1376,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = x * scales x = self.activation(x) # Identity only. For diagnostics. x = self.out_proj(x) + x = self.out_whiten(x) x = self.out_balancer(x) return x @@ -1391,10 +1397,10 @@ class FeedforwardModule(nn.Module): self.dropout = nn.Dropout(dropout) self.out_proj = ScaledLinear(feedforward_dim, embed_dim, initial_scale=0.01) - self.out_whitener = Whiten(num_groups=1, - whitening_limit=10.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.out_whiten = Whiten(num_groups=1, + whitening_limit=10.0, + prob=(0.025, 0.25), + grad_scale=0.01) def forward(self, x: Tensor): @@ -1403,7 +1409,7 @@ class FeedforwardModule(nn.Module): x = self.activation(x) x = self.dropout(x) x = self.out_proj(x) - x = self.out_whitener(x) + x = self.out_whiten(x) return x