diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index beac0bb40..896d84032 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1423,7 +1423,7 @@ class NonlinAttentionModule(nn.Module): min_prob=0.1, ) self.whiten = Whiten(num_groups=1, - whitening_limit=20.0, + whitening_limit=10.0, prob=(0.025, 0.25), grad_scale=0.01) @@ -1444,7 +1444,9 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) Returns: a Tensor with the same shape as x """ - v, s = self.in_proj(x).chunk(2, dim=-1) + x = self.in_proj(x) + x = self.whiten(x) + v, s = x.chunk(2, dim=-1) if self.training and random.random() < 0.02: # prevent the inputs to the sigmoid from getting very large (this is @@ -1455,7 +1457,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # GLU mechanism x = s.sigmoid() * v x = self.balancer(x) - x = self.whiten(x) (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0]