mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Decrease whitening limit from 2.0 to 1.1.
This commit is contained in:
parent
593a6e946d
commit
252798b6a1
@ -959,11 +959,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
|
||||
# self.whiten is applied on the values in forward()
|
||||
self.whiten_values = Whiten(whitening_limit=2.0,
|
||||
self.whiten_values = Whiten(whitening_limit=1.1,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
# self.whiten_keys is applied on the keys in forward()
|
||||
self.whiten_keys = Whiten(whitening_limit=2.0,
|
||||
self.whiten_keys = Whiten(whitening_limit=1.1,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
|
||||
@ -980,7 +980,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
# self.whiten_values2 is applied on the values in forward2()
|
||||
self.whiten_values2 = Whiten(whitening_limit=2.0,
|
||||
self.whiten_values2 = Whiten(whitening_limit=1.1,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user