From 5f5d02ed0c08e94bd7bcd33f55421b76e551d976 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 7 Dec 2022 18:07:56 +0800 Subject: [PATCH] Add another whitening module, move balancer to output. --- .../pruned_transducer_stateless7/zipformer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e1f933e50..608dbdd40 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1496,10 +1496,15 @@ class NonlinAttentionModule(nn.Module): min_abs=0.01, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten1 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.whiten2 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01) @@ -1539,10 +1544,11 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = torch.matmul(attn_weights, x) # now x: (num_heads, batch_size, seq_len, head_dim) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + x = self.whiten1(x) x = self.out_proj(x) x = self.balancer2(x) - x = self.whiten(x) + x = self.whiten2(x) return x