mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Set lr_factor on to_scores, max_abs=4.0 on balancer
This commit is contained in:
parent
45043e2e21
commit
3a71a53d8d
@ -813,6 +813,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||||
|
self.to_scores.lr_factor = 0.5
|
||||||
# score_balancer is just to keep the magnitudes of the scores in
|
# score_balancer is just to keep the magnitudes of the scores in
|
||||||
# a fixed range and keep them balanced around zero, to stop
|
# a fixed range and keep them balanced around zero, to stop
|
||||||
# these drifting around.
|
# these drifting around.
|
||||||
@ -820,7 +821,8 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||||
min_positive=1/(2*downsampling_factor),
|
min_positive=1/(2*downsampling_factor),
|
||||||
max_positive=0.6,
|
max_positive=0.6,
|
||||||
min_abs=1.0)
|
min_abs=1.0,
|
||||||
|
max_abs=4.0)
|
||||||
|
|
||||||
# below are for diagnostics.
|
# below are for diagnostics.
|
||||||
self.copy_weights1 = nn.Identity()
|
self.copy_weights1 = nn.Identity()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user