From faed28ba6a2a0db5e2ec0567c69bdd577ec17e88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 18:59:15 +0800 Subject: [PATCH] Changes for debugging/stats. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6a4f08f6d..f1f548ac0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1434,6 +1434,7 @@ class NonlinAttentionModule(nn.Module): min_abs=0.2, max_abs=10.0, min_prob=0.05, ) + self.pre_sigmoid = Identity() # for diagnostics. self.sigmoid = nn.Sigmoid() self.activation = Identity() # for diagnostics. @@ -1475,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) v = self.whiten1(v) # GLU mechanism + s = self.pre_sigmoid(s) x = self.sigmoid(s) * v x = self.balancer(x) @@ -1537,6 +1539,9 @@ class ConvolutionModule(nn.Module): max_abs=10.0, min_positive=0.05, max_positive=1.0 ) + self.pre_sigmoid = Identity() # before sigmoid; for diagnostics. + self.sigmoid = nn.Sigmoid() + self.depthwise_conv = nn.Conv1d( channels, channels, @@ -1585,7 +1590,13 @@ class ConvolutionModule(nn.Module): x = self.in_proj(x) # (time, batch, 2*channels) x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=-1) # (time, batch, channels) + + x, s = x.chunk(2, dim=-1) + s = self.pre_sigmoid(s) + s = self.sigmoid(s) + x = x * s + + # (time, batch, channels) # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time).