diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 37873f245..071cf8ac3 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -862,9 +862,6 @@ class LearnedDownsamplingModule(nn.Module): d = self.downsampling_factor seq_len_reduced = (seq_len + d - 1) // d - if random.random() < 0.01 or __name__ == '__main__': - logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") - weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] missing = weights_discarded.shape[1] - seq_len_reduced @@ -875,6 +872,10 @@ class LearnedDownsamplingModule(nn.Module): dtype=weights.dtype)), dim=1) + if random.random() < 0.01 or __name__ == '__main__': + logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") + + # randomly rotate `weights_discarded` on the sequence axis; this is # intended to ensure that it doesn't assign the highest scores to