From 76e6726178e96425b4aa6fc991fcbe36667b9e36 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 18 May 2023 14:56:44 +0800 Subject: [PATCH] Implement random rotation of dims --- egs/libriheavy/LM/zipformer1/subformer.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index baf0ed914..37873f245 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -865,16 +865,27 @@ class LearnedDownsamplingModule(nn.Module): if random.random() < 0.01 or __name__ == '__main__': logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") + weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] missing = weights_discarded.shape[1] - seq_len_reduced if missing != 0: - weights_discarded = torch.cat(weights_discarded, - torch.zeros(batch_size, missing, - device=weights.device, - dtype=weights.dtype), + weights_discarded = torch.cat((weights_discarded, + torch.zeros(batch_size, missing, + device=weights.device, + dtype=weights.dtype)), dim=1) - weights = weights[:, :seq_len_reduced] - weights_discarded + + # randomly rotate `weights_discarded` on the sequence axis; this is + # intended to ensure that it doesn't assign the highest scores to + # not-so-important elements to avoid the randomness of these + # discarded weights. + r = random.randint(0, seq_len_reduced - 1) + weights_discarded = torch.cat((weights_discarded[:, r:], + weights_discarded[:, :r]), + dim=1) + + weights = (weights[:, :seq_len_reduced] - weights_discarded) else: # test mode. because the sequence might be short, we keep all nonzero scores; # and there is no need for any penalty.