mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Change how we penalize weights
This commit is contained in:
parent
26cf13a3e1
commit
a514d23df7
@ -825,6 +825,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
# largish range used to keep grads relatively small and avoid overflow in grads.
|
# largish range used to keep grads relatively small and avoid overflow in grads.
|
||||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||||
min_positive=1/(2*downsampling_factor),
|
min_positive=1/(2*downsampling_factor),
|
||||||
|
max_positive=0.6,
|
||||||
min_abs=1.0)
|
min_abs=1.0)
|
||||||
|
|
||||||
# below are for diagnostics.
|
# below are for diagnostics.
|
||||||
@ -868,38 +869,19 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
d = self.downsampling_factor
|
d = self.downsampling_factor
|
||||||
seq_len_reduced = (seq_len + d - 1) // d
|
seq_len_reduced = (seq_len + d - 1) // d
|
||||||
|
|
||||||
# penalize any nonzero scores that are numbered higher than the
|
|
||||||
# reduced sequence length-- we don't want such scores present
|
|
||||||
# because they make the derivatives inaccurate (to make the
|
|
||||||
# derivatives accurate, we need the weights to go to zero before we
|
|
||||||
# remove those frames from the computation).
|
|
||||||
penalty1 = weights[:, seq_len_reduced:].mean()
|
|
||||||
|
|
||||||
# e.g. if intermediate_rate is 0.1, 10% of the kept frames should
|
|
||||||
# have scores between 0 and 1 -- and hence nonzero derivatives -- so
|
|
||||||
# we can learn the scores without the derivatives getting too large
|
|
||||||
# for that subset of frames. Under the assumption that the scores
|
|
||||||
# go about linearly from 1 to 0, the average of the kept scores
|
|
||||||
# would be (100% - 0.5*10%) = 95%. If the average of the kept
|
|
||||||
# scores is higher than this, we need to apply a penalty.
|
|
||||||
max_kept_scores = 1.0 - (0.5 * float(self.intermediate_rate))
|
|
||||||
|
|
||||||
penalty2 = (weights[:, :seq_len_reduced].mean() - max_kept_scores).clamp(min=0.0)
|
|
||||||
|
|
||||||
# the max=1.0 is to make sure we never make the final weights negative, which
|
|
||||||
# would lead to problems
|
|
||||||
# penalty_scale is a heuristic to make sure the penalty is sufficient to
|
|
||||||
# enforce the constraint.
|
|
||||||
penalty_scale = 2.0
|
|
||||||
penalty = (penalty_scale * (penalty1 + penalty2)).clamp(max=1.0)
|
|
||||||
|
|
||||||
if random.random() < 0.01 or __name__ == '__main__':
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
logging.info(f"penalty1={penalty1}, penalty2={penalty2}, mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
# if `penalty` is nonzero, inject some randomness into the weights of
|
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
||||||
# the whole batch. The hope is that this will be a sufficient penalty.
|
missing = weights_discarded.shape[1] - seq_len_reduced
|
||||||
# if this doesn't work well we can consider other ways to apply the penalty.
|
if missing != 0:
|
||||||
weights = weights * (1.0 + (torch.rand_like(weights) - 0.5) * penalty)
|
weights_discarded = torch.cat(weights_discarded,
|
||||||
|
torch.zeros(batch_size, missing,
|
||||||
|
device=weights.device,
|
||||||
|
dtype=weights.dtype),
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
weights = weights[:, :seq_len_reduced] - weights_discarded
|
||||||
else:
|
else:
|
||||||
# test mode. because the sequence might be short, we keep all nonzero scores;
|
# test mode. because the sequence might be short, we keep all nonzero scores;
|
||||||
# and there is no need for any penalty.
|
# and there is no need for any penalty.
|
||||||
@ -909,10 +891,10 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
|
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
|
||||||
if random.random() < 0.02:
|
if random.random() < 0.02:
|
||||||
logging.info("seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info("seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
weights = weights[:, :seq_len_reduced]
|
||||||
|
|
||||||
indexes = indexes[:, :seq_len_reduced]
|
indexes = indexes[:, :seq_len_reduced]
|
||||||
weights = weights[:, :seq_len_reduced]
|
|
||||||
|
|
||||||
weights = self.copy_weights2(weights)
|
weights = self.copy_weights2(weights)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user