From d2c198c072055571a6f860acacddf826e930a013 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 18 May 2023 15:48:14 +0800 Subject: [PATCH] Implement weight_scale, set weight_scale=10 --- egs/libriheavy/LM/zipformer1/subformer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index d0776594b..51664929c 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -805,13 +805,17 @@ class LearnedDownsamplingModule(nn.Module): downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no fundamental reason why this has to be an integer, but we make it so anyway. + weight_scale: constant scaling factor on the weights, introduced to make fp16 training + more stable by reducing gradient magnitudes. """ def __init__(self, embed_dim: int, - downsampling_factor: int): + downsampling_factor: int, + weight_scale: float = 10.0): super().__init__() + self.weight_scale = weight_scale self.to_scores = nn.Linear(embed_dim, 1, bias=False) # score_balancer is just to keep the magnitudes of the scores in # a fixed range and keep them balanced around zero, to stop @@ -855,7 +859,7 @@ class LearnedDownsamplingModule(nn.Module): sscores, indexes = scores.sort(dim=-1, descending=True) - weights = sscores.clamp(min=0.0, max=1.0) + weights = sscores.clamp(min=0.0, max=self.weight_scale) weights = self.copy_weights1(weights) if self.training: @@ -987,6 +991,9 @@ class LearnedDownsamplingModule(nn.Module): # unsqueeze at position 1 so the extra cost relates to the source position. attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1) + if self.weight_scale != 1.0: + attn_offset = attn_offset - math.log(self.weight_scale) + return attn_offset