diff --git a/egs/libriheavy/LM/zipformer1/lm_datamodule.py b/egs/libriheavy/LM/zipformer1/lm_datamodule.py index 6b91a1dca..0ef0ff98b 100644 --- a/egs/libriheavy/LM/zipformer1/lm_datamodule.py +++ b/egs/libriheavy/LM/zipformer1/lm_datamodule.py @@ -123,7 +123,7 @@ def LmDataloader(dataset: LmDataset, dataset=dataset, batch_size=batch_size, num_workers=num_workers, - drop_last=True) + drop_last=False) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index fff1d8fe5..d0f11659f 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -825,8 +825,9 @@ class LearnedDownsamplingModule(nn.Module): # largish range used to keep grads relatively small and avoid overflow in grads. self.score_balancer = Balancer(1, channel_dim=-1, min_positive=1/(2*downsampling_factor), - min_abs=1.0) + min_abs=1.0) + # below are for diagnostics. self.copy_weights1 = nn.Identity() self.copy_weights2 = nn.Identity() @@ -860,58 +861,54 @@ class LearnedDownsamplingModule(nn.Module): sscores, indexes = scores.sort(dim=-1, descending=True) + weights = sscores.clamp(min=0.0, max=1.0) + weights = self.copy_weights1(weights) + if self.training: d = self.downsampling_factor - seq_len_reduced = (seq_len + d - 1) // d - intermediate_rate = float(self.intermediate_rate) + # penalize any nonzero scores that are numbered higher than the + # reduced sequence length-- we don't want such scores present + # because they make the derivatives inaccurate (to make the + # derivatives accurate, we need the weights to go to zero before we + # remove those frames from the computation). + penalty1 = weights[:, seq_len_reduced:].mean() - # 'right' is the rightmost of the 2 limits; we want the scores indexed - # 'upper' to be mapped to around 0.0 - right = seq_len_reduced - # we want scores around 'left' to be mapped to around 1.0. - left = int(seq_len_reduced * (1.0 - intermediate_rate)) + # e.g. if intermediate_rate is 0.1, 10% of the kept frames should + # have scores between 0 and 1 -- and hence nonzero derivatives -- so + # we can learn the scores without the derivatives getting too large + # for that subset of frames. Under the assumption that the scores + # go about linearly from 1 to 0, the average of the kept scores + # would be (100% - 0.5*10%) = 95%. If the average of the kept + # scores is higher than this, we need to apply a penalty. + max_kept_scores = 1.0 - (0.5 * float(self.intermediate_rate)) - # 'collar' determines the range of positions in the sorted list that we use to - # compute the average. We could let collar be 0.0, which would more exactly - # accomplish what we want; but we don't, because this would cause too-noisy - # gradients, with too much gradient going to one frame. - collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate)) + penalty2 = (weights[:, :seq_len_reduced].mean() - max_kept_scores).clamp(min=0.0) - # right_avg: shape (batch_size,), this is to be mapped to 0.0 - right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True) + # the max=1.0 is to make sure we never make the final weights negative, which + # would lead to problems + # penalty_scale is a heuristic to make sure the penalty is sufficient to + # enforce the constraint. + penalty_scale = 2.0 + penalty = (penalty_scale * (penalty1 + penalty2)).clamp(max=1.0) - # we only shift the scores left (decrease them, to ensure no more than `intermediate_rate` - # proportion of the scores are >0). This lets us have batch-independence in test-mode, - # the idea is that the model will "learn" the right distribution of scores. - right_avg_clamped = right_avg.clamp(min=0.0) + if random.random() < 0.01 or __name__ == '__main__': + logging.info(f"penalty1={penalty1}, penalty2={penalty2}, mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") - # left_avg: shape (batch_size,), this is to be mapped to 1.0 - left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True) - - # the + 0.001 is to avoid possible division by zero in case of ties. - sscores = self.copy_weights1(sscores) - - # divide by den: only decrease the scores' value. - den = (left_avg - right_avg_clamped).clamp(min=1.0) - - #logging.info(f"den = {den}") - weights = (sscores - right_avg_clamped) / den + # if `penalty` is nonzero, inject some randomness into the weights of + # the whole batch. The hope is that this will be a sufficient penalty. + # if this doesn't work well we can consider other ways to apply the penalty. + weights = weights * (1.0 + (torch.rand_like(weights) - 0.5) * penalty) else: - # in test mode, no normalization (we can't have batch-dependent - # effects because this would be "seeing the future"). But we trainin such - # a way that, hopefully, it will most of the time give us not much more - # nonzero scores than in training time. - weights = sscores + # test mode. because the sequence might be short, we keep all nonzero scores; + # and there is no need for any penalty. - - weights = weights.clamp(min=0.0, max=1.0) - - if not self.training: # need to work out seq_len_reduced. seq_len_reduced = max(1, (weights > 0.0).to(torch.int32).sum(dim=-1).max().item()) + if random.random() < 0.02: + logging.info("seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") indexes = indexes[:, :seq_len_reduced] @@ -919,9 +916,6 @@ class LearnedDownsamplingModule(nn.Module): weights = self.copy_weights2(weights) - if random.random() < 0.01 or __name__ == '__main__': - logging.info(f"Mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") - # re-sort the indexes we kept, on index value, so that # masking for causal models will be in the correct order. # (actually this may not really matter, TODO: see whether we