diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a0cc8fc1f..944e010e8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -297,9 +297,10 @@ class ConformerEncoderLayer(nn.Module): min_positive=0.45, max_positive=0.55, max_abs=6.0, ) - self.max_eig = MaxEig( - d_model, channel_dim=-1, - ) + self.whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) def forward( @@ -352,7 +353,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.feed_forward3(src) - src = self.norm_final(self.max_eig(self.balancer(src))) + src = self.norm_final(self.balancer(src)) delta = src - src_orig bypass_scale = self.bypass_scale @@ -360,7 +361,7 @@ class ConformerEncoderLayer(nn.Module): bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) src = src_orig + delta * self.bypass_scale - return src + return self.whiten(src) class ConformerEncoder(nn.Module):