diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 75ea8a5d8..87b5fd4f2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1845,6 +1845,13 @@ class Conv2dSubsampling(nn.Module): self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1)) self.out = nn.Linear(out_width * layer3_channels, out_channels) + # use a much larger than normal grad_scale on this whitening module; + # there is only one such module, so there is not a concern about adding + # together many copies of this extra gradient term. + self.out_whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(4.0), + prob=(0.025, 0.25), + grad_scale=0.05) self.out_norm = BasicNorm(out_channels) self.dropout = Dropout2(dropout) @@ -1889,9 +1896,10 @@ class Conv2dSubsampling(nn.Module): # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out(x) + x = self.out_whiten(x) x = self.out_norm(x) x = self.dropout(x) - return x + return 4.0 * x class AttentionCombine(nn.Module): """