Introduce factor of 2 to more strongly penalize discarded weights.

This commit is contained in:
Daniel Povey 2023-05-19 16:31:45 +08:00
parent 824d7b4492
commit 5fc0cce553

View File

@ -805,17 +805,13 @@ class LearnedDownsamplingModule(nn.Module):
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
fundamental reason why this has to be an integer, but we make it so fundamental reason why this has to be an integer, but we make it so
anyway. anyway.
weight_scale: constant scaling factor on the weights, introduced to make fp16 training
more stable by reducing gradient magnitudes.
""" """
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
downsampling_factor: int, downsampling_factor: int):
weight_scale: float = 1.0):
super().__init__() super().__init__()
self.weight_scale = weight_scale
self.to_scores = nn.Linear(embed_dim, 1, bias=False) self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in # score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop # a fixed range and keep them balanced around zero, to stop
@ -859,7 +855,7 @@ class LearnedDownsamplingModule(nn.Module):
sscores, indexes = scores.sort(dim=-1, descending=True) sscores, indexes = scores.sort(dim=-1, descending=True)
weights = sscores.clamp(min=0.0, max=self.weight_scale) weights = sscores.clamp(min=0.0, max=1.0)
weights = self.copy_weights1(weights) weights = self.copy_weights1(weights)
if self.training: if self.training:
@ -879,9 +875,12 @@ class LearnedDownsamplingModule(nn.Module):
if random.random() < 0.01 or __name__ == '__main__': if random.random() < 0.01 or __name__ == '__main__':
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
#weights_discarded = weights_discarded.flip(dims=(1,))
weights = (weights[:, :seq_len_reduced] - weights_discarded) # we were getting too many discarded weights before introducing this factor, which was
# hurting test-mode performance by creating a mismatch.
discarded_weights_factor = 2.0
weights = (weights[:, :seq_len_reduced] - (weights_discarded * discarded_weights_factor)).clamp(min=0.0, max=1.0)
else: else:
# test mode. because the sequence might be short, we keep all nonzero scores; # test mode. because the sequence might be short, we keep all nonzero scores;
# and there is no need for any penalty. # and there is no need for any penalty.
@ -991,9 +990,6 @@ class LearnedDownsamplingModule(nn.Module):
# unsqueeze at position 1 so the extra cost relates to the source position. # unsqueeze at position 1 so the extra cost relates to the source position.
attn_offset = attn_offset + (weights + eps).log().unsqueeze(1) attn_offset = attn_offset + (weights + eps).log().unsqueeze(1)
if self.weight_scale != 1.0:
attn_offset = attn_offset - math.log(self.weight_scale)
return attn_offset return attn_offset