diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 33d025f86..9878939be 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1438,6 +1438,13 @@ class ModifiedSEModule(nn.Module): super().__init__() self.squeeze_proj = nn.Linear(d_model, d_model, bias=False) + + # caution: this won't work well if the batch size is extremely small. + self.squeeze_whiten = Whiten(num_groups=1, + whitening_limit=10.0, + prob=(0.025, 0.25), + grad_scale=0.01) + self.in_proj = nn.Linear(d_model, d_model, bias=False) @@ -1494,6 +1501,7 @@ class ModifiedSEModule(nn.Module): squeezed = (x * pooling_mask).sum(dim=0, keepdim=True) squeezed = self.squeeze_proj(squeezed) + squeezed = self.squeeze_whiten(squeezed) squeezed = self.balancer(squeezed) squeezed = self.activation(squeezed) squeezed = self.to_bottleneck_proj(squeezed)