Implement weight_scale, set weight_scale=10

This commit is contained in:
Daniel Povey 2023-05-18 15:48:14 +08:00
parent f6c7392430
commit d2c198c072

View File

@ -805,13 +805,17 @@ class LearnedDownsamplingModule(nn.Module):
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
fundamental reason why this has to be an integer, but we make it so fundamental reason why this has to be an integer, but we make it so
anyway. anyway.
weight_scale: constant scaling factor on the weights, introduced to make fp16 training
more stable by reducing gradient magnitudes.
""" """
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
downsampling_factor: int): downsampling_factor: int,
weight_scale: float = 10.0):
super().__init__() super().__init__()
self.weight_scale = weight_scale
self.to_scores = nn.Linear(embed_dim, 1, bias=False) self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in # score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop # a fixed range and keep them balanced around zero, to stop
@ -855,7 +859,7 @@ class LearnedDownsamplingModule(nn.Module):
sscores, indexes = scores.sort(dim=-1, descending=True) sscores, indexes = scores.sort(dim=-1, descending=True)
weights = sscores.clamp(min=0.0, max=1.0) weights = sscores.clamp(min=0.0, max=self.weight_scale)
weights = self.copy_weights1(weights) weights = self.copy_weights1(weights)
if self.training: if self.training:
@ -987,6 +991,9 @@ class LearnedDownsamplingModule(nn.Module):
# unsqueeze at position 1 so the extra cost relates to the source position. # unsqueeze at position 1 so the extra cost relates to the source position.
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1) attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
if self.weight_scale != 1.0:
attn_offset = attn_offset - math.log(self.weight_scale)
return attn_offset return attn_offset