Implement random rotation of dims

This commit is contained in:
Daniel Povey 2023-05-18 14:56:44 +08:00
parent d631ffec5b
commit 76e6726178

View File

@ -865,16 +865,27 @@ class LearnedDownsamplingModule(nn.Module):
if random.random() < 0.01 or __name__ == '__main__': if random.random() < 0.01 or __name__ == '__main__':
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
missing = weights_discarded.shape[1] - seq_len_reduced missing = weights_discarded.shape[1] - seq_len_reduced
if missing != 0: if missing != 0:
weights_discarded = torch.cat(weights_discarded, weights_discarded = torch.cat((weights_discarded,
torch.zeros(batch_size, missing, torch.zeros(batch_size, missing,
device=weights.device, device=weights.device,
dtype=weights.dtype), dtype=weights.dtype)),
dim=1) dim=1)
weights = weights[:, :seq_len_reduced] - weights_discarded
# randomly rotate `weights_discarded` on the sequence axis; this is
# intended to ensure that it doesn't assign the highest scores to
# not-so-important elements to avoid the randomness of these
# discarded weights.
r = random.randint(0, seq_len_reduced - 1)
weights_discarded = torch.cat((weights_discarded[:, r:],
weights_discarded[:, :r]),
dim=1)
weights = (weights[:, :seq_len_reduced] - weights_discarded)
else: else:
# test mode. because the sequence might be short, we keep all nonzero scores; # test mode. because the sequence might be short, we keep all nonzero scores;
# and there is no need for any penalty. # and there is no need for any penalty.