diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 21653e54d..e7b5df4ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -959,11 +959,11 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) # self.whiten is applied on the values in forward() - self.whiten_values = Whiten(whitening_limit=2.0, + self.whiten_values = Whiten(whitening_limit=1.1, prob=1.0 if __name__ == "__main__" else 0.1, grad_scale=0.0025) # self.whiten_keys is applied on the keys in forward() - self.whiten_keys = Whiten(whitening_limit=2.0, + self.whiten_keys = Whiten(whitening_limit=1.1, prob=1.0 if __name__ == "__main__" else 0.1, grad_scale=0.0025) @@ -980,7 +980,7 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(whitening_limit=2.0, + self.whiten_values2 = Whiten(whitening_limit=1.1, prob=1.0 if __name__ == "__main__" else 0.1, grad_scale=0.0025)