From 45069175d9c094f6d733ac2d9faa8f7a73ffd349 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Nov 2022 14:16:13 +0800 Subject: [PATCH] Add a second whitening to the NonlinAttentionModule, after the aggregation. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 81cfe12f1..f02813c78 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1437,10 +1437,14 @@ class NonlinAttentionModule(nn.Module): bias=True, initial_scale=0.05) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten1 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.01, 0.1), + grad_scale=0.01) + self.whiten2 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.01, 0.1), + grad_scale=0.01) @@ -1465,7 +1469,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # very small probability to save time). s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04) - v = self.whiten(v) + v = self.whiten1(v) # GLU mechanism x = s.sigmoid() * v x = self.balancer(x) @@ -1481,6 +1485,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = self.activation(x) # diagnostics only, it's the identity. + x = self.whiten2(x) x = self.out_proj(x) return x