From 5fc0cce553deaa76b46230398b6623e291844447 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 19 May 2023 16:31:45 +0800 Subject: [PATCH] Introduce factor of 2 to more strongly penalize discarded weights. --- egs/libriheavy/LM/zipformer1/subformer.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 554098366..bbf1ad889 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -805,17 +805,13 @@ class LearnedDownsamplingModule(nn.Module): downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no fundamental reason why this has to be an integer, but we make it so anyway. - weight_scale: constant scaling factor on the weights, introduced to make fp16 training - more stable by reducing gradient magnitudes. """ def __init__(self, embed_dim: int, - downsampling_factor: int, - weight_scale: float = 1.0): + downsampling_factor: int): super().__init__() - self.weight_scale = weight_scale self.to_scores = nn.Linear(embed_dim, 1, bias=False) # score_balancer is just to keep the magnitudes of the scores in # a fixed range and keep them balanced around zero, to stop @@ -859,7 +855,7 @@ class LearnedDownsamplingModule(nn.Module): sscores, indexes = scores.sort(dim=-1, descending=True) - weights = sscores.clamp(min=0.0, max=self.weight_scale) + weights = sscores.clamp(min=0.0, max=1.0) weights = self.copy_weights1(weights) if self.training: @@ -879,9 +875,12 @@ class LearnedDownsamplingModule(nn.Module): if random.random() < 0.01 or __name__ == '__main__': logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") - #weights_discarded = weights_discarded.flip(dims=(1,)) - weights = (weights[:, :seq_len_reduced] - weights_discarded) + # we were getting too many discarded weights before introducing this factor, which was + # hurting test-mode performance by creating a mismatch. + discarded_weights_factor = 2.0 + + weights = (weights[:, :seq_len_reduced] - (weights_discarded * discarded_weights_factor)).clamp(min=0.0, max=1.0) else: # test mode. because the sequence might be short, we keep all nonzero scores; # and there is no need for any penalty. @@ -991,9 +990,6 @@ class LearnedDownsamplingModule(nn.Module): # unsqueeze at position 1 so the extra cost relates to the source position. attn_offset = attn_offset + (weights + eps).log().unsqueeze(1) - if self.weight_scale != 1.0: - attn_offset = attn_offset - math.log(self.weight_scale) - return attn_offset