mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try to get rid of gradient blowup
This commit is contained in:
parent
2e66392306
commit
d2d0ce0335
@ -789,8 +789,11 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
downsampling_factor: int,
|
||||
intermediate_rate: FloatLike = 0.2):
|
||||
intermediate_rate: Optional[FloatLike] = ScheduledFloat((0.0, 0.5),
|
||||
(4000.0, 0.2),
|
||||
default=0.5)):
|
||||
super().__init__()
|
||||
|
||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||
# score_balancer is just to keep the magnitudes of the scores in
|
||||
# a fixed range and keep them balanced around zero, to stop
|
||||
@ -798,7 +801,7 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
# largish range used to keep grads relatively small and avoid overflow in grads.
|
||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=10.0, max_abs=12.0)
|
||||
min_abs=1.0, max_abs=1.2)
|
||||
|
||||
self.copy_weights1 = nn.Identity()
|
||||
self.copy_weights2 = nn.Identity()
|
||||
@ -862,8 +865,7 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
|
||||
den = (left_avg - right_avg)
|
||||
# the following is to avoid division by near-zero.
|
||||
den = 0.8 * den + 0.2 * den.mean()
|
||||
|
||||
den = 0.75 * den + 0.25 * den.mean()
|
||||
|
||||
#logging.info(f"den = {den}")
|
||||
weights = (sscores - right_avg) / den
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user