Try to get rid of gradient blowup

This commit is contained in:
Daniel Povey 2023-05-15 20:26:21 +08:00
parent 2e66392306
commit d2d0ce0335

View File

@ -789,8 +789,11 @@ class LearnedDownsamplingModule(nn.Module):
def __init__(self,
embed_dim: int,
downsampling_factor: int,
intermediate_rate: FloatLike = 0.2):
intermediate_rate: Optional[FloatLike] = ScheduledFloat((0.0, 0.5),
(4000.0, 0.2),
default=0.5)):
super().__init__()
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop
@ -798,7 +801,7 @@ class LearnedDownsamplingModule(nn.Module):
# largish range used to keep grads relatively small and avoid overflow in grads.
self.score_balancer = Balancer(1, channel_dim=-1,
min_positive=0.4, max_positive=0.6,
min_abs=10.0, max_abs=12.0)
min_abs=1.0, max_abs=1.2)
self.copy_weights1 = nn.Identity()
self.copy_weights2 = nn.Identity()
@ -862,8 +865,7 @@ class LearnedDownsamplingModule(nn.Module):
den = (left_avg - right_avg)
# the following is to avoid division by near-zero.
den = 0.8 * den + 0.2 * den.mean()
den = 0.75 * den + 0.25 * den.mean()
#logging.info(f"den = {den}")
weights = (sscores - right_avg) / den