From d2d0ce03357dc249412105560556587f71bad350 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 15 May 2023 20:26:21 +0800 Subject: [PATCH] Try to get rid of gradient blowup --- egs/libriheavy/LM/zipformer1/subformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index b4b2c8e8a..b1583dd16 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -789,8 +789,11 @@ class LearnedDownsamplingModule(nn.Module): def __init__(self, embed_dim: int, downsampling_factor: int, - intermediate_rate: FloatLike = 0.2): + intermediate_rate: Optional[FloatLike] = ScheduledFloat((0.0, 0.5), + (4000.0, 0.2), + default=0.5)): super().__init__() + self.to_scores = nn.Linear(embed_dim, 1, bias=False) # score_balancer is just to keep the magnitudes of the scores in # a fixed range and keep them balanced around zero, to stop @@ -798,7 +801,7 @@ class LearnedDownsamplingModule(nn.Module): # largish range used to keep grads relatively small and avoid overflow in grads. self.score_balancer = Balancer(1, channel_dim=-1, min_positive=0.4, max_positive=0.6, - min_abs=10.0, max_abs=12.0) + min_abs=1.0, max_abs=1.2) self.copy_weights1 = nn.Identity() self.copy_weights2 = nn.Identity() @@ -862,8 +865,7 @@ class LearnedDownsamplingModule(nn.Module): den = (left_avg - right_avg) # the following is to avoid division by near-zero. - den = 0.8 * den + 0.2 * den.mean() - + den = 0.75 * den + 0.25 * den.mean() #logging.info(f"den = {den}") weights = (sscores - right_avg) / den