Don't drop last batch

This commit is contained in:
Daniel Povey 2023-05-18 12:47:28 +08:00
parent eb64130787
commit 9367ea3646
2 changed files with 37 additions and 43 deletions

View File

@ -123,7 +123,7 @@ def LmDataloader(dataset: LmDataset,
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True)
drop_last=False)

View File

@ -825,8 +825,9 @@ class LearnedDownsamplingModule(nn.Module):
# largish range used to keep grads relatively small and avoid overflow in grads.
self.score_balancer = Balancer(1, channel_dim=-1,
min_positive=1/(2*downsampling_factor),
min_abs=1.0)
min_abs=1.0)
# below are for diagnostics.
self.copy_weights1 = nn.Identity()
self.copy_weights2 = nn.Identity()
@ -860,58 +861,54 @@ class LearnedDownsamplingModule(nn.Module):
sscores, indexes = scores.sort(dim=-1, descending=True)
weights = sscores.clamp(min=0.0, max=1.0)
weights = self.copy_weights1(weights)
if self.training:
d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d
intermediate_rate = float(self.intermediate_rate)
# penalize any nonzero scores that are numbered higher than the
# reduced sequence length-- we don't want such scores present
# because they make the derivatives inaccurate (to make the
# derivatives accurate, we need the weights to go to zero before we
# remove those frames from the computation).
penalty1 = weights[:, seq_len_reduced:].mean()
# 'right' is the rightmost of the 2 limits; we want the scores indexed
# 'upper' to be mapped to around 0.0
right = seq_len_reduced
# we want scores around 'left' to be mapped to around 1.0.
left = int(seq_len_reduced * (1.0 - intermediate_rate))
# e.g. if intermediate_rate is 0.1, 10% of the kept frames should
# have scores between 0 and 1 -- and hence nonzero derivatives -- so
# we can learn the scores without the derivatives getting too large
# for that subset of frames. Under the assumption that the scores
# go about linearly from 1 to 0, the average of the kept scores
# would be (100% - 0.5*10%) = 95%. If the average of the kept
# scores is higher than this, we need to apply a penalty.
max_kept_scores = 1.0 - (0.5 * float(self.intermediate_rate))
# 'collar' determines the range of positions in the sorted list that we use to
# compute the average. We could let collar be 0.0, which would more exactly
# accomplish what we want; but we don't, because this would cause too-noisy
# gradients, with too much gradient going to one frame.
collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate))
penalty2 = (weights[:, :seq_len_reduced].mean() - max_kept_scores).clamp(min=0.0)
# right_avg: shape (batch_size,), this is to be mapped to 0.0
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
# the max=1.0 is to make sure we never make the final weights negative, which
# would lead to problems
# penalty_scale is a heuristic to make sure the penalty is sufficient to
# enforce the constraint.
penalty_scale = 2.0
penalty = (penalty_scale * (penalty1 + penalty2)).clamp(max=1.0)
# we only shift the scores left (decrease them, to ensure no more than `intermediate_rate`
# proportion of the scores are >0). This lets us have batch-independence in test-mode,
# the idea is that the model will "learn" the right distribution of scores.
right_avg_clamped = right_avg.clamp(min=0.0)
if random.random() < 0.01 or __name__ == '__main__':
logging.info(f"penalty1={penalty1}, penalty2={penalty2}, mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
# left_avg: shape (batch_size,), this is to be mapped to 1.0
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
# the + 0.001 is to avoid possible division by zero in case of ties.
sscores = self.copy_weights1(sscores)
# divide by den: only decrease the scores' value.
den = (left_avg - right_avg_clamped).clamp(min=1.0)
#logging.info(f"den = {den}")
weights = (sscores - right_avg_clamped) / den
# if `penalty` is nonzero, inject some randomness into the weights of
# the whole batch. The hope is that this will be a sufficient penalty.
# if this doesn't work well we can consider other ways to apply the penalty.
weights = weights * (1.0 + (torch.rand_like(weights) - 0.5) * penalty)
else:
# in test mode, no normalization (we can't have batch-dependent
# effects because this would be "seeing the future"). But we trainin such
# a way that, hopefully, it will most of the time give us not much more
# nonzero scores than in training time.
weights = sscores
# test mode. because the sequence might be short, we keep all nonzero scores;
# and there is no need for any penalty.
weights = weights.clamp(min=0.0, max=1.0)
if not self.training:
# need to work out seq_len_reduced.
seq_len_reduced = max(1,
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
if random.random() < 0.02:
logging.info("seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
indexes = indexes[:, :seq_len_reduced]
@ -919,9 +916,6 @@ class LearnedDownsamplingModule(nn.Module):
weights = self.copy_weights2(weights)
if random.random() < 0.01 or __name__ == '__main__':
logging.info(f"Mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
# re-sort the indexes we kept, on index value, so that
# masking for causal models will be in the correct order.
# (actually this may not really matter, TODO: see whether we