From a514d23df7d4f48a0543135af95824f44d1f7e39 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 18 May 2023 14:14:50 +0800 Subject: [PATCH] Change how we penalize weights --- egs/libriheavy/LM/zipformer1/subformer.py | 46 +++++++---------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index d0f11659f..9465d843d 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -825,6 +825,7 @@ class LearnedDownsamplingModule(nn.Module): # largish range used to keep grads relatively small and avoid overflow in grads. self.score_balancer = Balancer(1, channel_dim=-1, min_positive=1/(2*downsampling_factor), + max_positive=0.6, min_abs=1.0) # below are for diagnostics. @@ -868,38 +869,19 @@ class LearnedDownsamplingModule(nn.Module): d = self.downsampling_factor seq_len_reduced = (seq_len + d - 1) // d - # penalize any nonzero scores that are numbered higher than the - # reduced sequence length-- we don't want such scores present - # because they make the derivatives inaccurate (to make the - # derivatives accurate, we need the weights to go to zero before we - # remove those frames from the computation). - penalty1 = weights[:, seq_len_reduced:].mean() - - # e.g. if intermediate_rate is 0.1, 10% of the kept frames should - # have scores between 0 and 1 -- and hence nonzero derivatives -- so - # we can learn the scores without the derivatives getting too large - # for that subset of frames. Under the assumption that the scores - # go about linearly from 1 to 0, the average of the kept scores - # would be (100% - 0.5*10%) = 95%. If the average of the kept - # scores is higher than this, we need to apply a penalty. - max_kept_scores = 1.0 - (0.5 * float(self.intermediate_rate)) - - penalty2 = (weights[:, :seq_len_reduced].mean() - max_kept_scores).clamp(min=0.0) - - # the max=1.0 is to make sure we never make the final weights negative, which - # would lead to problems - # penalty_scale is a heuristic to make sure the penalty is sufficient to - # enforce the constraint. - penalty_scale = 2.0 - penalty = (penalty_scale * (penalty1 + penalty2)).clamp(max=1.0) - if random.random() < 0.01 or __name__ == '__main__': - logging.info(f"penalty1={penalty1}, penalty2={penalty2}, mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") + logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") - # if `penalty` is nonzero, inject some randomness into the weights of - # the whole batch. The hope is that this will be a sufficient penalty. - # if this doesn't work well we can consider other ways to apply the penalty. - weights = weights * (1.0 + (torch.rand_like(weights) - 0.5) * penalty) + weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] + missing = weights_discarded.shape[1] - seq_len_reduced + if missing != 0: + weights_discarded = torch.cat(weights_discarded, + torch.zeros(batch_size, missing, + device=weights.device, + dtype=weights.dtype), + dim=1) + + weights = weights[:, :seq_len_reduced] - weights_discarded else: # test mode. because the sequence might be short, we keep all nonzero scores; # and there is no need for any penalty. @@ -909,10 +891,10 @@ class LearnedDownsamplingModule(nn.Module): (weights > 0.0).to(torch.int32).sum(dim=-1).max().item()) if random.random() < 0.02: logging.info("seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") - + weights = weights[:, :seq_len_reduced] indexes = indexes[:, :seq_len_reduced] - weights = weights[:, :seq_len_reduced] + weights = self.copy_weights2(weights)