Half the time, flip weights_discarded

This commit is contained in:
Daniel Povey 2023-05-19 17:55:03 +08:00
parent 7d162bf41e
commit 4a425f7eb5

View File

@ -879,6 +879,10 @@ class LearnedDownsamplingModule(nn.Module):
# we were getting too many discarded weights before introducing this factor, which was
# hurting test-mode performance by creating a mismatch.
discarded_weights_factor = 2.0
if random.random() < 0.5:
# flipping it half the time increases the randomness, so gives an extra incentive
# to avoid nonzero weights in the discarded half
weights_discarded = weights_discarded.flip(dims=(1,))
weights = (weights[:, :seq_len_reduced] - (weights_discarded * discarded_weights_factor)).clamp(min=0.0, max=1.0)
else: