mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement random rotation of dims
This commit is contained in:
parent
d631ffec5b
commit
76e6726178
@ -865,16 +865,27 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
if random.random() < 0.01 or __name__ == '__main__':
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
|
|
||||||
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
||||||
missing = weights_discarded.shape[1] - seq_len_reduced
|
missing = weights_discarded.shape[1] - seq_len_reduced
|
||||||
if missing != 0:
|
if missing != 0:
|
||||||
weights_discarded = torch.cat(weights_discarded,
|
weights_discarded = torch.cat((weights_discarded,
|
||||||
torch.zeros(batch_size, missing,
|
torch.zeros(batch_size, missing,
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
dtype=weights.dtype),
|
dtype=weights.dtype)),
|
||||||
dim=1)
|
dim=1)
|
||||||
|
|
||||||
weights = weights[:, :seq_len_reduced] - weights_discarded
|
|
||||||
|
# randomly rotate `weights_discarded` on the sequence axis; this is
|
||||||
|
# intended to ensure that it doesn't assign the highest scores to
|
||||||
|
# not-so-important elements to avoid the randomness of these
|
||||||
|
# discarded weights.
|
||||||
|
r = random.randint(0, seq_len_reduced - 1)
|
||||||
|
weights_discarded = torch.cat((weights_discarded[:, r:],
|
||||||
|
weights_discarded[:, :r]),
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
weights = (weights[:, :seq_len_reduced] - weights_discarded)
|
||||||
else:
|
else:
|
||||||
# test mode. because the sequence might be short, we keep all nonzero scores;
|
# test mode. because the sequence might be short, we keep all nonzero scores;
|
||||||
# and there is no need for any penalty.
|
# and there is no need for any penalty.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user