From 1a2632d0a2d8812fa52cbdf4c72c897273ea7ecc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Nov 2022 20:01:09 +0800 Subject: [PATCH] Remove whitening in SelfAttention module. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b406da0c7..e5d2e58ec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1211,15 +1211,6 @@ class SelfAttention(nn.Module): num_heads * value_head_dim, bias=True) - # attempt to make the output of `in_proj` uncorrelated within each head - # and all heads having roughly the same magnitude. the hope is to - # improve learning dynamics; this loses no power as there is no constraint - # on the condition number of out_proj. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.out_proj = ScaledLinear(num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05) @@ -1252,7 +1243,6 @@ class SelfAttention(nn.Module): assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = self.whiten_values(x) # does nothing in the forward pass. x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) # now x: (num_heads, batch_size, seq_len, value_head_dim) value_head_dim = x.shape[-1]