mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
More debug print
This commit is contained in:
parent
e4a774cb98
commit
299482d02d
@ -862,9 +862,6 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
d = self.downsampling_factor
|
d = self.downsampling_factor
|
||||||
seq_len_reduced = (seq_len + d - 1) // d
|
seq_len_reduced = (seq_len + d - 1) // d
|
||||||
|
|
||||||
if random.random() < 0.01 or __name__ == '__main__':
|
|
||||||
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
|
||||||
|
|
||||||
|
|
||||||
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
||||||
missing = weights_discarded.shape[1] - seq_len_reduced
|
missing = weights_discarded.shape[1] - seq_len_reduced
|
||||||
@ -875,6 +872,10 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
dtype=weights.dtype)),
|
dtype=weights.dtype)),
|
||||||
dim=1)
|
dim=1)
|
||||||
|
|
||||||
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
|
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# randomly rotate `weights_discarded` on the sequence axis; this is
|
# randomly rotate `weights_discarded` on the sequence axis; this is
|
||||||
# intended to ensure that it doesn't assign the highest scores to
|
# intended to ensure that it doesn't assign the highest scores to
|
||||||
|
Loading…
x
Reference in New Issue
Block a user