diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index bcdf859a8..76493678e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1442,6 +1442,7 @@ class NonlinAttentionModule(nn.Module): min_abs=0.2, max_abs=10.0, min_prob=0.05, ) + self.pre_sigmoid = Identity() # for diagnostics. self.sigmoid = nn.Sigmoid() self.activation = Identity() # for diagnostics. @@ -1484,6 +1485,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) v = self.whiten1(v) # GLU mechanism + s = self.pre_sigmoid(s) x = self.sigmoid(s) * v x = self.balancer(x) @@ -1546,6 +1548,9 @@ class ConvolutionModule(nn.Module): max_abs=10.0, min_positive=0.05, max_positive=1.0 ) + self.pre_sigmoid = Identity() # before sigmoid; for diagnostics. + self.sigmoid = nn.Sigmoid() + self.depthwise_conv = nn.Conv1d( channels, channels, @@ -1594,7 +1599,13 @@ class ConvolutionModule(nn.Module): x = self.in_proj(x) # (time, batch, 2*channels) x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=-1) # (time, batch, channels) + + x, s = x.chunk(2, dim=-1) + s = self.pre_sigmoid(s) + s = self.sigmoid(s) + x = x * s + + # (time, batch, channels) # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time).